Source code for solver.equation_parser

"""Safe parsing and evaluation of user-written ODE expressions.

Expressions may use either the legacy ``y[k]`` notation or the unified
``f[...]`` notation.  When ``f`` tokens are present they are automatically
rewritten to ``y[...]`` via :mod:`solver.notation` before compilation.
"""

from __future__ import annotations

import re
from typing import Any, Callable

import numpy as np

from solver.notation import FNotation, _rewrite_f_expression
from utils import (
    EquationParseError,
    build_eval_namespace,
    get_logger,
    normalize_params,
    normalize_unicode_escapes,
    safe_eval,
    validate_exclusive_args,
    validate_expression_ast,
)

logger = get_logger(__name__)


def _maybe_rewrite(expression: str, notation: FNotation | None) -> str:
    """Rewrite f-notation to y-notation if a notation context is provided.

    Args:
        expression: Expression string with f-notation.
        notation: Notation context, or None to skip rewriting.

    Returns:
        Expression with f rewritten to y if notation given, else unchanged.
    """
    if notation is not None:
        return _rewrite_f_expression(expression, notation)
    return expression


def _compile_and_test(
    expression: str,
    namespace: dict[str, Any],
    var_names: str | tuple[str, ...] = ("x", "y"),
    test_values: dict[str, Any] | None = None,
) -> Any:
    """Compile an expression and test it for evaluation errors.

    Args:
        expression: Python expression string.
        namespace: Namespace dict (typically {**SAFE_MATH, **params}).
        var_names: Variable names to include in test eval (single string or tuple).
        test_values: Override test values for variables (e.g., {"x": 0.0}).

    Returns:
        Compiled code object.

    Raises:
        EquationParseError: If compilation or test evaluation fails.
    """
    compiled = compile(expression, "<expression>", "eval")

    # Build test namespace
    test_ns = {**namespace}
    if isinstance(var_names, str):
        var_names = (var_names,)
    for var_name in var_names:
        if test_values and var_name in test_values:
            test_ns[var_name] = test_values[var_name]
        elif var_name == "x":
            test_ns[var_name] = 0.0
        elif var_name == "n":
            test_ns[var_name] = 0
        elif var_name == "y":
            test_ns[var_name] = np.zeros(test_values.get("y_size", 1) if test_values else 1)

    try:
        safe_eval(compiled, test_ns)
    except Exception as exc:
        raise EquationParseError(f"Expression evaluation failed: {exc}") from exc

    return compiled


def _load_config_function(function_name: str, module_name: str = "config.equations") -> Callable:
    """Load a callable function from a config module.

    Args:
        function_name: Name of the function to load.
        module_name: Full module path (default: "config.equations").

    Returns:
        The callable function.

    Raises:
        EquationParseError: If the module cannot be imported or function not found.
    """
    try:
        import importlib

        module = importlib.import_module(module_name)
    except ImportError as exc:
        raise EquationParseError(f"Cannot import {module_name}: {exc}") from exc

    if not hasattr(module, function_name):
        raise EquationParseError(f"Function '{function_name}' not found in {module_name}")

    func = getattr(module, function_name)
    if not callable(func):
        raise EquationParseError(f"'{function_name}' in {module_name} is not callable")

    return func


def _parse_expression(
    expression: str,
    order: int,
    parameters: dict[str, float] | None = None,
    notation: FNotation | None = None,
) -> Callable[[float, np.ndarray], np.ndarray]:
    """Parse an ODE expression into a callable ``f(x, y) -> dy/dx``.

    The expression may use ``f[k]`` notation (rewritten automatically)
    or legacy ``y[k]`` notation.

    Args:
        expression: Python expression string for the highest derivative.
        order: Order of the ODE (1, 2, …).
        parameters: Named parameter values (e.g. ``{"omega": 2.0}``).
        notation: Notation context for ``f[...]`` rewriting. If ``None``,
            a default scalar ODE notation is created automatically.

    Returns:
        A callable ``f(x, y)`` that returns ``dy/dx`` as a 1-D array
        suitable for :func:`scipy.integrate.solve_ivp`.

    Raises:
        EquationParseError: If the expression is invalid.
    """
    expression = normalize_unicode_escapes(expression)
    if notation is None:
        notation = FNotation(kind="ode", n_components=1, order=order)
    expression = _maybe_rewrite(expression, notation)
    validate_expression_ast(expression, "ODE expression")
    params = normalize_params(parameters)
    logger.debug("Parsing expression (order=%d): %s, params=%s", order, expression, params)

    namespace = build_eval_namespace(params)

    compiled = _compile_and_test(
        expression,
        namespace,
        var_names=("x", "y"),
        test_values={"y_size": order},
    )

    def ode_func(x: float, y: np.ndarray) -> np.ndarray:
        local_ns = {**namespace, "x": x, "y": y}
        highest = safe_eval(compiled, local_ns)
        dydt = np.empty(order)
        for i in range(order - 1):
            dydt[i] = y[i + 1]
        dydt[order - 1] = float(highest)
        return dydt

    return ode_func


[docs] def get_ode_function( *, expression: str | None = None, function_name: str | None = None, order: int, parameters: dict[str, float] | None = None, ) -> Callable[[float, np.ndarray], np.ndarray]: """Resolve an ODE function from either an expression string or a Python function. Exactly one of expression or function_name must be provided. Args: expression: Python expression for the highest derivative. function_name: Name of a function in config.equations to import. order: ODE order (1, 2, …). parameters: Named parameter values. Returns: A callable ``f(x, y)`` that returns ``dy/dx`` as a 1-D array. Raises: EquationParseError: If expression is invalid or function cannot be resolved. ValueError: If neither or both expression and function_name are provided. """ params = normalize_params(parameters) validate_exclusive_args(expression, function_name, "expression", "function_name") if expression is not None: return _parse_expression(expression, order, params) assert function_name is not None # Guaranteed by validation above func = _load_config_function(function_name, "config.equations") def ode_func(x: float, y: np.ndarray) -> np.ndarray: return func(x, y, **params) return ode_func
def _parse_difference_expression( expression: str, order: int, parameters: dict[str, float] | None = None, notation: FNotation | None = None, ) -> Callable[[int, np.ndarray], float]: """Parse a difference equation expression into a callable ``f(n, y) -> y_next``. The expression may use ``f[k]`` notation (rewritten automatically) or legacy ``y[k]`` notation. Args: expression: Python expression string for the next value. order: Order of the recurrence (1, 2, …). parameters: Named parameter values. notation: Notation context for ``f[...]`` rewriting. Returns: A callable ``f(n, y)`` that returns the next value (scalar). Raises: EquationParseError: If the expression is invalid. """ expression = normalize_unicode_escapes(expression) if notation is None: notation = FNotation(kind="difference", n_components=1, order=order) expression = _maybe_rewrite(expression, notation) validate_expression_ast(expression, "difference expression") params = normalize_params(parameters) logger.debug( "Parsing difference expression (order=%d): %s, params=%s", order, expression, params, ) namespace = build_eval_namespace(params) compiled = _compile_and_test( expression, namespace, var_names=("n", "y"), test_values={"y_size": order}, ) def recur_func(n: int, y: np.ndarray) -> float: local_ns = {**namespace, "n": n, "y": y} return float(safe_eval(compiled, local_ns)) return recur_func
[docs] def get_difference_function( *, expression: str | None = None, function_name: str | None = None, order: int, parameters: dict[str, float] | None = None, ) -> Callable[[int, np.ndarray], float]: """Resolve a difference equation function from expression or Python function. Exactly one of expression or function_name must be provided. Args: expression: Python expression for y_{n+order}. function_name: Name of a function in config.difference_equations to import. order: Recurrence order (1, 2, …). parameters: Named parameter values. Returns: A callable ``f(n, y)`` that returns the next value (scalar). Raises: EquationParseError: If expression is invalid or function cannot be resolved. ValueError: If neither or both expression and function_name are provided. """ params = normalize_params(parameters) validate_exclusive_args(expression, function_name, "expression", "function_name") if expression is not None: return _parse_difference_expression(expression, order, params) assert function_name is not None # Guaranteed by validation above try: from config import difference_equations as diff_module except ImportError: try: from config import equations as diff_module except ImportError as exc: raise EquationParseError( f"Cannot import config.difference_equations or config.equations: {exc}" ) from exc if not hasattr(diff_module, function_name): raise EquationParseError(f"Function '{function_name}' not found in config") func = getattr(diff_module, function_name) if not callable(func): raise EquationParseError(f"'{function_name}' is not callable") def recur_func(n: int, y: np.ndarray) -> float: return float(func(n, y, **params)) return recur_func
_INDEXED_VAR_NAMES = ["x", "y", "z", "w"] _INDEXED_VAR_RE = re.compile(r"\bx\[([0-3])\]") # PDE RHS notation: f[k] = f_{x[k]}, f[i,j] = f_{x[i],x[j]} # x[0]=x, x[1]=y. So f[0]=fx, f[1]=fy, f[0,0]=fxx, f[0,1]=fxy, f[1,0]=fxy, f[1,1]=fyy # Bare f (no brackets) = solution value _PDE_F_SINGLE: dict[int, str] = {0: "fx", 1: "fy"} _PDE_F_DOUBLE: dict[tuple[int, int], str] = { (0, 0): "fxx", (0, 1): "fxy", (1, 0): "fxy", (1, 1): "fyy", } _PDE_F_DOUBLE_RE = re.compile(r"\bf\[([0-1]),([0-1])\]") _PDE_F_SINGLE_RE = re.compile(r"\bf\[([0-1])\]") def _rewrite_pde_f_notation(expression: str) -> str: """Rewrite f[k], f[i,j] to fx, fy, fxx, fxy, fyy in PDE RHS context. Notation: f[k] = f_{x[k]}, f[i,j] = f_{x[i],x[j]}. Bare f (no brackets) = solution value. Args: expression: PDE RHS expression string. Returns: Expression with f-notation rewritten to derivative names. """ # Replace f[i,j] first (longer pattern) def _replace_double(m: re.Match) -> str: i, j = int(m.group(1)), int(m.group(2)) return _PDE_F_DOUBLE.get((i, j), m.group(0)) def _replace_single(m: re.Match) -> str: idx = int(m.group(1)) return _PDE_F_SINGLE.get(idx, m.group(0)) expr = _PDE_F_DOUBLE_RE.sub(_replace_double, expression) return _PDE_F_SINGLE_RE.sub(_replace_single, expr) def _rewrite_indexed_vars(expression: str) -> str: """Rewrite indexed variable notation ``x[0]``, ``x[1]``, ... to named variables. Maps ``x[0]`` -> ``x``, ``x[1]`` -> ``y``, ``x[2]`` -> ``z``, ``x[3]`` -> ``w``. This allows users to write PDE expressions using indexed notation while the internal solver still uses named variables. Args: expression: Expression string with indexed variables. Returns: Expression with indexed vars replaced by names. """ return _INDEXED_VAR_RE.sub( lambda m: _INDEXED_VAR_NAMES[int(m.group(1))], expression, )
[docs] def parse_pde_rhs_expression( expression: str, variables: list[str], parameters: dict[str, float] | None = None, ) -> Callable[..., float]: """Parse a PDE RHS expression into a callable f(x, y, ...) -> float. The expression can use variable names (x, y, z, ...) or indexed notation (x[0], x[1], ...) and parameters. Used for the RHS of Poisson-type equations -u_xx - u_yy = f(x,y). Args: expression: Python expression string (e.g. ``"k"`` or ``"x[0] * x[1]"``). variables: List of variable names (e.g. ``["x", "y"]`` or ``["x[0]", "x[1]"]``). parameters: Named parameter values. Returns: A callable that takes (x, y, ...) and returns the RHS value. Raises: EquationParseError: If the expression is invalid. """ expression = normalize_unicode_escapes(expression) expression = _rewrite_indexed_vars(expression) expression = _rewrite_pde_f_notation(expression) # Ensure internal variable names are plain (x, y, ...) for evaluation internal_vars = [ _INDEXED_VAR_NAMES[i] if v.startswith("x[") else v for i, v in enumerate(variables) if i < len(_INDEXED_VAR_NAMES) ] if not internal_vars: internal_vars = list(variables) validate_expression_ast(expression, "PDE RHS") params = normalize_params(parameters) logger.debug( "Parsing PDE RHS expression: %s, variables=%s, internal_vars=%s, params=%s", expression, variables, internal_vars, params, ) namespace = build_eval_namespace(params) pde_solution_vars = ("f", "fx", "fy", "fxx", "fxy", "fyy") test_values: dict[str, Any] = {var: 0.0 for var in internal_vars} test_values.update({v: 0.0 for v in pde_solution_vars}) compiled = _compile_and_test( expression, namespace, var_names=tuple(internal_vars) + pde_solution_vars, test_values=test_values, ) def rhs_func(*args: float, **kwargs: Any) -> float: local_ns = {**namespace, **kwargs} for i, var in enumerate(internal_vars): if i < len(args): local_ns[var] = args[i] return float(safe_eval(compiled, local_ns)) return rhs_func
def _parse_vector_expression( expressions: list[str], order: int, parameters: dict[str, float] | None = None, notation: FNotation | None = None, ) -> Callable[[float, np.ndarray], np.ndarray]: """Parse a list of ODE expressions into a vector ODE callable. Expressions may use ``f[i,k]`` notation (rewritten automatically) or legacy ``y[j]`` flat indexing. Args: expressions: List of Python expressions, one per component. order: Order of each ODE (1, 2, …). parameters: Named parameter values. notation: Notation context for ``f[...]`` rewriting. Returns: A callable f(x, y) that returns dy/dx as a 1-D array. """ n_components = len(expressions) if n_components == 0: raise EquationParseError("vector_expressions must have at least one expression") if notation is None: notation = FNotation(kind="vector_ode", n_components=n_components, order=order) params = normalize_params(parameters) namespace = build_eval_namespace(params) compiled_list: list[Any] = [] for i, expr in enumerate(expressions): expr = normalize_unicode_escapes(expr) expr = _maybe_rewrite(expr, notation) validate_expression_ast(expr, f"vector expression {i}") compiled_list.append(compile(expr, f"<vector_ode_{i}>", "eval")) state_size = n_components * order # Test each compiled expression test_y = np.zeros(state_size) test_ns = {**namespace, "x": 0.0, "y": test_y} for i, compiled in enumerate(compiled_list): try: safe_eval(compiled, test_ns) except Exception as exc: raise EquationParseError(f"Expression {i} evaluation failed: {exc}") from exc def ode_func(x: float, y: np.ndarray) -> np.ndarray: dydt = np.empty(state_size) local_ns = {**namespace, "x": x, "y": y} for i in range(n_components): for k in range(order - 1): dydt[i * order + k] = y[i * order + k + 1] highest = safe_eval(compiled_list[i], local_ns) dydt[i * order + order - 1] = float(highest) return dydt return ode_func
[docs] def get_vector_ode_function( *, vector_expressions: list[str], function_name: str | None = None, order: int, vector_components: int, parameters: dict[str, float] | None = None, ) -> Callable[[float, np.ndarray], np.ndarray]: """Resolve a vector ODE function from expressions or Python function. Exactly one of vector_expressions or function_name must be provided. Args: vector_expressions: List of expressions for each component's highest derivative. function_name: Name of function in config.equations (returns full dydt). order: Order of each ODE component. vector_components: Number of components (f_0, f_1, ...). parameters: Named parameter values. Returns: A callable f(x, y) that returns dy/dx. Raises: ValueError: If both or neither of vector_expressions and function_name provided. EquationParseError: If expressions are invalid or function not found. """ params = normalize_params(parameters) if vector_expressions and function_name: raise ValueError("Provide either vector_expressions or function_name, not both") if not vector_expressions and not function_name: raise ValueError("Provide either vector_expressions or function_name") if vector_expressions: if len(vector_expressions) != vector_components: raise EquationParseError( f"vector_expressions length ({len(vector_expressions)}) " f"must match vector_components ({vector_components})" ) return _parse_vector_expression(vector_expressions, order, params) assert function_name is not None # Guaranteed by validation above func = _load_config_function(function_name, "config.equations") def ode_func(x: float, y: np.ndarray) -> np.ndarray: return func(x, y, **params) return ode_func
def _validate_expression(expression: str) -> list[str]: """Check an expression for obvious errors without evaluating. Args: expression: Python expression string. Returns: List of error messages (empty if valid). """ from solver.notation import _preprocess_prime_notation errors: list[str] = [] if not expression or not expression.strip(): errors.append("Expression is empty") return errors try: # Preprocess f'/f'' notation before AST validation so that # Python's parser doesn't confuse f' with an f-string literal. expr = _preprocess_prime_notation(normalize_unicode_escapes(expression.strip())) validate_expression_ast(expr, "expression") except EquationParseError as exc: errors.append(str(exc)) return errors