Source code for complex_problems.coupled_oscillators.solver

"""Solver for coupled harmonic oscillators."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable

import numpy as np
from scipy.integrate import solve_ivp

from complex_problems.coupled_oscillators.model import (
    build_ode_function,
    compute_normal_modes,
)
from config import get_default_solver_method, get_env_from_schema
from utils import SolverFailedError, get_logger

logger = get_logger(__name__)


[docs] @dataclass class CoupledOscillatorsResult: """Result from solving coupled oscillators. Attributes: x: Time values (1D). y: State array shape (2*N, n_points): [x_0..x_{N-1}, v_0..v_{N-1}]. n_oscillators: Number of oscillators. masses: Mass array (n_oscillators,). k_coupling: Coupling array (n_oscillators-1 or n_oscillators for periodic). M_modes: Mode matrix (n, n), columns are eigenvectors. omega_modes: Angular frequencies of modes (1D). has_modes: True if normal modes were computed (linear only). metadata: Additional solver metadata. """ x: np.ndarray y: np.ndarray n_oscillators: int masses: np.ndarray k_coupling: np.ndarray M_modes: np.ndarray omega_modes: np.ndarray has_modes: bool metadata: dict[str, Any] = field(default_factory=dict)
[docs] def solve_coupled_oscillators( n_oscillators: int, masses: float | list[float] | Callable[[int], float], k_coupling: float | list[float] | Callable[[int], float], boundary: str = "fixed", coupling_types: list[str] | None = None, nonlinear_coeff: float = 0.0, nonlinear_fput_alpha: float = 0.0, nonlinear_quartic: float = 0.0, nonlinear_quintic: float = 0.0, k_2nn: float = 0.0, k_3nn: float = 0.0, k_4nn: float = 0.0, external_amplitude: float = 0.0, external_frequency: float = 1.0, t_min: float = 0.0, t_max: float = 30.0, n_points: int | None = None, y0: list[float] | None = None, method: str | None = None, ) -> CoupledOscillatorsResult: """Solve the coupled oscillators system. Args: n_oscillators: Number of oscillators. masses: Mass specification. k_coupling: Coupling specification. boundary: "fixed" or "periodic". coupling_types: List of coupling types. nonlinear_coeff: Cubic nonlinear coupling coefficient (FPUT-β). nonlinear_fput_alpha: FPUT-α quadratic nonlinear coefficient. nonlinear_quartic: Quartic nonlinear coefficient. nonlinear_quintic: Quintic nonlinear coefficient. k_2nn: 2nd-neighbor linear coupling (0 = disabled). k_3nn: 3rd-neighbor linear coupling (0 = disabled). k_4nn: 4th-neighbor linear coupling (0 = disabled). external_amplitude: External force amplitude. external_frequency: External force frequency. t_min: Start time. t_max: End time. n_points: Number of output points (default from env). y0: Initial conditions [x_0, ..., x_{N-1}, v_0, ..., v_{N-1}]. method: Solver method. Returns: CoupledOscillatorsResult with solution and metadata. Raises: SolverFailedError: If integration fails. """ coupling_types = coupling_types or ["linear"] n = n_oscillators if n_points is None: n_points = int(get_env_from_schema("SOLVER_NUM_POINTS")) t_eval = np.linspace(t_min, t_max, n_points) if y0 is None: y0 = [0.0] * (2 * n) y0[0] = 1.0 # Excite first oscillator y0_arr = np.array(y0, dtype=float) ode_func = build_ode_function( n_oscillators=n, masses=masses, k_coupling=k_coupling, boundary=boundary, coupling_types=coupling_types, nonlinear_coeff=nonlinear_coeff, nonlinear_fput_alpha=nonlinear_fput_alpha, nonlinear_quartic=nonlinear_quartic, nonlinear_quintic=nonlinear_quintic, k_2nn=k_2nn, k_3nn=k_3nn, k_4nn=k_4nn, external_amplitude=external_amplitude, external_frequency=external_frequency, ) method = method or get_default_solver_method() rtol = get_env_from_schema("SOLVER_RTOL") atol = get_env_from_schema("SOLVER_ATOL") max_step = get_env_from_schema("SOLVER_MAX_STEP") effective_max_step = np.inf if max_step <= 0 else max_step logger.info( "Solving coupled oscillators: n=%d, t=[%g, %g], method=%s", n, t_min, t_max, method, ) sol = solve_ivp( fun=ode_func, t_span=(t_min, t_max), y0=y0_arr, method=method, t_eval=t_eval, max_step=effective_max_step, rtol=rtol, atol=atol, dense_output=True, ) if not sol.success: logger.error("Solver failed: %s", sol.message) raise SolverFailedError(f"Solver failed ({method}): {sol.message}") # Resolve masses and k to arrays for result from complex_problems.coupled_oscillators.model import _resolve_k, _resolve_mass masses_arr = np.array([_resolve_mass(masses, i, n) for i in range(n)]) n_springs = n if boundary == "periodic" else n - 1 k_arr = np.array([_resolve_k(k_coupling, i, n) for i in range(n_springs)]) # Compute linear modes whenever linear coupling is present; use as projection basis # even for nonlinear/external systems (energy per mode not conserved in that case) has_modes = "linear" in coupling_types M_modes = np.eye(n) omega_modes = np.ones(n) if has_modes: M_modes, omega_modes = compute_normal_modes( n, masses, k_coupling, boundary, k_2nn=k_2nn, k_3nn=k_3nn, k_4nn=k_4nn, ) metadata = { "method": method, "n_eval": getattr(sol, "nfev", 0), "boundary": boundary, "coupling_types": coupling_types, "nonlinear_coeff": nonlinear_coeff, "nonlinear_fput_alpha": nonlinear_fput_alpha, "nonlinear_quartic": nonlinear_quartic, "nonlinear_quintic": nonlinear_quintic, "k_2nn": k_2nn, "k_3nn": k_3nn, "k_4nn": k_4nn, } return CoupledOscillatorsResult( x=sol.t, y=sol.y, n_oscillators=n, masses=masses_arr, k_coupling=k_arr, M_modes=M_modes, omega_modes=omega_modes, has_modes=has_modes, metadata=metadata, )