"""Input validation for DifferentialLab solver parameters."""
from __future__ import annotations
import math
from config import SOLVER_METHODS
from solver.equation_parser import _validate_expression
from utils import get_logger
logger = get_logger(__name__)
# Module-level constants
_MAX_GRID_POINTS = 1_000_000
_SUBSCRIPTS = "₀₁₂₃₄₅₆₇₈₉"
def _ordinal(n: int) -> str:
"""Convert an integer to its ordinal string representation.
Args:
n: Integer to convert.
Returns:
Ordinal string (e.g. "1st", "2nd", "3rd", "4th").
Examples:
1 -> "1st", 2 -> "2nd", 3 -> "3rd", 4 -> "4th", etc.
"""
if n % 100 in (11, 12, 13):
suffix = "th"
else:
remainder = n % 10
suffix = {1: "st", 2: "nd", 3: "rd"}.get(remainder, "th")
return f"{n}{suffix}"
def _is_finite(value: float) -> bool:
"""Return ``True`` if *value* is a finite number (not NaN, not ±inf).
Args:
value: Value to check.
Returns:
True if finite, False otherwise.
"""
try:
return math.isfinite(float(value))
except (TypeError, ValueError):
return False
def _validate_domain(x_min: float, x_max: float) -> list[str]:
"""Validate the integration domain.
Args:
x_min: Start of the domain.
x_max: End of the domain.
Returns:
List of error messages (empty if valid).
"""
errors: list[str] = []
if x_min >= x_max:
errors.append(f"x_min ({x_min}) must be less than x_max ({x_max})")
if not all(map(_is_finite, (x_min, x_max))):
errors.append("Domain bounds must be finite numbers")
return errors
def _validate_initial_conditions(y0: list[float], expected_order: int) -> list[str]:
"""Validate the initial conditions vector.
Args:
y0: Initial conditions values.
expected_order: The ODE order (determines how many ICs are needed).
Returns:
List of error messages (empty if valid).
"""
errors: list[str] = []
if len(y0) != expected_order:
errors.append(
f"Expected {expected_order} initial condition(s) for a "
f"{_ordinal(expected_order)}-order ODE, "
f"got {len(y0)}"
)
for i, val in enumerate(y0):
if not _is_finite(val):
errors.append(f"Initial condition y0[{i}] = {val} is not a finite number")
return errors
def _validate_grid(num_points: int) -> list[str]:
"""Validate the number of evaluation points.
Args:
num_points: Requested grid size.
Returns:
List of error messages (empty if valid).
"""
errors: list[str] = []
if num_points < 10:
errors.append("Number of points must be at least 10")
if num_points > _MAX_GRID_POINTS:
errors.append(f"Number of points must not exceed {_MAX_GRID_POINTS:,}")
return errors
def _validate_method(method: str) -> list[str]:
"""Validate the solver method name.
Args:
method: Solver method name.
Returns:
List of error messages (empty if valid).
"""
if method not in SOLVER_METHODS:
return [f"Unknown method '{method}'. Choose from: {', '.join(SOLVER_METHODS)}"]
return []
def _validate_parameters(params: dict[str, object]) -> list[str]:
"""Validate parameter values.
Args:
params: Parameter name-value mapping. Values may be scalars or
array-like (e.g. ``numpy.ndarray``) for list parameters.
Returns:
List of error messages (empty if valid).
"""
import numpy as _np
errors: list[str] = []
for name, value in params.items():
if isinstance(value, _np.ndarray):
if not _np.all(_np.isfinite(value)):
errors.append(f"Parameter '{name}' contains non-finite values")
elif not _is_finite(value):
errors.append(f"Parameter '{name}' = {value} is not a finite number")
return errors
def _validate_ic_points(
x0_list: list[float],
x_min: float,
x_max: float,
) -> list[str]:
"""Validate per-derivative initial condition points.
Args:
x0_list: x_i value for each derivative condition.
x_min: Domain start.
x_max: Domain end.
Returns:
List of error messages (empty if valid).
"""
errors: list[str] = []
for i, xi in enumerate(x0_list):
sub = _SUBSCRIPTS[i] if i < len(_SUBSCRIPTS) else str(i)
if not _is_finite(xi):
errors.append(f"x{sub} = {xi} is not a finite number")
elif not (x_min <= xi <= x_max):
errors.append(f"x{sub} = {xi} must lie within the domain [{x_min}, {x_max}]")
return errors