Source code for data.TorchReader

import torch

from cetaceo.data import BaseDataset

[docs] class TorchDataset(BaseDataset): r""" Dataset class for PyTorch tensors. Args: x_scaler (object, optional): Scaler for the input data. y_scaler (object, optional): Scaler for the target data. """ def __init__( self, x, y=None, isscaled=(False, False), x_scaler=None, y_scaler=None, dtype=torch.float32 ): super().__init__() self.x = x self.y = y self.x_scaler = x_scaler self.y_scaler = y_scaler self._isscaled = isscaled self.dtype = dtype self._convert_data_to_tensor() # self.inputs_size = self.x.shape[1] # self.outputs_size = self.y.shape[1]
[docs] def scale_data(self, **kwargs): super().scale_data(**kwargs) self._convert_data_to_tensor()
[docs] def rescale_data(self, **kwargs): super().rescale_data(**kwargs) self._convert_data_to_tensor()
def _convert_data_to_tensor(self): if not isinstance(self.x, torch.Tensor): self.x = torch.tensor(self.x, dtype=self.dtype) if self.y is not None and not isinstance(self.y, torch.Tensor): self.y = torch.tensor(self.y, dtype=self.dtype) def __len__(self): return len(self.x) def __getitem__(self, idx): if self.y is None: return (self.x[idx], ) return (self.x[idx], self.y[idx])