import h5py
import torch
from cetaceo.data import TorchDataset
[docs]
class HDF5Reader:
r"""
Class for reading HDF5 files. The data is stored in groups and datasets.
The inputs are stored in the "inputs" group and the outputs are stored in the "outputs" group.
Args:
file_path (str): The path to the HDF5 file to read.
"""
def __init__(self, file_path):
self.file_path = file_path
self._load_from_hdf5(file_path)
def _load_from_hdf5(self, file_path):
with h5py.File(file_path, "r") as f:
for key in f.keys():
setattr(self, key, self._create_class_for_group(f[key]))
def _create_class_for_group(self, group):
class_name = group.name.split("/")[-1]
group_class = type(class_name, (), {})
dataset_count = 0
for key in group.keys():
if isinstance(group[key], h5py.Group):
setattr(group_class, key, self._create_class_for_group(group[key]))
elif isinstance(group[key], h5py.Dataset):
data = group[key][()]
setattr(group_class, key, data)
dataset_count += 1
setattr(group_class, "_dataset_count", dataset_count)
return group_class
[docs]
def load_data_to_tensor(self, group_name: str) -> torch.Tensor:
"""
Load the datasets from the group to a tensor.
Args:
group_name (str): The name of the group to load.
Returns:
tensor (torch.Tensor): The tensor with the data from the group.
"""
if hasattr(self, group_name):
group_class = getattr(self, group_name)
tensors = []
group = h5py.File(self.file_path, "r")[group_name]
for key in group.keys():
if isinstance(group[key], h5py.Dataset):
tensor = torch.tensor(group[key][:])
tensors.append(tensor.float())
if tensors:
return torch.stack(tensors, dim=0)
else:
print(f"No datasets were found on group '{group_name}'.")
return None
else:
print(f"The group '{group_name}' does not exist")
return None
[docs]
def count_datasets(self, group_name:str) -> int:
"""
Count the number of datasets in the group.
Args:
group_name (str): The name of the group to count.
Returns:
count (int): The number of datasets in the group.
"""
if hasattr(self, group_name):
group = getattr(self, group_name)
return self._recursive_count_datasets(group)
else:
print(f"The group '{group_name}' does not exist")
return 0
def _recursive_count_datasets(self, group):
count = getattr(group, "_dataset_count", 0)
for key, value in group.__dict__.items():
if isinstance(value, type):
count += self._recursive_count_datasets(value)
return count
[docs]
class HDF5Dataset(TorchDataset):
r"""
Dataset class for HDF5 files.
Args:
src_file (str): The path to the HDF5 file to read.
x_scaler (object, optional): Scaler for the input data. Default is `None`.
y_scaler (object, optional): Scaler for the target data. Default is `None`.
"""
def __init__(
self,
src_file,
x_scaler=None,
y_scaler=None,
):
self.reader = HDF5Reader(src_file)
self.x = self._load_to_tensor("inputs")
self.y = self._load_to_tensor("outputs")
super().__init__(self.x, self.y, x_scaler=x_scaler, y_scaler=y_scaler)
self.x_scaler = x_scaler
self.y_scaler = y_scaler
self.inputs_size = self.reader.count_datasets("inputs")
self.outputs_size = self.reader.count_datasets("outputs")
def _load_to_tensor(self, group_name):
tensor = self.reader.load_data_to_tensor(group_name)
tensor = tensor.squeeze()
if len(tensor.shape) > 1:
return tensor.T
elif len(tensor.shape) == 1:
return tensor.unsqueeze(1)
else:
return tensor
# def scale_data(self, **kwargs):
# super().scale_data(**kwargs)
# self.x = torch.tensor(self.x, dtype=torch.float32)
# self.y = torch.tensor(self.y, dtype=torch.float32)
# def rescale_data(self, **kwargs):
# super().rescale_data(**kwargs)
# self.x = torch.tensor(self.x, dtype=torch.float32)
# self.y = torch.tensor(self.y, dtype=torch.float32)
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return (self.x[idx], self.y[idx])