from torch.utils.data import Dataset
import numpy as np
from sklearn.utils.validation import check_is_fitted
[docs]
class BaseDataset(Dataset):
r"""
Base class for datasets.
"""
def __init__(self) -> None:
super().__init__()
self.x = None
self.y = None
self.x_scaler = None
self.y_scaler = None
self._isscaled = (False, False)
@property
def isscaled(self):
return self._isscaled
[docs]
def process_data(
self,
process_function: callable,
) -> None:
r"""
Process the data using the provided function.
The `x` and `y` attributes of the dataset are updated with the transformed data.
Args:
process_function (callable): A function that takes in `x` and `y` as input and returns transformed `x` and `y`.
"""
self.x, self.y = process_function(self.x, self.y)
[docs]
def scale_data(
self,
scale_x: bool = True,
scale_y: bool = True,
) -> None:
r"""
Transform the data using the provided scalers.
If the scalers are fitted, the data will be transformed.
If not, the scalers will be fitted to the data and then the data will be transformed.
Args:
scale_x (bool): Whether to scale the input data (default: ``True``).
scale_y (bool): Whether to scale the target data (default: ``True``).
"""
x_scaled, y_scaled = self._isscaled
if scale_x and not x_scaled:
self.x = np.array(self.x)
try:
check_is_fitted(self.x_scaler)
self.x = self.x_scaler.transform(self.x)
except:
self.x = self.x_scaler.fit_transform(self.x)
x_scaled = True
if scale_y and not y_scaled:
self.y = np.array(self.y)
try:
check_is_fitted(self.y_scaler)
self.y = self.y_scaler.transform(self.y)
except:
self.y = self.y_scaler.fit_transform(self.y)
y_scaled = True
self._isscaled = (x_scaled, y_scaled)
[docs]
def rescale_data(
self,
rescale_x: bool = True,
rescale_y: bool = True,
) -> None:
r"""
Reverse the scaling of the data using the provided scalers.
This function should be used only if the data has been scaled.
The `x` and `y` attributes of the dataset are updated with the unscaled data.
Args:
rescale_x (bool): Whether to unscale the input data (default: ``True``).
rescale_y (bool): Whether to unscale the target data (default: ``True``).
"""
x_scaled, y_scaled = self._isscaled
if rescale_x and x_scaled and self.x_scaler is not None:
self.x = np.array(self.x)
self.x = self.x_scaler.inverse_transform(self.x)
x_scaled = False
if rescale_y and y_scaled and self.y_scaler is not None:
self.y = np.array(self.y)
self.y = self.y_scaler.inverse_transform(self.y)
y_scaled = False
self._isscaled = (x_scaled, y_scaled)
[docs]
def rescale_x(self, x: np.array) -> np.array:
"""
Rescale the input data using the scaler.
Parameters
----------
x : np.array
The input data to be rescaled.
Returns
-------
np.array
The rescaled input data.
"""
if self.x_scaler is not None:
return self.x_scaler.inverse_transform(x)
return x
[docs]
def rescale_y(self, y: np.array) -> np.array:
"""
Rescale the target data using the scaler.
Parameters
----------
y : np.array
The target data to be rescaled.
Returns
-------
np.array
The rescaled target data.
"""
if self.y_scaler is not None:
return self.y_scaler.inverse_transform(y)
return y