import matplotlib.pyplot as plt
import plotly.graph_objects as go
import numpy as np
[docs]
def plot_mesh_2d(
ax,
mesh_coords,
title,
field_name,
field_data=None,
color_bar=True,
xlim=None,
ylim=None,
):
"""
Plot a mesh colored by field data.
:param ax: The axes to plot on.
:type ax: matplotlib.axes.Axes
:param mesh_coords: Array with shape (N, 2) containing the x, y coordinates of the mesh.
:type mesh_coords: numpy.ndarray
:param title: The title of the plot.
:type title: str
:param field_name: The name of the field data.
:type field_name: str
:param field_data: Array with shape (N,) containing the field data. Defaults to None.
:type field_data: numpy.ndarray, optional
:param color_bar: Whether to display a color bar. Defaults to True.
:type color_bar: bool, optional
:param xlim: Optional x-axis limits (min, max). Defaults to None.
:type xlim: tuple, optional
:param ylim: Optional y-axis limits (min, max). Defaults to None.
:type ylim: tuple, optional
:returns: None
"""
x = mesh_coords[:, 0]
y = mesh_coords[:, 1]
scatter = ax.scatter(x, y, c=field_data, cmap="viridis", s=1, edgecolor=None)
if color_bar:
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label(field_name)
ax.set_xlabel("X Coordinate")
ax.set_ylabel("Y Coordinate")
ax.set_title(title)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
[docs]
def plot_mesh_3d_interactive(X, y=None, field_variable_idx=None):
"""
Plot a 3D interactive mesh.
:param X: Array with shape (N, 3) containing the x, y, z coordinates of the mesh.
:type X: numpy.ndarray
:param y: Array with shape (N, M) containing the field data.
:type y: numpy.ndarray
:param field_variable_idx: The index of the field variable to color the mesh by. Defaults to None.
:type field_variable_idx: int, optional
:returns: None
"""
marker = dict(
size=2,
)
if field_variable_idx is not None:
assert y is not None, "Field data must be provided to color the mesh."
marker["color"] = y[:, field_variable_idx]
marker[
"colorscale"
] = "plasma" # choose a colorscale (e.g. viridis, plasma, inferno)
marker["showscale"] = True
fig = go.Figure(
data=[
go.Scatter3d(
x=X[:, 0], y=X[:, 1], z=np.zeros(len(X)), mode="markers", marker=marker
)
]
)
fig.update_layout(
scene=dict(
xaxis_title="X",
yaxis_title="Y",
zaxis_title="Z",
camera_eye=dict(x=1, y=1, z=1.5),
dragmode="orbit",
),
height=800,
margin=dict(l=0, r=0, t=0, b=0),
)
return fig.show()