Source code for utils.zone_metrics

import torch
import json
from typing import List, Dict
import matplotlib.pyplot as plt

from cetaceo.data import BaseDataset, TorchDataset
from cetaceo.models import Model

[docs] def zone_metrics_calculation( model: Model, dataset: BaseDataset, evaluators: List = [], savedir: str = None, ): r""" Evaluate the model on a dataset using a custom analysis. Args: dataset (BaseDataset): The dataset to evaluate the model on. evaluators (List, optional): The evaluators to use for evaluating the model. Default is `[]`. Returns: metrics (Dict[str, float]): The metrics evaluated on the dataset. """ print("\nComputing evaluation zones...") print(" Original dataset shape:", dataset.x.shape, dataset.y.shape) n_ipt = dataset.x.shape[1] n_col = n_ipt-3 tensor = torch.cat((dataset.x, dataset.y), dim=1) unique_combinations, inverse_indices = torch.unique(tensor[:, :n_col], dim=0, return_inverse=True) groups_dict = {} for i in range(unique_combinations.size(0)): group = tensor[inverse_indices == i] group_size = group.size(0) if group_size not in groups_dict: groups_dict[group_size] = [] groups_dict[group_size].append(group) sizes = sorted(groups_dict.keys(), reverse=True) groups_flight = groups_dict[sizes[0]] groups_coord = groups_dict[sizes[1]] print(" Number and size of flight conditions evaluation zones:", len(groups_flight), ", ", sizes[0]) print(" Number and size of geometric interpolation evaluation zones:", len(groups_coord), ", ", sizes[1]) print("\nComputing performance in flight conditions interpolation...") metrics_flight = [] for group in groups_flight: x, y = group[:, :n_ipt], group[:, n_ipt:] group_dataset = TorchDataset(x, y, dataset.x_scaler, dataset.y_scaler, dataset.isscaled) y_pred, y_true = model.predict(group_dataset, rescale_output=True, return_targets=True) metrics = {} for evaluator in evaluators: metrics.update(evaluator(y_true, y_pred)) group_dataset.rescale_data() metrics["flight_condition"] = group_dataset.x[0, :n_col].tolist() metrics_flight.append(metrics) print("\nComputing performance in geometric interpolation...") metrics_coord = [] for group in groups_coord: x, y = group[:, :n_ipt], group[:, n_ipt:] group_dataset = TorchDataset(x, y, dataset.x_scaler, dataset.y_scaler, dataset.isscaled) y_pred, y_true = model.predict(group_dataset, rescale_output=True, return_targets=True) metrics = {} for evaluator in evaluators: metrics.update(evaluator(y_true, y_pred)) group_dataset.rescale_data() metrics["flight_condition"] = group_dataset.x[0, :n_col].tolist() metrics_coord.append(metrics) if savedir is not None: print("\nSaving results as .json files...") with open(savedir + "/metrics_flight.json", "w") as f: json.dump(metrics_flight, f) with open(savedir + "/metrics_coord.json", "w") as f: json.dump(metrics_coord, f) return metrics_flight, metrics_coord
[docs] def zone_metrics_plotter( metrics_flight: List[Dict], metrics_coord: List[Dict], savedir: str = None, ): def generate_plot(dict, config): n_fig = len(config["metrics"]) n_row = 2 n_col = n_fig // n_row markers = ["o", "s", "D", "v", "^", "<", ">", "p", "P", "*", "X", "d", "1", "2", "3", "4", "8", "h", "H", "+", "x", "|", "_"] colors = ["b", "g", "r", "c", "m", "y", "k", "w"] marker_map = {"mach": list(), "alt": list()} fig = plt.figure(figsize=(8*n_col, 8*n_row)) if not config["3Dplot"]: fig.subplots_adjust(left=0.075, right=0.925, top=0.95, bottom=0.05, wspace=0.3, hspace=0.3) plt.tight_layout() for i, metric in enumerate(config["metrics"]): x = [d["flight_condition"][0] for d in dict] y = [d["flight_condition"][1] for d in dict] z = [d["flight_condition"][2] for d in dict] mtrc = [d[metric["name"]] for d in dict] if config["3Dplot"]: ax = fig.add_subplot(n_row, n_col, i+1, projection='3d') scatter = ax.scatter(x, y, z, c=mtrc, cmap='viridis', s=100) cbar = plt.colorbar(scatter, ax=ax, shrink=0.5, aspect=10) cbar.set_label(metric["label"]) ax.set_xlabel('Mach [-]') ax.set_ylabel('Angle of Attack') ax.set_zlabel('Altitude [FL]') ax.set_title(metric["label"] + " by" + config["group_label"]) else: ax = fig.add_subplot(n_row, n_col, i+1) for x_idx, xi in enumerate(x): yi = y[x_idx] zi = z[x_idx] mtrci = mtrc[x_idx] if xi not in marker_map["mach"]: marker_map["mach"].append(xi) marker = markers[marker_map["mach"].index(xi)] else: marker = markers[marker_map["mach"].index(xi)] if zi not in marker_map["alt"]: marker_map["alt"].append(zi) color = colors[marker_map["alt"].index(zi)] else: color = colors[marker_map["alt"].index(zi)] scatter = ax.scatter(yi, y=mtrci, marker=marker, color=color) ax.set_xlabel('Angle of Attack [deg]') ax.set_ylabel(metric["label"]) ax.set_title(metric["label"] + " by" + config["group_label"]) mach_markers = [ax.scatter([], [], marker=markers[n], color='k', label=str('Mach '+str(round(mach,1)))) for n, mach in enumerate(marker_map["mach"])] mach_markers.sort(key=lambda x: x.get_label()) alt_markers = [ax.scatter([], [], marker='o', color=colors[n], label=str('FL '+str(int(alt)))) for n, alt in enumerate(marker_map["alt"])] alt_markers.sort(key=lambda x: x.get_label()) custom_markers = mach_markers + alt_markers ax.legend(handles=custom_markers) if config["3Dplot"]: fig.subplots_adjust(left=0.075, right=0.925, top=0.95, bottom=0.05, wspace=0.3, hspace=0.3) plt.tight_layout() return fig print(" Plotting results for flight conditions interpolation...") config = {"group_name": "flight_cond", "group_label": "Flight Conditions", "metrics": [{"name":"mre", "label":"Mean Relative Error (MRE)"}, {"name":"mae", "label":"Mean Absolute Error (MAE)"}, {"name":"mse", "label":"Mean Squared Error (MSE)"}, {"name":"r2", "label":"R2 Score"}, {"name":"ae_95", "label":"Absolute Error Quantile 95%"}, {"name":"ae_99", "label":"Absolute Error Quantile 99%"}], "3Dplot": True} fig = generate_plot(metrics_flight, config) fig.savefig(savedir + "/plots/flight_cond_metrics.png") config["3Dplot"] = False fig = generate_plot(metrics_flight, config) fig.savefig(savedir + "/plots/flight_cond_metrics_alpha.png") print(" Plotting results for flight conditions interpolation...") config = {"group_name": "coord_interp", "group_label": "Flight Conditions", "metrics": [{"name":"mre", "label":"Mean Relative Error (MRE)"}, {"name":"mae", "label":"Mean Absolute Error (MAE)"}, {"name":"mse", "label":"Mean Squared Error (MSE)"}, {"name":"r2", "label":"R2 Score"}, {"name":"ae_95", "label":"Absolute Error Quantile 95%"}, {"name":"ae_99", "label":"Absolute Error Quantile 99%"}], "3Dplot": True} fig = generate_plot(metrics_coord, config) fig.savefig(savedir + "/plots/coord_interp_metrics.png") config["3Dplot"] = False fig = generate_plot(metrics_coord, config) fig.savefig(savedir + "/plots/coord_interp_metrics_alpha.png")