from . import Plotter
import matplotlib.pyplot as plt
import numpy as np
[docs]
class TrueVsPredPlotter(Plotter):
r"""
Class for plotting a scatter plot with the true vs predicted values.
"""
[docs]
def plot(self, y_true: np.ndarray, y_pred: np.ndarray, path: str, name: str = None):
num_plots = y_true.shape[1]
plt.figure(figsize=(10, 5 * num_plots))
for j in range(num_plots):
plt.subplot(num_plots, 1, j + 1)
plt.scatter(y_true[:, j], y_pred[:, j], s=1, c="b", alpha=0.5)
plt.xlabel("True values")
plt.ylabel("Predicted values")
plt.title(f"Scatterplot for Component {j+1}")
# plt.xlim(0, 1)
# plt.ylim(0, 1)
plt.grid(True)
plt.tight_layout()
if name is not None:
file_path = str(path) + '/' + name
else:
file_path = str(path) + '/scatterplot.png'
plt.savefig(file_path, dpi=300)