Source code for solver.ode_solver

"""Core ODE solving wrappers around SciPy integrators."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable

import numpy as np
from scipy.integrate import solve_ivp
from scipy.optimize import fsolve

from config import get_default_solver_method, get_env_from_schema
from utils import SolverFailedError, get_logger

logger = get_logger(__name__)


[docs] @dataclass class ODESolution: """Container for ODE solution data. Attributes: x: Independent variable values. y: Solution array — shape ``(n_vars, n_points)``. success: Whether the solver converged. message: Solver status message. method_used: Name of the integration method. n_eval: Number of function evaluations (if available). raw: The original ``OdeResult`` from SciPy. """ x: np.ndarray y: np.ndarray success: bool message: str method_used: str n_eval: int = 0 raw: Any = field(default=None, repr=False)
def _resolve_solver_params( method: str | None, max_step: float | None, rtol: float | None, atol: float | None, t_span: tuple[float, float], ) -> tuple[str, float, float, float, np.ndarray]: """Resolve solver parameters using environment defaults if needed. Args: method: Integration method name (or None for default). max_step: Maximum step size (or None for default). rtol: Relative tolerance (or None for default). atol: Absolute tolerance (or None for default). t_span: Integration span to compute uniform grid over. Returns: Tuple of (method, effective_max_step, rtol, atol, t_eval). """ if method is None: method = get_default_solver_method() if max_step is None: max_step = get_env_from_schema("SOLVER_MAX_STEP") if rtol is None: rtol = get_env_from_schema("SOLVER_RTOL") if atol is None: atol = get_env_from_schema("SOLVER_ATOL") effective_max_step = np.inf if max_step <= 0 else max_step n_points: int = get_env_from_schema("SOLVER_NUM_POINTS") t_eval = np.linspace(t_span[0], t_span[1], n_points) return method, effective_max_step, rtol, atol, t_eval
[docs] def solve_ode( ode_func: Callable[[float, np.ndarray], np.ndarray], t_span: tuple[float, float], y0: list[float], method: str | None = None, t_eval: np.ndarray | None = None, max_step: float | None = None, rtol: float | None = None, atol: float | None = None, ) -> ODESolution: """Solve an initial-value ODE problem using ``scipy.integrate.solve_ivp``. Args: ode_func: Right-hand side of the ODE system ``dy/dx = f(x, y)``. t_span: Integration interval ``(x_min, x_max)``. y0: Initial conditions vector. method: Integration method name. Falls back to env default. t_eval: Times at which to store the solution. If ``None``, a uniform grid is generated from env settings. max_step: Maximum allowed step size (0 → ``np.inf``). rtol: Relative tolerance. atol: Absolute tolerance. Returns: An :class:`ODESolution` with the results. Raises: SolverFailedError: If the solver reports failure. """ if t_eval is None: method, effective_max_step, rtol, atol, t_eval = _resolve_solver_params( method, max_step, rtol, atol, t_span ) else: # If t_eval is provided, still resolve other params method = method or get_default_solver_method() max_step = max_step or get_env_from_schema("SOLVER_MAX_STEP") rtol = rtol or get_env_from_schema("SOLVER_RTOL") atol = atol or get_env_from_schema("SOLVER_ATOL") effective_max_step = np.inf if max_step <= 0 else max_step logger.info( "Solving IVP: method=%s, span=%s, y0=%s, rtol=%s, atol=%s", method, t_span, y0, rtol, atol, ) sol = solve_ivp( fun=ode_func, t_span=t_span, y0=y0, method=method, t_eval=t_eval, max_step=effective_max_step, rtol=rtol, atol=atol, dense_output=True, ) n_eval = getattr(sol, "nfev", 0) result = ODESolution( x=sol.t, y=sol.y, success=sol.success, message=sol.message, method_used=method, n_eval=n_eval, raw=sol, ) if not sol.success: logger.error("Solver failed: %s", sol.message) raise SolverFailedError(f"Solver failed ({method}): {sol.message}") logger.info("Solver succeeded: %d points, %d evaluations", len(sol.t), n_eval) return result
[docs] def solve_multipoint( ode_func: Callable[[float, np.ndarray], np.ndarray], conditions: list[tuple[int, float, float]], order: int, x_min: float, x_max: float, method: str | None = None, t_eval: np.ndarray | None = None, max_step: float | None = None, rtol: float | None = None, atol: float | None = None, ) -> ODESolution: """Solve an ODE with initial conditions specified at possibly different x points. Uses a shooting method: the full state at ``x_min`` is found via root-finding so that all given conditions ``y^(k)(x_i) = a_i`` are satisfied. Args: ode_func: Right-hand side of the ODE system ``dy/dx = f(x, y)``. conditions: List of ``(k, x_i, a_i)`` meaning ``y^(k)(x_i) = a_i``. order: ODE order (equals number of conditions). x_min: Domain start. x_max: Domain end. method: Integration method name. t_eval: Points at which to store the final solution. max_step: Maximum step size. rtol: Relative tolerance. atol: Absolute tolerance. Returns: An :class:`ODESolution` with the results. Raises: SolverFailedError: If the solver or shooting method fails. """ if t_eval is None: method, effective_max_step, rtol, atol, t_eval = _resolve_solver_params( method, max_step, rtol, atol, (x_min, x_max) ) else: # If t_eval is provided, still resolve other params method = method or get_default_solver_method() max_step = max_step or get_env_from_schema("SOLVER_MAX_STEP") rtol = rtol or get_env_from_schema("SOLVER_RTOL") atol = atol or get_env_from_schema("SOLVER_ATOL") effective_max_step = np.inf if max_step <= 0 else max_step all_at_start = all(abs(xi - x_min) < 1e-12 for (_, xi, _) in conditions) if all_at_start: y0 = [a for (_, _, a) in sorted(conditions, key=lambda c: c[0])] return solve_ode( ode_func, (x_min, x_max), y0, method=method, t_eval=t_eval, max_step=max_step, rtol=rtol, atol=atol, ) y0_guess = np.zeros(order) for k, _xi, ai in conditions: y0_guess[k] = ai x_max_needed = max(max(xi for (_, xi, _) in conditions), x_max) n_fine = max(2000, len(t_eval) * 2) t_eval_fine = np.linspace(x_min, x_max_needed, n_fine) def _residuals(y0: np.ndarray) -> np.ndarray: sol = solve_ivp( ode_func, (x_min, x_max_needed), y0.tolist(), method=method, t_eval=t_eval_fine, max_step=effective_max_step, rtol=rtol, atol=atol, dense_output=True, ) if not sol.success: return np.full(len(conditions), 1e10) return np.array([np.interp(xi, sol.t, sol.y[k]) - ai for (k, xi, ai) in conditions]) y0_opt, _, ier, mesg = fsolve(_residuals, y0_guess, full_output=True) if ier != 1: raise SolverFailedError(f"Shooting method did not converge: {mesg}") logger.info("Shooting method converged; y0_opt=%s", y0_opt.tolist()) return solve_ode( ode_func, (x_min, x_max), y0_opt.tolist(), method=method, t_eval=t_eval, max_step=max_step, rtol=rtol, atol=atol, )