Source code for data.VTUReader

import pyvista as pv
import numpy as np
import torch

from cetaceo.data import TorchDataset


[docs] class VTUReader: """ Class for reading VTU files and processing the data. The flow field data is stored in the cell centers, so the data is read from there. Args: mesh_file (str): The name of the VTU file to read. """ def __init__(self, mesh_file): self.mesh = pv.read(mesh_file) self.centers = self.mesh.cell_centers()
[docs] def get_centers_coordinates(self): """ Get the coordinates of the cell centers. Returns: np.ndarray: The coordinates of the cell centers. """ return self.centers.points
[docs] def get_cell_data(self, fields: list = [], all_fields=False): """ Get the cell data from the mesh. Args: fields (list, optional): The fields idex to read. Default is `[]`. all_fields (bool, optional): Whether to read all fields. Default is `False`. """ assert ( len(fields) != 0 or all_fields ), "Either fields or all_fields must be provided." if all_fields: fields = self.get_data_names() fields_data = [] for field in fields: data = self.centers.cell_data[field] if data.ndim == 1: data = data[:, np.newaxis] fields_data.append(data) return np.hstack(fields_data)
[docs] def get_data_names(self): """ Get the names of the fields in the mesh. Returns: list: The names of the fields in the mesh. """ return self.centers.cell_data.keys()
[docs] def get_groupID(self): """ Get the CAD group ID for each cell. Returns: np.ndarray: The CAD group ID for each cell. """ return self.mesh.cell_data["CADGroupID"]
[docs] class VTUDataset(TorchDataset): """ Dataset class for VTU files. Args: mesh_files (List[str]): List of file names to read. x_scaler (object, optional): Scaler for the input data. Default is `None`. y_scaler (object, optional): Scaler for the target data. Default is `None`. coordinates_idx (List[int]): Indexes of the coordinates to read. Default is `[0, 2]`. fields_idx (List[int]): Indexes of the fields to read. Default is `[0, 1, 2, 4, 5, 6, 7]`. """ def __init__( self, mesh_files, x_scaler=None, y_scaler=None, coordinates_idx=[0, 2], fields_idx=[0, 1, 2, 4, 5, 6, 7], dtype=torch.float32, ): self.mesh_files = mesh_files self.coordinates_idx = coordinates_idx self.fields_idx = fields_idx self.x_scaler = x_scaler self.y_scaler = y_scaler # read data from the files x, y = self._load_all_meshes(mesh_files) # convert to tensors super().__init__( np.vstack(x), np.vstack(y), x_scaler=x_scaler, y_scaler=y_scaler, dtype=dtype ) def __len__(self): return self.x.shape[0] def __getitem__(self, idx): return self.x[idx], self.y[idx]
[docs] def load_mesh(self, file_name: str): """ Load the mesh data from the file. Args: file_name (str): The name of the VTU file to read. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple with the coordinates and fields data. """ reader = VTUReader(file_name) fields = reader.get_data_names()[1:] return ( reader.get_centers_coordinates()[:, self.coordinates_idx], reader.get_cell_data(fields=fields)[:, self.fields_idx], )
def _load_all_meshes(self, files: list): x = [] y = [] for mesh_file in files: data = self.load_mesh(mesh_file) x.append(data[0]) y.append(data[1]) return x, y