Source code for plotting.true_vs_pred_plot

from . import Plotter
import matplotlib.pyplot as plt
import numpy as np

[docs] class TrueVsPredPlotter(Plotter): r""" Class for plotting a scatter plot with the true vs predicted values. """
[docs] def plot(self, y_true: np.ndarray, y_pred: np.ndarray, path: str, name: str = None): num_plots = y_true.shape[1] plt.figure(figsize=(10, 5 * num_plots)) for j in range(num_plots): plt.subplot(num_plots, 1, j + 1) plt.scatter(y_true[:, j], y_pred[:, j], s=1, c="b", alpha=0.5) plt.xlabel("True values") plt.ylabel("Predicted values") plt.title(f"Scatterplot for Component {j+1}") # plt.xlim(0, 1) # plt.ylim(0, 1) plt.grid(True) plt.tight_layout() if name is not None: file_path = str(path) + '/' + name else: file_path = str(path) + '/scatterplot.png' plt.savefig(file_path, dpi=300)