import torch
from cetaceo.data import BaseDataset
[docs]
class TorchDataset(BaseDataset):
r"""
Dataset class for PyTorch tensors.
Args:
x_scaler (object, optional): Scaler for the input data.
y_scaler (object, optional): Scaler for the target data.
"""
def __init__(
self,
x,
y=None,
isscaled=(False, False),
x_scaler=None,
y_scaler=None,
dtype=torch.float32
):
super().__init__()
self.x = x
self.y = y
self.x_scaler = x_scaler
self.y_scaler = y_scaler
self._isscaled = isscaled
self.dtype = dtype
self._convert_data_to_tensor()
# self.inputs_size = self.x.shape[1]
# self.outputs_size = self.y.shape[1]
[docs]
def scale_data(self, **kwargs):
super().scale_data(**kwargs)
self._convert_data_to_tensor()
[docs]
def rescale_data(self, **kwargs):
super().rescale_data(**kwargs)
self._convert_data_to_tensor()
def _convert_data_to_tensor(self):
if not isinstance(self.x, torch.Tensor):
self.x = torch.tensor(self.x, dtype=self.dtype)
if self.y is not None and not isinstance(self.y, torch.Tensor):
self.y = torch.tensor(self.y, dtype=self.dtype)
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
if self.y is None:
return (self.x[idx], )
return (self.x[idx], self.y[idx])