import abc
from typing import Tuple, Dict, Optional
from cetaceo.data import BaseDataset
from cetaceo.optimization import OptunaOptimizer
[docs]
class Model(abc.ABC):
[docs]
@abc.abstractmethod
def fit(self, train_dataset: BaseDataset, eval_set=Optional[BaseDataset], **kwargs):
"""
Fit the model to the training data.
Args:
train_dataset (BaseDataset): The training dataset.
eval_set (Optional[BaseDataset]): The evaluation dataset.
**kwargs: Additional parameters for the fit method.
"""
pass
[docs]
@abc.abstractmethod
def predict(self, X: BaseDataset, rescale_output=True, **kwargs):
"""
Predict the target values for the input data.
Args:
X (BaseDataset): The input data.
rescale_output (bool, optional): Whether to rescale the output data. Default is `True`.
**kwargs: Additional parameters for the predict method.
Returns:
np.array: The predicted target values.
"""
pass
[docs]
@classmethod
def create_optimized_model(
cls,
train_dataset: BaseDataset,
eval_dataset: Optional[BaseDataset],
optuna_optimizer: OptunaOptimizer,
) -> Tuple["Model", Dict]:
"""
Create an optimized model using Optuna.
Args:
train_dataset (BaseDataset): The training dataset.
eval_dataset (Optional[BaseDataset]): The evaluation dataset.
optuna_optimizer (OptunaOptimizer): The optimizer to use for optimization.
Returns:
Tuple[Model, Dict]: The optimized model and the best parameters found by the optimizer.
"""
raise NotImplementedError("create_optimized_model not implemented")
[docs]
@abc.abstractmethod
def save(self, path: str):
"""
Save the model to a file.
Args:
path (str): The path to save the model.
"""
pass
[docs]
@classmethod
@abc.abstractmethod
def load(self, path: str):
"""
Load a model from a file.
Args:
path (str): The path to load the model from.
Returns:
Model: The loaded model.
"""
pass