"""Load and manage predefined ODE equations from YAML."""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
import yaml
from utils import get_logger
logger = get_logger(__name__)
_EQUATIONS_DIR = Path(__file__).resolve().parent.parent / "config" / "equations"
_EQUATION_FILES = ["ode.yaml", "vector_ode.yaml", "difference.yaml", "pde.yaml"]
_cache: dict[str, PredefinedEquation] | None = None
EquationType = Literal["ode", "difference", "pde", "vector_ode"]
[docs]
@dataclass
class PredefinedEquation:
"""Predefined equation (ODE, difference, PDE, or vector ODE) loaded from YAML.
formula is always required for display. Either expression or function_name
must be set for execution. If function_name is set, the equation is resolved by
importing the function from config.equations; otherwise expression is used.
For vector ODEs, use vector_expressions or function_name.
Attributes:
key: Unique identifier (YAML key).
name: Human-readable name.
formula: Compact human-readable equation string (e.g. ``"y'' + ω²y = 0"``).
description: Multi-line description with formula and context.
order: Equation order (1, 2, …) for ODE/difference. For vector: order per component.
parameters: Mapping of param name to ``{default, description}``.
expression: Python expression for execution (optional if function_name set).
function_name: Name of function in config.equations to import (optional).
vector_expressions: For vector ODEs, list of expressions (one per component).
vector_components: Number of components [f_0, f_1, ...] for vector ODEs.
default_initial_conditions: Default y0 vector.
default_domain: Default ``[x_min, x_max]`` for ODE or ``[n_min, n_max]`` for difference.
For PDE: ``[x_min, x_max, y_min, y_max, ...]`` per variable.
equation_type: ``"ode"`` (differential), ``"difference"``, ``"pde"``, or ``"vector_ode"``.
category: Display category (e.g. ``"Oscillators"``, ``"Population"``) for UI grouping.
variables: Independent variable names, e.g. ``["x"]`` for 1D, ``["x","y"]`` for 2D.
If absent or ``["x"]``, treated as 1D ODE.
partial_derivatives: For PDEs, maps derivative keys (e.g. ``"f_xx"``, ``"f_xy"``)
to expression strings. Only needed for PDE type.
"""
key: str
name: str
formula: str
description: str
order: int
parameters: dict[str, dict[str, Any]]
expression: str | None
function_name: str | None
default_initial_conditions: list[float]
default_domain: list[float] = field(default_factory=lambda: [0.0, 10.0])
vector_expressions: list[str] | None = None
vector_components: int = 1
equation_type: EquationType = "ode"
category: str = "Oscillators"
variables: list[str] = field(default_factory=lambda: ["x"])
partial_derivatives: dict[str, str] | None = None
[docs]
def load_predefined_equations() -> dict[str, PredefinedEquation]:
"""Load all predefined equations from the YAML file.
Results are cached after the first successful load to avoid redundant
disk I/O on repeated calls.
Returns:
Ordered dict mapping equation key to :class:`PredefinedEquation`.
Raises:
FileNotFoundError: If the YAML file is missing.
"""
global _cache
if _cache is not None:
return _cache
if not _EQUATIONS_DIR.exists():
logger.error("Equations directory not found: %s", _EQUATIONS_DIR)
raise FileNotFoundError(f"Equations directory not found: {_EQUATIONS_DIR}")
raw: dict[str, Any] = {}
for filename in _EQUATION_FILES:
filepath = _EQUATIONS_DIR / filename
if not filepath.exists():
logger.error("Equations file not found: %s", filepath)
raise FileNotFoundError(f"Equations file not found: {filepath}")
inferred_type = filename.replace(".yaml", "")
with open(filepath, "r", encoding="utf-8") as f:
chunk: dict[str, Any] = yaml.safe_load(f) or {}
for key, eq_data in chunk.items():
eq_data = dict(eq_data)
if "equation_type" not in eq_data:
eq_data["equation_type"] = inferred_type
raw[key] = eq_data
equations: dict[str, PredefinedEquation] = {}
for key, data in raw.items():
formula: str = data.get("formula", "")
if not formula:
logger.warning("Equation '%s' has no formula (required for display); skipping", key)
continue
expression: str | None = data.get("expression")
function_name: str | None = data.get("function_name")
vector_expressions: list[str] | None = data.get("vector_expressions")
vector_components: int = int(data.get("vector_components", 1))
eq_type_str: str = data.get("equation_type", "ode")
has_vector = (
vector_expressions is not None and len(vector_expressions) > 0
) or eq_type_str == "vector_ode"
if not expression and not function_name and not has_vector:
logger.warning(
"Equation '%s' has neither expression, function_name, nor vector_expressions; "
"skipping",
key,
)
continue
partial_derivatives = data.get("partial_derivatives")
partial_derivatives = dict(partial_derivatives) if partial_derivatives else None
eq = PredefinedEquation(
key=key,
name=data.get("name", key),
formula=formula,
description=data.get("description", ""),
order=int(data.get("order", 1)),
parameters=data.get("parameters", {}),
expression=expression,
function_name=function_name,
vector_expressions=vector_expressions,
vector_components=vector_components if has_vector else 1,
default_initial_conditions=list(data.get("default_initial_conditions", [0.0])),
default_domain=list(data.get("default_domain", [0.0, 10.0])),
equation_type=str(data.get("equation_type", "ode")),
category=str(data.get("category", "Oscillators")),
variables=list(data.get("variables", ["x"])),
partial_derivatives=partial_derivatives,
)
equations[key] = eq
logger.debug("Loaded predefined equation: %s", key)
logger.info("Loaded %d predefined equations", len(equations))
_cache = equations
return equations
[docs]
def is_multivariate(variables: list[str] | None) -> bool:
"""Return True if the equation has more than one independent variable.
Args:
variables: List of variable names (e.g. ["x"] or ["x","y"]).
Returns:
True if len(variables) > 1.
"""
if not variables:
return False
return len(variables) > 1