Source code for solver.difference_solver

"""Solver for difference equations (recurrence relations)."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable

import numpy as np

from utils import get_logger

logger = get_logger(__name__)


[docs] @dataclass class DifferenceSolution: """Container for difference equation solution data. Attributes: n: Discrete index values (0, 1, 2, ...). y: Solution array — shape ``(n_vars, n_points)``. success: Whether the iteration completed without error. message: Status message. """ n: np.ndarray y: np.ndarray success: bool message: str
[docs] def solve_difference( recur_func: Callable[[int, np.ndarray], float], n_min: int, n_max: int, y0: list[float], order: int, ) -> DifferenceSolution: """Solve a difference equation by iterating the recurrence. The recurrence has the form y_{n+order} = f(n, [y_n, y_{n+1}, ..., y_{n+order-1}]). State vector at step n: y[0]=y_n, y[1]=y_{n+1}, ..., y[order-1]=y_{n+order-1}. Args: recur_func: Function (n, y) -> next value y_{n+order}. n_min: Start index (inclusive). n_max: End index (inclusive). y0: Initial conditions [y_0, y_1, ..., y_{order-1}]. order: Order of the recurrence. Returns: A :class:`DifferenceSolution` with n and y arrays. y has shape (order, n_points) for compatibility with ODE pipeline. """ if n_min >= n_max: return DifferenceSolution( n=np.array([]), y=np.array([]).reshape(0, 0), success=False, message="n_min must be less than n_max", ) n_points = n_max - n_min + 1 n_arr = np.arange(n_min, n_max + 1, dtype=float) y_arr = np.zeros((order, n_points)) y_arr[:, 0] = y0[:order] state = np.array(y0[:order], dtype=float) last_i = 0 try: for i in range(1, n_points): last_i = i n_curr = n_min + i - 1 next_val = float(recur_func(n_curr, state)) state = np.roll(state, -1) state[-1] = next_val y_arr[:, i] = state logger.info( "Difference equation solved: %d points from n=%d to n=%d", n_points, n_min, n_max, ) return DifferenceSolution( n=n_arr, y=y_arr, success=True, message="Solved successfully", ) except Exception as exc: logger.error("Difference equation iteration failed: %s", exc) return DifferenceSolution( n=n_arr[: last_i + 1], y=y_arr[:, : last_i + 1], success=False, message=str(exc), )