Source code for utils.expression_parser_shared

"""Shared constants and validation for safe expression parsing.

Used by solver.equation_parser and transforms.function_parser to avoid
duplication of _SAFE_MATH, _ALLOWED_NODE_TYPES, and AST validation logic.
"""

from __future__ import annotations

import ast
import re
from typing import Any

import numpy as np

from utils.exceptions import EquationParseError
from utils.logger import get_logger

logger = get_logger(__name__)

_UNICODE_ESCAPE_RE = re.compile(r"\\u([0-9A-Fa-f]{4})")


[docs] def normalize_unicode_escapes(text: str) -> str: """Replace ``\\uXXXX`` escape sequences with their Unicode characters. Allows users to enter expressions like ``\\u03C9**2 * y[0]`` and have them treated equivalently to ``ω**2 * y[0]``. Args: text: Input string that may contain Unicode escape sequences. Returns: String with all ``\\uXXXX`` sequences replaced by the corresponding Unicode character. """ return _UNICODE_ESCAPE_RE.sub(lambda m: chr(int(m.group(1), 16)), text)
SAFE_MATH: dict[str, Any] = { "sin": np.sin, "cos": np.cos, "tan": np.tan, "exp": np.exp, "log": np.log, "log10": np.log10, "sqrt": np.sqrt, "abs": np.abs, "pi": np.pi, "e": np.e, "sinh": np.sinh, "cosh": np.cosh, "tanh": np.tanh, "arcsin": np.arcsin, "arccos": np.arccos, "arctan": np.arctan, "floor": np.floor, "ceil": np.ceil, "sign": np.sign, "heaviside": np.heaviside, } _ALLOWED_NODE_TYPES: frozenset[type[ast.AST]] = frozenset( ( ast.Module, ast.Expr, ast.Expression, ast.BinOp, ast.UnaryOp, ast.Call, ast.Name, ast.Constant, ast.Load, ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Pow, ast.USub, ast.UAdd, ast.Subscript, ast.Attribute, ast.FloorDiv, ast.Mod, ast.Compare, ast.IfExp, ast.BoolOp, ast.And, ast.Or, ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.Tuple, ast.List, ) )
[docs] def normalize_params(parameters: dict[str, float] | None) -> dict[str, float]: """Return a mutable copy of *parameters*, or an empty dict if ``None``.""" return dict(parameters) if parameters else {}
[docs] def build_eval_namespace(params: dict[str, float]) -> dict[str, Any]: """Build the safe evaluation namespace combining SAFE_MATH and user params.""" return {**SAFE_MATH, **params}
[docs] def safe_eval(compiled: Any, namespace: dict[str, Any]) -> Any: """Evaluate a compiled expression in a sandboxed namespace.""" return eval(compiled, {"__builtins__": {}}, namespace) # noqa: S307
[docs] def validate_exclusive_args( arg_a: object | None, arg_b: object | None, name_a: str, name_b: str, ) -> None: """Ensure exactly one of two mutually exclusive arguments is provided. Raises: ValueError: If both or neither argument is provided. """ if arg_a is not None and arg_b is not None: raise ValueError(f"Provide either {name_a} or {name_b}, not both") if arg_a is None and arg_b is None: raise ValueError(f"Provide either {name_a} or {name_b}")
[docs] def validate_expression_ast(expression: str, context: str = "expression") -> None: """Check that the expression contains only allowed AST nodes. Args: expression: Python expression string to validate. context: Short description for error messages (e.g. "ODE expression"). Raises: EquationParseError: If the expression contains disallowed constructs. """ try: tree = ast.parse(expression, mode="eval") except SyntaxError as exc: logger.debug("%s syntax error: %s%s", context, expression[:80], exc) raise EquationParseError(f"Syntax error in expression: {exc}") from exc for node in ast.walk(tree): if type(node) not in _ALLOWED_NODE_TYPES: raise EquationParseError(f"Disallowed construct in {context}: {type(node).__name__}")