"""Matplotlib plotting utilities for ODE solutions."""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from config import get_env_from_schema
from utils import get_logger
logger = get_logger(__name__)
_MAX_ELEMENTS_PLOT = 50
_SUB_DIGS = "\u2080\u2081\u2082\u2083\u2084\u2085\u2086\u2087\u2088\u2089"
def _get_colors(color_scheme: str, n: int) -> list:
"""Get a list of n colors from the specified colormap with fallback.
Args:
color_scheme: Matplotlib colormap name (e.g., 'Set1', 'tab10').
n: Number of colors to generate.
Returns:
List of matplotlib color objects.
"""
import matplotlib.pyplot as plt
import numpy as np
try:
cmap = plt.colormaps.get_cmap(color_scheme)
return [cmap(i / max(1, n - 1)) for i in range(n)]
except (ValueError, AttributeError):
logger.debug("Colormap '%s' invalid, using Set1 fallback", color_scheme)
cmap_fallback = plt.colormaps.get_cmap("Set1")
return list(cmap_fallback(np.linspace(0, 1, max(1, n))))
def _apply_plot_style() -> None:
"""Configure matplotlib rcParams from environment variables."""
import matplotlib
matplotlib.rcParams.update(
{
"font.family": get_env_from_schema("FONT_FAMILY"),
"font.size": get_env_from_schema("FONT_TICK_SIZE"),
"axes.titlesize": get_env_from_schema("FONT_TITLE_SIZE"),
"axes.titleweight": get_env_from_schema("FONT_TITLE_WEIGHT"),
"axes.labelsize": get_env_from_schema("FONT_AXIS_SIZE"),
"figure.dpi": get_env_from_schema("DPI"),
}
)
def _new_figure() -> tuple[Any, Any]:
"""Create a configured figure and axes from env settings.
Returns:
Tuple of ``(fig, ax)``.
"""
import matplotlib.pyplot as plt
_apply_plot_style()
width: int = get_env_from_schema("PLOT_FIGSIZE_WIDTH")
height: int = get_env_from_schema("PLOT_FIGSIZE_HEIGHT")
dpi: int = get_env_from_schema("DPI")
return plt.subplots(figsize=(width, height), dpi=dpi)
def _finalize_plot(
ax: Axes,
title: str,
xlabel: str,
ylabel: str,
*,
legend: bool = False,
) -> None:
"""Apply title, grid, axis labels, and optional legend from env settings.
Args:
ax: Matplotlib axes.
title: Plot title (shown only if ``PLOT_SHOW_TITLE`` is enabled).
xlabel: Label for x-axis.
ylabel: Label for y-axis.
legend: Whether to display the legend.
"""
axis_style: str = get_env_from_schema("FONT_AXIS_STYLE")
if get_env_from_schema("PLOT_SHOW_TITLE") and title:
ax.set_title(title)
ax.set_xlabel(xlabel, fontstyle=axis_style)
ax.set_ylabel(ylabel, fontstyle=axis_style)
if get_env_from_schema("PLOT_SHOW_GRID"):
grid_alpha: float = get_env_from_schema("PLOT_GRID_ALPHA")
ax.grid(True, alpha=grid_alpha)
if legend:
ax.legend()
def _new_3d_figure() -> tuple[Any, Any]:
"""Create a configured 3D figure and axes from env settings.
Returns:
Tuple of ``(fig, ax)`` with 3D projection.
"""
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
_apply_plot_style()
width: int = get_env_from_schema("PLOT_FIGSIZE_WIDTH")
height: int = get_env_from_schema("PLOT_FIGSIZE_HEIGHT")
dpi: int = get_env_from_schema("DPI")
fig = plt.figure(figsize=(width, height), dpi=dpi)
ax = fig.add_subplot(111, projection="3d")
return fig, ax
def _finalize_3d_plot(
ax: Axes,
title: str,
xlabel: str,
ylabel: str,
zlabel: str,
*,
legend: bool = False,
) -> None:
"""Apply title, axis labels, and optional legend for 3D plots."""
axis_style: str = get_env_from_schema("FONT_AXIS_STYLE")
if get_env_from_schema("PLOT_SHOW_TITLE") and title:
ax.set_title(title)
ax.set_xlabel(xlabel, fontstyle=axis_style)
ax.set_ylabel(ylabel, fontstyle=axis_style)
ax.set_zlabel(zlabel, fontstyle=axis_style)
if legend:
ax.legend()
def _component_labels(n: int) -> list[str]:
"""Generate f₀, f₁, ... labels for component indices."""
return [f"f{_SUB_DIGS[i]}" if i < len(_SUB_DIGS) else f"f_{i}" for i in range(n)]
[docs]
def create_solution_plot(
x: np.ndarray,
y: np.ndarray,
title: str = "f(x)",
xlabel: str = "x",
ylabel: str = "f",
show_markers: bool = False,
selected_derivatives: list[int] | None = None,
labels: list[str] | None = None,
) -> Figure:
"""Create a publication-ready plot of the ODE solution.
Args:
x: Independent variable values.
y: Solution values — shape ``(n_vars, n_points)`` or ``(n_points,)``.
title: Plot title.
xlabel: Label for x-axis.
ylabel: Label for y-axis.
show_markers: Whether to overlay data-point markers.
selected_derivatives: Indices of solution components to plot.
labels: Custom legend labels for each derivative (f-notation).
Returns:
A matplotlib :class:`Figure`.
"""
import numpy as np
fig, ax = _new_figure()
line_color: str = get_env_from_schema("PLOT_LINE_COLOR")
line_width: float = get_env_from_schema("PLOT_LINE_WIDTH")
line_style: str = get_env_from_schema("PLOT_LINE_STYLE")
color_scheme: str = get_env_from_schema("PLOT_COLOR_SCHEME")
y_2d = np.atleast_2d(y)
if y_2d.shape[1] != len(x):
y_2d = y_2d.T
if selected_derivatives is None:
selected_derivatives = list(range(y_2d.shape[0]))
if labels is None:
labels = ["f"] if y_2d.shape[0] == 1 else [f"f[{i}]" for i in range(y_2d.shape[0])]
n_colors = max(1, len(selected_derivatives) - 1)
colors = [line_color] + _get_colors(color_scheme, n_colors)
for plot_idx, deriv_idx in enumerate(selected_derivatives):
if deriv_idx >= y_2d.shape[0]:
continue
color = colors[plot_idx] if plot_idx < len(colors) else None
ax.plot(
x,
y_2d[deriv_idx],
color=color,
linewidth=line_width,
linestyle=line_style,
label=labels[deriv_idx],
)
if show_markers:
marker: str = get_env_from_schema("PLOT_MARKER_FORMAT")
msize: int = get_env_from_schema("PLOT_MARKER_SIZE")
mfc: str = get_env_from_schema("PLOT_MARKER_FACE_COLOR")
mec: str = get_env_from_schema("PLOT_MARKER_EDGE_COLOR")
step = max(1, len(x) // _MAX_ELEMENTS_PLOT)
for plot_idx, deriv_idx in enumerate(selected_derivatives):
if deriv_idx >= y_2d.shape[0]:
continue
ax.plot(
x[::step],
y_2d[deriv_idx, ::step],
marker=marker,
markersize=msize,
markerfacecolor=mfc,
markeredgecolor=mec,
linestyle="none",
)
_finalize_plot(ax, title, xlabel, ylabel, legend=len(selected_derivatives) > 1)
fig.tight_layout()
return fig
[docs]
def create_energy_evolution_plot(
t: np.ndarray,
E_kin: np.ndarray,
E_pot: np.ndarray,
E_tot: np.ndarray,
title: str = "Energy vs time",
xlabel: str = "t",
) -> Figure:
"""Create a plot of kinetic, potential, and total energy vs time.
Args:
t: Time values (1D).
E_kin: Kinetic energy at each time step.
E_pot: Potential energy at each time step.
E_tot: Total energy at each time step.
title: Plot title.
xlabel: Label for x-axis.
Returns:
A matplotlib :class:`Figure`.
"""
fig, ax = _new_figure()
line_width: float = get_env_from_schema("PLOT_LINE_WIDTH")
line_style: str = get_env_from_schema("PLOT_LINE_STYLE")
ax.plot(t, E_kin, linewidth=line_width, linestyle=line_style, label="Kinetic")
ax.plot(t, E_pot, linewidth=line_width, linestyle=line_style, label="Potential")
ax.plot(t, E_tot, linewidth=line_width, linestyle=line_style, label="Total")
_finalize_plot(ax, title, xlabel, "Energy", legend=True)
fig.tight_layout()
return fig
[docs]
def create_energy_per_mode_plot(
t: np.ndarray,
E_modes: np.ndarray,
selected_indices: list[int],
labels: list[str],
title: str = "Energy per mode",
xlabel: str = "t",
) -> Figure:
"""Create a multi-line plot of energy per mode (or oscillator) vs time.
Args:
t: Time values (1D).
E_modes: Energy array shape (n_modes, n_points).
selected_indices: Indices of modes/oscillators to plot.
labels: Legend labels for each selected index.
title: Plot title.
xlabel: Label for x-axis.
Returns:
A matplotlib :class:`Figure`.
"""
fig, ax = _new_figure()
line_width: float = get_env_from_schema("PLOT_LINE_WIDTH")
line_style: str = get_env_from_schema("PLOT_LINE_STYLE")
color_scheme: str = get_env_from_schema("PLOT_COLOR_SCHEME")
colors = _get_colors(color_scheme, len(selected_indices))
for idx, (mode_idx, lbl) in enumerate(zip(selected_indices, labels)):
if mode_idx < E_modes.shape[0]:
ax.plot(
t,
E_modes[mode_idx],
color=colors[idx],
linewidth=line_width,
linestyle=line_style,
label=lbl,
)
_finalize_plot(ax, title, xlabel, "Energy", legend=True)
fig.tight_layout()
return fig
[docs]
def create_phase_plot(
y: np.ndarray,
title: str = "Phase Portrait",
xlabel: str = "f",
ylabel: str = "f'",
x: np.ndarray | None = None,
) -> Figure:
"""Create a phase portrait for an ODE.
For second-order (or higher): plots y vs y' (position vs velocity).
For first-order: plots y vs dy/dx using numerical derivative (requires x).
Args:
y: Solution array — shape ``(n_vars, n_points)``.
title: Plot title.
xlabel: Label for horizontal axis.
ylabel: Label for vertical axis.
x: Independent variable (required for first-order to compute dy/dx).
Returns:
A matplotlib :class:`Figure`.
"""
import numpy as np
fig, ax = _new_figure()
line_color: str = get_env_from_schema("PLOT_LINE_COLOR")
line_width: float = get_env_from_schema("PLOT_LINE_WIDTH")
y_2d = np.atleast_2d(y)
if y_2d.shape[0] >= 2:
ax.plot(y_2d[0], y_2d[1], color=line_color, linewidth=line_width)
horiz, vert = y_2d[0], y_2d[1]
else:
if x is None:
raise ValueError("x is required for first-order phase portrait")
horiz = y_2d[0]
vert = np.gradient(y_2d[0], x)
ax.plot(horiz, vert, color=line_color, linewidth=line_width)
phase_start_color: str = get_env_from_schema("PLOT_PHASE_START_COLOR")
phase_end_color: str = get_env_from_schema("PLOT_PHASE_END_COLOR")
phase_marker_size: int = get_env_from_schema("PLOT_PHASE_MARKER_SIZE")
ax.plot(
horiz[0],
vert[0],
"o",
color=phase_start_color,
markersize=phase_marker_size,
label="Start",
)
ax.plot(
horiz[-1],
vert[-1],
"s",
color=phase_end_color,
markersize=phase_marker_size,
label="End",
)
_finalize_plot(ax, title, xlabel, ylabel, legend=True)
fig.tight_layout()
return fig
[docs]
def create_phase_3d_plot(
data_x: np.ndarray,
data_y: np.ndarray,
data_z: np.ndarray,
title: str = "Phase Space 3D",
xlabel: str = "f\u2080",
ylabel: str = "f\u2081",
zlabel: str = "f\u2082",
) -> Figure:
"""Create a 3D phase-space trajectory plot.
Args:
data_x: Values for the x-axis.
data_y: Values for the y-axis.
data_z: Values for the z-axis.
title: Plot title.
xlabel: Label for x-axis.
ylabel: Label for y-axis.
zlabel: Label for z-axis.
Returns:
A matplotlib :class:`Figure`.
"""
fig, ax = _new_3d_figure()
line_color: str = get_env_from_schema("PLOT_LINE_COLOR")
line_width: float = get_env_from_schema("PLOT_LINE_WIDTH")
phase_start_color: str = get_env_from_schema("PLOT_PHASE_START_COLOR")
phase_end_color: str = get_env_from_schema("PLOT_PHASE_END_COLOR")
phase_marker_size: int = get_env_from_schema("PLOT_PHASE_MARKER_SIZE")
ax.plot(data_x, data_y, data_z, color=line_color, linewidth=line_width)
ax.plot(
[data_x[0]],
[data_y[0]],
[data_z[0]],
"o",
color=phase_start_color,
markersize=phase_marker_size,
label="Start",
)
ax.plot(
[data_x[-1]],
[data_y[-1]],
[data_z[-1]],
"s",
color=phase_end_color,
markersize=phase_marker_size,
label="End",
)
_finalize_3d_plot(ax, title, xlabel, ylabel, zlabel, legend=True)
fig.tight_layout()
return fig
[docs]
def create_surface_plot(
x: np.ndarray,
y: np.ndarray,
z: np.ndarray,
title: str = "f(x, y)",
xlabel: str = "x",
ylabel: str = "y",
zlabel: str = "f",
) -> Figure:
"""Create a 3D surface plot for 2D scalar field data.
Args:
x: 1D array of x values.
y: 1D array of y values.
z: 2D array of values, shape (len(y), len(x)).
title: Plot title.
xlabel: Label for x-axis.
ylabel: Label for y-axis.
zlabel: Label for z-axis.
Returns:
A matplotlib :class:`Figure`.
"""
import numpy as np
fig, ax = _new_3d_figure()
X, Y = np.meshgrid(x, y)
if z.shape != X.shape:
z = np.asarray(z)
if z.shape != X.shape:
raise ValueError(f"z shape {z.shape} does not match grid {X.shape}")
# Mask NaN values (e.g. exterior points in non-rectangular PDE domains)
z_masked = np.ma.masked_invalid(z)
surface_cmap: str = get_env_from_schema("PLOT_SURFACE_CMAP")
surface_alpha: float = get_env_from_schema("PLOT_SURFACE_ALPHA")
colorbar_shrink: float = get_env_from_schema("PLOT_COLORBAR_SHRINK")
surf = ax.plot_surface(
X,
Y,
z_masked,
cmap=surface_cmap,
alpha=surface_alpha,
edgecolor="none",
)
fig.colorbar(surf, ax=ax, shrink=colorbar_shrink)
_finalize_3d_plot(ax, title, xlabel, ylabel, zlabel)
fig.tight_layout()
return fig
[docs]
def create_contour_plot(
x: np.ndarray,
y: np.ndarray,
z: np.ndarray,
title: str = "f(x, y)",
xlabel: str = "x",
ylabel: str = "f",
) -> Figure:
"""Create a 2D contour plot for 2D scalar field data.
Args:
x: 1D array of x values.
y: 1D array of y values.
z: 2D array of values, shape (len(y), len(x)).
title: Plot title.
xlabel: Label for x-axis.
ylabel: Label for y-axis.
Returns:
A matplotlib :class:`Figure`.
"""
import numpy as np
fig, ax = _new_figure()
X, Y = np.meshgrid(x, y)
if z.shape != X.shape:
z = np.asarray(z)
if z.shape != X.shape:
raise ValueError(f"z shape {z.shape} does not match grid {X.shape}")
# Mask NaN values (e.g. exterior points in non-rectangular PDE domains)
z_masked = np.ma.masked_invalid(z)
contour_levels: int = get_env_from_schema("PLOT_CONTOUR_LEVELS")
surface_cmap: str = get_env_from_schema("PLOT_SURFACE_CMAP")
contour = ax.contourf(X, Y, z_masked, levels=contour_levels, cmap=surface_cmap)
fig.colorbar(contour, ax=ax)
_finalize_plot(ax, title, xlabel, ylabel)
fig.tight_layout()
return fig
[docs]
def create_vector_animation_plot(
x: np.ndarray,
y: np.ndarray,
order: int,
vector_components: int,
title: str = "f_i(x) vs component",
deriv_offset: int = 0,
component_labels: list[str] | None = None,
) -> Figure:
"""Create an interactive plot: x-axis = component index i, y-axis = f_i(x).
The figure stores _animation_update(idx) and _animation_n_points for use
with a Tkinter Scale (matplotlib Slider is unreliable when embedded in Tk).
For vector ODE: y has shape (n_state, n_points), with f_i at y[i*order].
Args:
x: Independent variable values (1D).
y: Solution array, shape (n_state, n_points).
order: Order per component.
vector_components: Number of components.
title: Plot title.
deriv_offset: Derivative order to display (0=value, 1=first derivative, etc.).
Returns:
A matplotlib Figure (use with embed_animation_plot_in_tk).
"""
import matplotlib.pyplot as plt
import numpy as np
_apply_plot_style()
width: int = get_env_from_schema("PLOT_FIGSIZE_WIDTH")
height: int = get_env_from_schema("PLOT_FIGSIZE_HEIGHT")
dpi: int = get_env_from_schema("DPI")
fig = plt.figure(figsize=(width, height), dpi=dpi)
ax_main = fig.add_axes([0.12, 0.15, 0.78, 0.78])
y_2d = np.atleast_2d(y)
if y_2d.shape[1] != len(x):
y_2d = y_2d.T if y_2d.shape[0] == len(x) else y_2d
row_indices = np.arange(vector_components) * order + deriv_offset
f_values = y_2d[row_indices]
n_points = len(x)
time_index = 0
color_scheme: str = get_env_from_schema("PLOT_COLOR_SCHEME")
colors = _get_colors(color_scheme, vector_components)
marker_size: int = get_env_from_schema("PLOT_PHASE_MARKER_SIZE")
anim_line_width: float = get_env_from_schema("PLOT_ANIMATION_LINE_WIDTH")
vlines_line_width: float = get_env_from_schema("PLOT_VLINES_LINE_WIDTH")
vlines_alpha: float = get_env_from_schema("PLOT_VLINES_ALPHA")
y_margin: float = get_env_from_schema("PLOT_ANIMATION_Y_MARGIN")
y_min_global = float(np.min(f_values)) - y_margin
y_max_global = float(np.max(f_values)) + y_margin
ax_main.set_ylim(y_min_global, y_max_global)
indices = np.arange(vector_components)
vals = f_values[:, time_index]
(line_chain,) = ax_main.plot(
indices,
vals,
"o-",
color=colors[0],
markersize=marker_size,
linewidth=anim_line_width,
)
vlines_coll = ax_main.vlines(
indices, 0, vals, colors=colors, linewidth=vlines_line_width, alpha=vlines_alpha
)
ax_main.set_xticks(indices)
ax_main.set_xticklabels(
component_labels if component_labels is not None else _component_labels(vector_components)
)
j_vals = indices # Reuse for segment construction in update
def update(idx: int) -> None:
i = max(0, min(idx, n_points - 1))
new_vals = f_values[:, i]
line_chain.set_ydata(new_vals)
segments = np.stack(
[
np.column_stack([j_vals, np.zeros(vector_components)]),
np.column_stack([j_vals, new_vals]),
],
axis=1,
)
vlines_coll.set_segments(segments)
fig.canvas.draw_idle()
if get_env_from_schema("PLOT_SHOW_TITLE") and title:
ax_main.set_title(title)
axis_style: str = get_env_from_schema("FONT_AXIS_STYLE")
ax_main.set_xlabel("Component index i", fontstyle=axis_style)
ax_main.set_ylabel("f_i(x)", fontstyle=axis_style)
if get_env_from_schema("PLOT_SHOW_GRID"):
grid_alpha: float = get_env_from_schema("PLOT_GRID_ALPHA")
ax_main.grid(True, alpha=grid_alpha)
fig.subplots_adjust(bottom=0.12, left=0.12, right=0.95, top=0.92)
fig._animation_update = update
fig._animation_n_points = n_points
fig._animation_initial_index = time_index
fig._animation_x = x
fig._animation_f_values = f_values
fig._animation_vector_components = vector_components
return fig
[docs]
def create_vector_animation_3d(
x: np.ndarray,
y: np.ndarray,
order: int,
vector_components: int,
title: str = "f_i(x) — 3D",
deriv_offset: int = 0,
) -> Figure:
"""Create a 3D plot: x (independent), component index i, f_i(x).
Args:
x: Independent variable values.
y: Solution array, shape (n_state, n_points).
order: Order per component.
vector_components: Number of components.
title: Plot title.
deriv_offset: Derivative order to display (0=value, 1=first derivative, etc.).
Returns:
A matplotlib Figure with 3D surface.
"""
import numpy as np
fig, ax = _new_3d_figure()
y_2d = np.atleast_2d(y)
if y_2d.shape[1] != len(x):
y_2d = y_2d.T if y_2d.shape[0] == len(x) else y_2d
surface_cmap: str = get_env_from_schema("PLOT_SURFACE_CMAP")
surface_alpha: float = get_env_from_schema("PLOT_SURFACE_ALPHA")
row_indices = np.arange(vector_components) * order + deriv_offset
Z_grid = y_2d[row_indices]
X_grid, I_grid = np.meshgrid(x, np.arange(vector_components))
ax.plot_surface(
X_grid, I_grid, Z_grid, cmap=surface_cmap, alpha=surface_alpha, edgecolor="none"
)
_finalize_3d_plot(ax, title, "x", "Component i", "f_i(x)")
fig.tight_layout()
return fig
_MAX_MP4_FRAMES = 500
[docs]
def export_animation_to_mp4(
x: np.ndarray,
y: np.ndarray,
order: int,
vector_components: int,
filepath: Path,
*,
title: str = "f_i(x) vs component",
duration_seconds: float = 10.0,
deriv_offset: int = 0,
) -> Path:
"""Export vector animation as MP4 video.
Frames are downsampled to at most _MAX_MP4_FRAMES to avoid memory exhaustion.
Requires ffmpeg to be installed on the system.
Args:
x: Independent variable values.
y: Solution array, shape (n_state, n_points).
order: Order per component.
vector_components: Number of components.
filepath: Output path for the MP4 file.
title: Plot title.
duration_seconds: Desired video duration in seconds. FPS is computed.
Returns:
The path that was written.
Raises:
RuntimeError: If ffmpeg is not available.
"""
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation
_apply_plot_style()
dpi: int = get_env_from_schema("DPI")
y_2d = np.atleast_2d(y)
if y_2d.shape[1] != len(x):
y_2d = y_2d.T if y_2d.shape[0] == len(x) else y_2d
row_indices = np.arange(vector_components) * order + deriv_offset
f_values = y_2d[row_indices]
n_points = len(x)
frame_indices = np.linspace(0, n_points - 1, min(n_points, _MAX_MP4_FRAMES), dtype=int)
num_frames = len(frame_indices)
fps = max(1, num_frames / max(0.5, duration_seconds))
color_scheme: str = get_env_from_schema("PLOT_COLOR_SCHEME")
colors = _get_colors(color_scheme, vector_components)
width: int = get_env_from_schema("PLOT_FIGSIZE_WIDTH")
height: int = get_env_from_schema("PLOT_FIGSIZE_HEIGHT")
marker_size: int = get_env_from_schema("PLOT_PHASE_MARKER_SIZE")
anim_line_width: float = get_env_from_schema("PLOT_ANIMATION_LINE_WIDTH")
vlines_line_width: float = get_env_from_schema("PLOT_VLINES_LINE_WIDTH")
vlines_alpha: float = get_env_from_schema("PLOT_VLINES_ALPHA")
y_margin: float = get_env_from_schema("PLOT_ANIMATION_Y_MARGIN")
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
y_min = float(np.min(f_values)) - y_margin
y_max = float(np.max(f_values)) + y_margin
ax.set_ylim(y_min, y_max)
ax.set_xlabel("Component index i")
ax.set_ylabel("f_i(x)")
if get_env_from_schema("PLOT_SHOW_TITLE") and title:
ax.set_title(title)
if get_env_from_schema("PLOT_SHOW_GRID"):
grid_alpha: float = get_env_from_schema("PLOT_GRID_ALPHA")
ax.grid(True, alpha=grid_alpha)
indices = np.arange(vector_components)
vals = f_values[:, frame_indices[0]]
(line_chain,) = ax.plot(
indices,
vals,
"o-",
color=colors[0],
markersize=marker_size,
linewidth=anim_line_width,
)
vlines_coll = ax.vlines(
indices, 0, vals, colors=colors, linewidth=vlines_line_width, alpha=vlines_alpha
)
ax.set_xticks(indices)
ax.set_xticklabels(_component_labels(vector_components))
j_vals = indices
def _frame(idx: int) -> None:
new_vals = f_values[:, idx]
line_chain.set_ydata(new_vals)
segments = np.stack(
[
np.column_stack([j_vals, np.zeros(vector_components)]),
np.column_stack([j_vals, new_vals]),
],
axis=1,
)
vlines_coll.set_segments(segments)
anim = FuncAnimation(
fig,
lambda i: _frame(frame_indices[i]),
frames=num_frames,
interval=int(1000 / fps),
blit=False,
)
from matplotlib.animation import writers
if not writers.is_available("ffmpeg"):
plt.close(fig)
raise RuntimeError("FFMpeg is not available. Install ffmpeg and ensure it is in your PATH.")
try:
filepath.parent.mkdir(parents=True, exist_ok=True)
anim.save(str(filepath), writer="ffmpeg", fps=fps)
except Exception as exc:
logger.error("MP4 export failed: %s", exc, exc_info=True)
plt.close(fig)
raise RuntimeError(f"Failed to export MP4. Is ffmpeg installed? {exc}") from exc
plt.close(fig)
logger.info("Animation exported: %s (%d frames)", filepath, len(frame_indices))
return filepath