from . import RegressionEvaluator
import numpy as np
[docs]
class RegressionEvaluatorPlotter(RegressionEvaluator):
r"""
Evaluator class for regression tasks.
Besides returning the evaluation metrics, it also create the plots of the `Plotter`s given.
Args:
plots_path (str): The path to save the plot.
plots_name (str): The name of the plot (default: ``None``).
tolerance (float): Tolerance level to consider values close to zero for MRE calculation (default: ``1e-4``).
plotters (list): List of plotters to be used.
"""
def __init__(
self,
plots_path: str,
plots_name: str = None,
tolerance: float = 1e-4,
plotters = []
) -> None:
super().__init__(tolerance)
self.plotters = plotters
self.plots_path = plots_path
self.plots_name = plots_name
def __call__(self, y_true: np.ndarray, y_pred: np.ndarray, x) -> dict:
metrics = super().__call__(y_true, y_pred)
if len(self.plotters) != 0:
for plotter in self.plotters:
plotter.plot(y_true, y_pred, self.plots_path, self.plots_name)
return metrics