"""Notation translation between user-facing f[...] and internal flat y[j] arrays.
The user writes expressions using ``f[i,k]`` (vector ODE), ``f[k]`` (scalar ODE /
difference), or ``f[a,b,...]`` (PDE derivatives). Internally, scipy's solve_ivp
operates on a flat state vector ``y``. This module bridges the two representations.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Literal
EquationKind = Literal["ode", "vector_ode", "difference", "pde"]
[docs]
@dataclass(frozen=True)
class FNotation:
"""Describes the notation context for a particular equation.
Attributes:
kind: One of ``"ode"``, ``"vector_ode"``, ``"difference"``, ``"pde"``.
n_components: Number of components (1 for scalar ODE / PDE / difference).
order: ODE order per component, or max derivative order for PDE.
n_independent_vars: Number of independent variables (1 for ODE, >=1 for PDE).
"""
kind: EquationKind
n_components: int = 1
order: int = 1
n_independent_vars: int = 1
component_orders: tuple[int, ...] = ()
[docs]
def state_size(self) -> int:
"""Total number of entries in the flat state vector.
Returns:
Number of entries in the flat ``y`` array.
"""
if self.component_orders:
return sum(self.component_orders)
return self.n_components * self.order
# ---------------------------------------------------------------------------
# f' -> f[k] preprocessing (prime notation)
# ---------------------------------------------------------------------------
# Matches: f followed by one or more prime chars, optionally followed by [...]
_F_PRIME_TOKEN = re.compile(
r"""
\bf # word-boundary then literal 'f'
(['\u2032]+) # one or more primes: ASCII ' or Unicode ′ (group 1)
(?:\[([^\]]*)\])? # optional bracketed content (group 2)
""",
re.VERBOSE,
)
def _preprocess_prime_notation(expression: str) -> str:
"""Convert ``f'`` notation to ``f[k]`` notation before main rewriting.
Scalar ODE: ``f'`` → ``f[1]``, ``f''`` → ``f[2]``, ``f'''`` → ``f[3]``
Vector ODE: ``f'[i]`` → ``f[i,1]``, ``f''[i]`` → ``f[i,2]``
Bare ``f`` (no primes, no brackets) is left unchanged.
Args:
expression: Python expression string with prime notation.
Returns:
Expression with primes rewritten to bracket notation.
"""
def _replace(m: re.Match) -> str:
primes = m.group(1)
bracket = m.group(2) # content inside optional [...]
deriv_order = len(primes)
if bracket is not None:
# Vector case: f'[i] -> f[i, deriv_order]
return f"f[{bracket.strip()},{deriv_order}]"
# Scalar case: f'' -> f[deriv_order]
return f"f[{deriv_order}]"
return _F_PRIME_TOKEN.sub(_replace, expression)
# ---------------------------------------------------------------------------
# f[...] -> y[...] rewriting
# ---------------------------------------------------------------------------
# Matches: f | f[...] (greedy bracket content)
_F_TOKEN = re.compile(
r"""
\bf # word-boundary then literal 'f'
(?:\[([^\]]*)\])? # optional bracketed content (group 1)
(?![(\w]) # NOT followed by '(' or word char (avoid matching func names)
""",
re.VERBOSE,
)
def _rewrite_match_ode_scalar(m: re.Match, order: int) -> str:
"""Rewrite a single f-token for scalar ODE / difference.
Args:
m: Regex match for f-token.
order: ODE order.
Returns:
Equivalent y-index expression.
"""
bracket = m.group(1)
if bracket is None:
# bare 'f' -> y[0]
return "y[0]"
bracket = bracket.strip()
if not bracket:
return "y[0]"
# f[k] -> y[k]
return f"y[{bracket}]"
def _rewrite_match_vector_ode(m: re.Match, notation: FNotation) -> str:
"""Rewrite a single f-token for vector ODE.
Args:
m: Regex match for f-token.
notation: Vector ODE notation context.
Returns:
Equivalent y-index expression.
"""
bracket = m.group(1)
if bracket is None:
return "y[0]"
bracket = bracket.strip()
if not bracket:
return "y[0]"
parts = [p.strip() for p in bracket.split(",")]
if len(parts) == 1:
# f[k] -> y[0*order + k] (implicit component 0)
k_expr = parts[0]
if notation.component_orders:
# component 0 always starts at offset 0
return f"y[{k_expr}]"
return f"y[{k_expr}]"
if len(parts) == 2:
i_expr, k_expr = parts
# f[i,k] -> y[i*order + k]
# If the indices are literal integers, compute directly for clarity.
try:
i_val = int(i_expr)
k_val = int(k_expr)
if notation.component_orders:
offset = sum(notation.component_orders[:i_val])
return f"y[{offset + k_val}]"
flat = i_val * notation.order + k_val
return f"y[{flat}]"
except ValueError:
# Symbolic indices (e.g. loop variable i)
if notation.component_orders:
# Fallback: cannot compute with heterogeneous orders symbolically.
# Assume uniform order for symbolic case.
return f"y[({i_expr})*{notation.order}+({k_expr})]"
return f"y[({i_expr})*{notation.order}+({k_expr})]"
# More than 2 indices is invalid for vector ODE
return m.group(0) # leave unchanged
def _rewrite_f_expression(expression: str, notation: FNotation) -> str:
"""Rewrite user-facing ``f[...]`` tokens to internal ``y[...]`` form.
Supports both bracket notation (``f[k]``, ``f[i,k]``) and prime notation
(``f'``, ``f''``, ``f'[i]``). Prime notation is preprocessed first.
Args:
expression: Python expression using ``f``, ``f[k]``, ``f[i,k]``,
``f'``, ``f''``, ``f'[i]``, etc.
notation: Context describing the equation type and dimensions.
Returns:
Equivalent expression with ``y[...]`` indexing suitable for the solver.
"""
# Preprocess prime notation (f', f'', f'[i], ...) into bracket form first
expression = _preprocess_prime_notation(expression)
if notation.kind in ("ode", "difference"):
return _F_TOKEN.sub(lambda m: _rewrite_match_ode_scalar(m, notation.order), expression)
if notation.kind == "vector_ode":
return _F_TOKEN.sub(lambda m: _rewrite_match_vector_ode(m, notation), expression)
if notation.kind == "pde":
# PDE: f alone -> y[0] (the solution value).
# f[a,b,...] derivative references are handled by the PDE parser (Phase 9).
# For now, only rewrite bare f -> y[0].
return _F_TOKEN.sub(lambda m: _rewrite_match_ode_scalar(m, notation.order), expression)
return expression
# ---------------------------------------------------------------------------
# Flat index -> label (for plots, CSV headers)
# ---------------------------------------------------------------------------
_PRIME_SYMBOLS = ["", "\u2032", "\u2033", "\u2034"] # '', ′, ″, ‴
_SUBSCRIPT_DIGITS = "\u2080\u2081\u2082\u2083\u2084\u2085\u2086\u2087\u2088\u2089"
def _prime_str(k: int) -> str:
"""Return a prime string for derivative order k."""
if k < len(_PRIME_SYMBOLS):
return _PRIME_SYMBOLS[k]
return f"({k})"
def _subscript(n: int) -> str:
"""Return subscript digits for integer n (e.g. 12 → '₁₂').
Args:
n: Non-negative integer.
Returns:
Unicode subscript string.
"""
if 0 <= n < len(_SUBSCRIPT_DIGITS):
return _SUBSCRIPT_DIGITS[n]
return "".join(_SUBSCRIPT_DIGITS[int(d)] if d.isdigit() else d for d in str(n))
def _flat_index_to_label(j: int, notation: FNotation) -> str:
"""Convert flat state-vector index *j* to a human-readable label.
Args:
j: Index into the flat ``y`` array.
notation: Equation context.
Returns:
A string like ``"f"``, ``"f\u2032"``, ``"f\u2032\u2081"``, etc.
"""
if notation.kind in ("ode", "difference"):
if notation.order == 1:
return "f"
return f"f{_prime_str(j)}"
if notation.kind == "vector_ode":
if notation.component_orders:
# Heterogeneous orders: find component
cumsum = 0
for comp, comp_order in enumerate(notation.component_orders):
if j < cumsum + comp_order:
k = j - cumsum
return f"f{_prime_str(k)}{_subscript(comp)}"
cumsum += comp_order
return f"y[{j}]"
comp = j // notation.order
k = j % notation.order
return f"f{_prime_str(k)}{_subscript(comp)}"
if notation.kind == "pde":
if j == 0:
return "f"
return f"y[{j}]"
return f"y[{j}]"
[docs]
def generate_derivative_labels(notation: FNotation) -> list[str]:
"""Generate labels for every entry in the flat state vector.
Args:
notation: Equation context.
Returns:
List of human-readable labels, length == ``notation.state_size()``.
"""
return [_flat_index_to_label(j, notation) for j in range(notation.state_size())]
[docs]
def generate_phase_space_options(notation: FNotation) -> list[tuple[str, int | None]]:
"""Generate options for phase-space axis selectors.
Each option is ``(label, flat_index)`` where flat_index is ``None`` for the
independent variable ``x``.
Args:
notation: Equation context.
Returns:
List of ``(label, flat_index_or_None)`` pairs.
"""
x_label = "n" if notation.kind == "difference" else "x"
options: list[tuple[str, int | None]] = [(x_label, None)]
for j in range(notation.state_size()):
options.append((_flat_index_to_label(j, notation), j))
return options