import torch
import json
from typing import List, Dict
import matplotlib.pyplot as plt
from cetaceo.data import BaseDataset, TorchDataset
from cetaceo.models import Model
[docs]
def zone_metrics_calculation(
model: Model,
dataset: BaseDataset,
evaluators: List = [],
savedir: str = None,
):
r"""
Evaluate the model on a dataset using a custom analysis.
Args:
dataset (BaseDataset): The dataset to evaluate the model on.
evaluators (List, optional): The evaluators to use for evaluating the model. Default is `[]`.
Returns:
metrics (Dict[str, float]): The metrics evaluated on the dataset.
"""
print("\nComputing evaluation zones...")
print(" Original dataset shape:", dataset.x.shape, dataset.y.shape)
n_ipt = dataset.x.shape[1]
n_col = n_ipt-3
tensor = torch.cat((dataset.x, dataset.y), dim=1)
unique_combinations, inverse_indices = torch.unique(tensor[:, :n_col], dim=0, return_inverse=True)
groups_dict = {}
for i in range(unique_combinations.size(0)):
group = tensor[inverse_indices == i]
group_size = group.size(0)
if group_size not in groups_dict:
groups_dict[group_size] = []
groups_dict[group_size].append(group)
sizes = sorted(groups_dict.keys(), reverse=True)
groups_flight = groups_dict[sizes[0]]
groups_coord = groups_dict[sizes[1]]
print(" Number and size of flight conditions evaluation zones:", len(groups_flight), ", ", sizes[0])
print(" Number and size of geometric interpolation evaluation zones:", len(groups_coord), ", ", sizes[1])
print("\nComputing performance in flight conditions interpolation...")
metrics_flight = []
for group in groups_flight:
x, y = group[:, :n_ipt], group[:, n_ipt:]
group_dataset = TorchDataset(x, y, dataset.x_scaler, dataset.y_scaler, dataset.isscaled)
y_pred, y_true = model.predict(group_dataset, rescale_output=True, return_targets=True)
metrics = {}
for evaluator in evaluators:
metrics.update(evaluator(y_true, y_pred))
group_dataset.rescale_data()
metrics["flight_condition"] = group_dataset.x[0, :n_col].tolist()
metrics_flight.append(metrics)
print("\nComputing performance in geometric interpolation...")
metrics_coord = []
for group in groups_coord:
x, y = group[:, :n_ipt], group[:, n_ipt:]
group_dataset = TorchDataset(x, y, dataset.x_scaler, dataset.y_scaler, dataset.isscaled)
y_pred, y_true = model.predict(group_dataset, rescale_output=True, return_targets=True)
metrics = {}
for evaluator in evaluators:
metrics.update(evaluator(y_true, y_pred))
group_dataset.rescale_data()
metrics["flight_condition"] = group_dataset.x[0, :n_col].tolist()
metrics_coord.append(metrics)
if savedir is not None:
print("\nSaving results as .json files...")
with open(savedir + "/metrics_flight.json", "w") as f:
json.dump(metrics_flight, f)
with open(savedir + "/metrics_coord.json", "w") as f:
json.dump(metrics_coord, f)
return metrics_flight, metrics_coord
[docs]
def zone_metrics_plotter(
metrics_flight: List[Dict],
metrics_coord: List[Dict],
savedir: str = None,
):
def generate_plot(dict, config):
n_fig = len(config["metrics"])
n_row = 2
n_col = n_fig // n_row
markers = ["o", "s", "D", "v", "^", "<", ">", "p", "P", "*", "X", "d", "1", "2", "3", "4", "8", "h", "H", "+", "x", "|", "_"]
colors = ["b", "g", "r", "c", "m", "y", "k", "w"]
marker_map = {"mach": list(), "alt": list()}
fig = plt.figure(figsize=(8*n_col, 8*n_row))
if not config["3Dplot"]:
fig.subplots_adjust(left=0.075, right=0.925, top=0.95, bottom=0.05, wspace=0.3, hspace=0.3)
plt.tight_layout()
for i, metric in enumerate(config["metrics"]):
x = [d["flight_condition"][0] for d in dict]
y = [d["flight_condition"][1] for d in dict]
z = [d["flight_condition"][2] for d in dict]
mtrc = [d[metric["name"]] for d in dict]
if config["3Dplot"]:
ax = fig.add_subplot(n_row, n_col, i+1, projection='3d')
scatter = ax.scatter(x, y, z, c=mtrc, cmap='viridis', s=100)
cbar = plt.colorbar(scatter, ax=ax, shrink=0.5, aspect=10)
cbar.set_label(metric["label"])
ax.set_xlabel('Mach [-]')
ax.set_ylabel('Angle of Attack')
ax.set_zlabel('Altitude [FL]')
ax.set_title(metric["label"] + " by" + config["group_label"])
else:
ax = fig.add_subplot(n_row, n_col, i+1)
for x_idx, xi in enumerate(x):
yi = y[x_idx]
zi = z[x_idx]
mtrci = mtrc[x_idx]
if xi not in marker_map["mach"]:
marker_map["mach"].append(xi)
marker = markers[marker_map["mach"].index(xi)]
else:
marker = markers[marker_map["mach"].index(xi)]
if zi not in marker_map["alt"]:
marker_map["alt"].append(zi)
color = colors[marker_map["alt"].index(zi)]
else:
color = colors[marker_map["alt"].index(zi)]
scatter = ax.scatter(yi, y=mtrci, marker=marker, color=color)
ax.set_xlabel('Angle of Attack [deg]')
ax.set_ylabel(metric["label"])
ax.set_title(metric["label"] + " by" + config["group_label"])
mach_markers = [ax.scatter([], [], marker=markers[n], color='k', label=str('Mach '+str(round(mach,1)))) for n, mach in enumerate(marker_map["mach"])]
mach_markers.sort(key=lambda x: x.get_label())
alt_markers = [ax.scatter([], [], marker='o', color=colors[n], label=str('FL '+str(int(alt)))) for n, alt in enumerate(marker_map["alt"])]
alt_markers.sort(key=lambda x: x.get_label())
custom_markers = mach_markers + alt_markers
ax.legend(handles=custom_markers)
if config["3Dplot"]:
fig.subplots_adjust(left=0.075, right=0.925, top=0.95, bottom=0.05, wspace=0.3, hspace=0.3)
plt.tight_layout()
return fig
print(" Plotting results for flight conditions interpolation...")
config = {"group_name": "flight_cond",
"group_label": "Flight Conditions",
"metrics": [{"name":"mre", "label":"Mean Relative Error (MRE)"},
{"name":"mae", "label":"Mean Absolute Error (MAE)"},
{"name":"mse", "label":"Mean Squared Error (MSE)"},
{"name":"r2", "label":"R2 Score"},
{"name":"ae_95", "label":"Absolute Error Quantile 95%"},
{"name":"ae_99", "label":"Absolute Error Quantile 99%"}],
"3Dplot": True}
fig = generate_plot(metrics_flight, config)
fig.savefig(savedir + "/plots/flight_cond_metrics.png")
config["3Dplot"] = False
fig = generate_plot(metrics_flight, config)
fig.savefig(savedir + "/plots/flight_cond_metrics_alpha.png")
print(" Plotting results for flight conditions interpolation...")
config = {"group_name": "coord_interp",
"group_label": "Flight Conditions",
"metrics": [{"name":"mre", "label":"Mean Relative Error (MRE)"},
{"name":"mae", "label":"Mean Absolute Error (MAE)"},
{"name":"mse", "label":"Mean Squared Error (MSE)"},
{"name":"r2", "label":"R2 Score"},
{"name":"ae_95", "label":"Absolute Error Quantile 95%"},
{"name":"ae_99", "label":"Absolute Error Quantile 99%"}],
"3Dplot": True}
fig = generate_plot(metrics_coord, config)
fig.savefig(savedir + "/plots/coord_interp_metrics.png")
config["3Dplot"] = False
fig = generate_plot(metrics_coord, config)
fig.savefig(savedir + "/plots/coord_interp_metrics_alpha.png")