import pyvista as pv
import numpy as np
import torch
from cetaceo.data import TorchDataset
[docs]
class VTUReader:
"""
Class for reading VTU files and processing the data.
The flow field data is stored in the cell centers, so the data is read from there.
Args:
mesh_file (str): The name of the VTU file to read.
"""
def __init__(self, mesh_file):
self.mesh = pv.read(mesh_file)
self.centers = self.mesh.cell_centers()
[docs]
def get_centers_coordinates(self):
"""
Get the coordinates of the cell centers.
Returns:
np.ndarray: The coordinates of the cell centers.
"""
return self.centers.points
[docs]
def get_cell_data(self, fields: list = [], all_fields=False):
"""
Get the cell data from the mesh.
Args:
fields (list, optional): The fields idex to read. Default is `[]`.
all_fields (bool, optional): Whether to read all fields. Default is `False`.
"""
assert (
len(fields) != 0 or all_fields
), "Either fields or all_fields must be provided."
if all_fields:
fields = self.get_data_names()
fields_data = []
for field in fields:
data = self.centers.cell_data[field]
if data.ndim == 1:
data = data[:, np.newaxis]
fields_data.append(data)
return np.hstack(fields_data)
[docs]
def get_data_names(self):
"""
Get the names of the fields in the mesh.
Returns:
list: The names of the fields in the mesh.
"""
return self.centers.cell_data.keys()
[docs]
def get_groupID(self):
"""
Get the CAD group ID for each cell.
Returns:
np.ndarray: The CAD group ID for each cell.
"""
return self.mesh.cell_data["CADGroupID"]
[docs]
class VTUDataset(TorchDataset):
"""
Dataset class for VTU files.
Args:
mesh_files (List[str]): List of file names to read.
x_scaler (object, optional): Scaler for the input data. Default is `None`.
y_scaler (object, optional): Scaler for the target data. Default is `None`.
coordinates_idx (List[int]): Indexes of the coordinates to read. Default is `[0, 2]`.
fields_idx (List[int]): Indexes of the fields to read. Default is `[0, 1, 2, 4, 5, 6, 7]`.
"""
def __init__(
self,
mesh_files,
x_scaler=None,
y_scaler=None,
coordinates_idx=[0, 2],
fields_idx=[0, 1, 2, 4, 5, 6, 7],
dtype=torch.float32,
):
self.mesh_files = mesh_files
self.coordinates_idx = coordinates_idx
self.fields_idx = fields_idx
self.x_scaler = x_scaler
self.y_scaler = y_scaler
# read data from the files
x, y = self._load_all_meshes(mesh_files)
# convert to tensors
super().__init__(
np.vstack(x), np.vstack(y), x_scaler=x_scaler, y_scaler=y_scaler, dtype=dtype
)
def __len__(self):
return self.x.shape[0]
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
[docs]
def load_mesh(self, file_name: str):
"""
Load the mesh data from the file.
Args:
file_name (str): The name of the VTU file to read.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple with the coordinates and fields data.
"""
reader = VTUReader(file_name)
fields = reader.get_data_names()[1:]
return (
reader.get_centers_coordinates()[:, self.coordinates_idx],
reader.get_cell_data(fields=fields)[:, self.fields_idx],
)
def _load_all_meshes(self, files: list):
x = []
y = []
for mesh_file in files:
data = self.load_mesh(mesh_file)
x.append(data[0])
y.append(data[1])
return x, y