Source code for solver.validators

"""Input validation for DifferentialLab solver parameters."""

from __future__ import annotations

import math

from config import SOLVER_METHODS
from solver.equation_parser import _validate_expression
from utils import get_logger

logger = get_logger(__name__)

# Module-level constants
_MAX_GRID_POINTS = 1_000_000
_SUBSCRIPTS = "₀₁₂₃₄₅₆₇₈₉"


def _ordinal(n: int) -> str:
    """Convert an integer to its ordinal string representation.

    Args:
        n: Integer to convert.

    Returns:
        Ordinal string (e.g. "1st", "2nd", "3rd", "4th").

    Examples:
        1 -> "1st", 2 -> "2nd", 3 -> "3rd", 4 -> "4th", etc.
    """
    if n % 100 in (11, 12, 13):
        suffix = "th"
    else:
        remainder = n % 10
        suffix = {1: "st", 2: "nd", 3: "rd"}.get(remainder, "th")
    return f"{n}{suffix}"


def _is_finite(value: float) -> bool:
    """Return ``True`` if *value* is a finite number (not NaN, not ±inf).

    Args:
        value: Value to check.

    Returns:
        True if finite, False otherwise.
    """
    try:
        return math.isfinite(float(value))
    except (TypeError, ValueError):
        return False


def _validate_domain(x_min: float, x_max: float) -> list[str]:
    """Validate the integration domain.

    Args:
        x_min: Start of the domain.
        x_max: End of the domain.

    Returns:
        List of error messages (empty if valid).
    """
    errors: list[str] = []
    if x_min >= x_max:
        errors.append(f"x_min ({x_min}) must be less than x_max ({x_max})")
    if not all(map(_is_finite, (x_min, x_max))):
        errors.append("Domain bounds must be finite numbers")
    return errors


def _validate_initial_conditions(y0: list[float], expected_order: int) -> list[str]:
    """Validate the initial conditions vector.

    Args:
        y0: Initial conditions values.
        expected_order: The ODE order (determines how many ICs are needed).

    Returns:
        List of error messages (empty if valid).
    """
    errors: list[str] = []
    if len(y0) != expected_order:
        errors.append(
            f"Expected {expected_order} initial condition(s) for a "
            f"{_ordinal(expected_order)}-order ODE, "
            f"got {len(y0)}"
        )
    for i, val in enumerate(y0):
        if not _is_finite(val):
            errors.append(f"Initial condition y0[{i}] = {val} is not a finite number")
    return errors


def _validate_grid(num_points: int) -> list[str]:
    """Validate the number of evaluation points.

    Args:
        num_points: Requested grid size.

    Returns:
        List of error messages (empty if valid).
    """
    errors: list[str] = []
    if num_points < 10:
        errors.append("Number of points must be at least 10")
    if num_points > _MAX_GRID_POINTS:
        errors.append(f"Number of points must not exceed {_MAX_GRID_POINTS:,}")
    return errors


def _validate_method(method: str) -> list[str]:
    """Validate the solver method name.

    Args:
        method: Solver method name.

    Returns:
        List of error messages (empty if valid).
    """
    if method not in SOLVER_METHODS:
        return [f"Unknown method '{method}'. Choose from: {', '.join(SOLVER_METHODS)}"]
    return []


def _validate_parameters(params: dict[str, object]) -> list[str]:
    """Validate parameter values.

    Args:
        params: Parameter name-value mapping.  Values may be scalars or
            array-like (e.g. ``numpy.ndarray``) for list parameters.

    Returns:
        List of error messages (empty if valid).
    """
    import numpy as _np

    errors: list[str] = []
    for name, value in params.items():
        if isinstance(value, _np.ndarray):
            if not _np.all(_np.isfinite(value)):
                errors.append(f"Parameter '{name}' contains non-finite values")
        elif not _is_finite(value):
            errors.append(f"Parameter '{name}' = {value} is not a finite number")
    return errors


def _validate_ic_points(
    x0_list: list[float],
    x_min: float,
    x_max: float,
) -> list[str]:
    """Validate per-derivative initial condition points.

    Args:
        x0_list: x_i value for each derivative condition.
        x_min: Domain start.
        x_max: Domain end.

    Returns:
        List of error messages (empty if valid).
    """
    errors: list[str] = []
    for i, xi in enumerate(x0_list):
        sub = _SUBSCRIPTS[i] if i < len(_SUBSCRIPTS) else str(i)
        if not _is_finite(xi):
            errors.append(f"x{sub} = {xi} is not a finite number")
        elif not (x_min <= xi <= x_max):
            errors.append(f"x{sub} = {xi} must lie within the domain [{x_min}, {x_max}]")
    return errors


[docs] def validate_all_inputs( *, expression: str | None = None, function_name: str | None = None, order: int, x_min: float, x_max: float, y0: list[float], num_points: int, method: str, params: dict[str, float] | None = None, x0_list: list[float] | None = None, equation_type: str = "ode", vector_expressions: list[str] | None = None, vector_components: int = 1, ) -> list[str]: """Run all validations and return accumulated errors. Either expression or function_name must be provided. Args: expression: ODE/difference expression (optional). function_name: Name of function in config (optional). order: Equation order. x_min: Domain start (n_min for difference). x_max: Domain end (n_max for difference). y0: Initial conditions. num_points: Grid points (ODE only). method: Solver method (ODE only). params: Named parameters. x0_list: Per-derivative initial condition points (ODE only). equation_type: ``"ode"`` or ``"difference"``. Returns: List of all error messages (empty if everything is valid). """ errors: list[str] = [] is_vector = ( vector_expressions is not None and len(vector_expressions) > 0 ) or equation_type == "vector_ode" expected_order = order * vector_components if is_vector else order if is_vector: if vector_expressions and function_name: errors.append("Provide either vector_expressions or function_name, not both") elif not vector_expressions and not function_name: errors.append("Vector ODE requires vector_expressions or function_name") elif vector_expressions: for i, expr in enumerate(vector_expressions): errors.extend(f"Vector expression {i}: {e}" for e in _validate_expression(expr)) # else: function_name only — no expression validation needed else: if expression is not None and function_name is not None: errors.append("Provide either expression or function_name, not both") elif expression is None and function_name is None: errors.append("Provide either expression or function_name") elif expression is not None: errors.extend(_validate_expression(expression)) errors.extend(_validate_domain(x_min, x_max)) errors.extend(_validate_initial_conditions(y0, expected_order)) if equation_type == "difference": n_min, n_max = int(x_min), int(x_max) n_points = n_max - n_min + 1 if n_points < 2: errors.append("Difference equation needs at least 2 points (n_max > n_min)") if n_points > _MAX_GRID_POINTS: errors.append(f"Number of points must not exceed {_MAX_GRID_POINTS:,}") else: errors.extend(_validate_grid(num_points)) errors.extend(_validate_method(method)) if x0_list: errors.extend(_validate_ic_points(x0_list, x_min, x_max)) if params: errors.extend(_validate_parameters(params)) if errors: logger.warning("Validation errors: %s", errors) return errors