Source code for pipeline

"""Solver pipeline — orchestrates validation, solving, and statistics."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable

import numpy as np

from config import get_env_from_schema
from solver import (
    FNotation,
    ODESolution,
    compute_ode_residual_error,
    compute_statistics,
    compute_statistics_2d,
    get_difference_function,
    get_ode_function,
    get_vector_ode_function,
    is_multivariate,
    parse_pde_rhs_expression,
    solve_difference,
    solve_multipoint,
    solve_ode,
    solve_pde_2d,
    validate_all_inputs,
)
from solver.pde_solver import BC_DIRICHLET, BC_NEUMANN
from solver.predefined import EquationType
from utils import ValidationError, build_eval_namespace, get_logger, safe_eval

logger = get_logger(__name__)


@dataclass
class _DispatchResult:
    """Intermediate result from a solver dispatch function.

    Attributes:
        x: Independent variable or grid values.
        y: Solution array.
        success: Whether the solver converged.
        message: Solver status message.
        n_eval: Number of function evaluations.
        error_metrics: Residual error metrics (ODE only).
        solver_quality: Solver parameters and metadata.
        y_grid: For 2D PDE, the y-axis grid. None otherwise.
        ode_func: ODE function used (for residual computation). None for PDE/difference.
    """

    x: np.ndarray
    y: np.ndarray
    success: bool
    message: str
    n_eval: int
    error_metrics: dict[str, float] = field(default_factory=dict)
    solver_quality: dict[str, Any] = field(default_factory=dict)
    y_grid: np.ndarray | None = None
    ode_func: Callable | None = None


def _build_solver_quality(solution: ODESolution) -> dict[str, Any]:
    """Build solver quality metadata from an ODE solution.

    Args:
        solution: ODE solution from :func:`solve_ode` or :func:`solve_multipoint`.

    Returns:
        Dict with rtol, atol, and optionally n_jacobian_evals.
    """
    quality: dict[str, Any] = {
        "rtol": get_env_from_schema("SOLVER_RTOL"),
        "atol": get_env_from_schema("SOLVER_ATOL"),
    }
    if solution.raw is not None:
        njev = getattr(solution.raw, "njev", None)
        if njev is not None:
            quality["n_jacobian_evals"] = int(njev)
    return quality


def _build_bc_array(
    bc_expressions: list[str],
    variables: list[str],
    parameters: dict[str, float],
    x_grid: np.ndarray,
    y_grid: np.ndarray,
    nx: int,
    ny: int,
) -> np.ndarray:
    """Build a (ny, nx) boundary values array from function expressions.

    Order of ``bc_expressions``: [bottom(y=y_min), top(y=y_max),
    left(x=x_min), right(x=x_max)]. Horizontal boundaries (bottom/top)
    are written first. Vertical boundaries (left/right) overwrite corner
    values where they overlap.

    Args:
        bc_expressions: List of 4 expressions for boundary values.
        variables: Independent variable names (e.g. ``["x", "y"]``).
        parameters: Named parameter values for expression evaluation.
        x_grid: 1D array of x values.
        y_grid: 1D array of y values.
        nx: Number of x grid points.
        ny: Number of y grid points.

    Returns:
        Boundary values array of shape (ny, nx).
    """
    bc = np.zeros((ny, nx))

    # Bottom (row 0): f(x) along x at y = y_min
    if bc_expressions[0].strip() not in ("0", ""):
        func = parse_pde_rhs_expression(bc_expressions[0], [variables[0]], parameters)
        for i in range(nx):
            bc[0, i] = func(x_grid[i])

    # Top (row ny-1): f(x) along x at y = y_max
    if len(bc_expressions) > 1 and bc_expressions[1].strip() not in ("0", ""):
        func = parse_pde_rhs_expression(bc_expressions[1], [variables[0]], parameters)
        for i in range(nx):
            bc[ny - 1, i] = func(x_grid[i])

    # Left (col 0): f(y) along y at x = x_min
    if len(bc_expressions) > 2 and bc_expressions[2].strip() not in ("0", ""):
        func = parse_pde_rhs_expression(bc_expressions[2], [variables[1]], parameters)
        for j in range(ny):
            bc[j, 0] = func(y_grid[j])

    # Right (col nx-1): f(y) along y at x = x_max
    if len(bc_expressions) > 3 and bc_expressions[3].strip() not in ("0", ""):
        func = parse_pde_rhs_expression(bc_expressions[3], [variables[1]], parameters)
        for j in range(ny):
            bc[j, nx - 1] = func(y_grid[j])

    return bc


def _build_mask(
    mask_expression: str | None,
    x_grid: np.ndarray,
    y_grid: np.ndarray,
    parameters: dict[str, float],
) -> np.ndarray | None:
    """Evaluate a mask expression on the grid, or return None for rectangular.

    Args:
        mask_expression: Python expression using x, y (or X, Y for meshgrid).
        x_grid: 1D array of x values.
        y_grid: 1D array of y values.
        parameters: Named parameters for expression evaluation.

    Returns:
        Boolean array (ny, nx) where True = inside domain, or None if
        no mask (rectangular domain).
    """
    if not mask_expression or not mask_expression.strip():
        return None
    X, Y = np.meshgrid(x_grid, y_grid)
    ns = build_eval_namespace(parameters)
    ns.update({"x": X, "y": Y, "X": X, "Y": Y})
    compiled = compile(mask_expression.strip(), "<mask>", "eval")
    result = safe_eval(compiled, ns)
    return np.asarray(result, dtype=bool)


def _build_bc_type_array(
    bc_types: list[str] | None,
    nx: int,
    ny: int,
    mask: np.ndarray | None,
    contour_bc_type: str | None,
) -> np.ndarray | None:
    """Build a (ny, nx) BC type array from per-edge types or contour type.

    Order of bc_types: [bottom, top, left, right] — each "dirichlet" or
    "neumann". For custom contour domains, contour_bc_type is used for
    all boundary points.

    Args:
        bc_types: Per-edge BC types [bottom, top, left, right].
        nx: Number of x grid points.
        ny: Number of y grid points.
        mask: Boolean mask (ny, nx). None for rectangular.
        contour_bc_type: BC type for custom contour (all boundary).

    Returns:
        String array (ny, nx) with "dirichlet" or "neumann", or None.
    """
    if bc_types is None and contour_bc_type is None:
        return None

    arr = np.full((ny, nx), BC_DIRICHLET, dtype=object)

    if mask is not None and contour_bc_type:
        # Custom contour: uniform BC type on all boundary
        arr[:] = contour_bc_type
    elif bc_types:
        # Rectangular: per-edge types
        types = bc_types + [BC_DIRICHLET] * (4 - len(bc_types))
        arr[0, :] = types[0]  # bottom
        arr[ny - 1, :] = types[1]  # top
        arr[:, 0] = types[2]  # left
        arr[:, nx - 1] = types[3]  # right

    return arr


def _build_neumann_array(
    bc_types: list[str] | None,
    bc_expressions: list[str] | None,
    variables: list[str],
    parameters: dict[str, float],
    x_grid: np.ndarray,
    y_grid: np.ndarray,
    nx: int,
    ny: int,
    mask: np.ndarray | None,
    contour_bc_type: str | None,
    contour_bc_expression: str | None,
) -> np.ndarray | None:
    """Build a (ny, nx) Neumann derivative values array.

    Only fills values where bc_type is "neumann". Returns None if no Neumann.

    Args:
        bc_types: Per-edge BC types.
        bc_expressions: Boundary value expressions.
        variables: Independent variable names.
        parameters: Parameter dict for expression evaluation.
        x_grid: 1D x grid.
        y_grid: 1D y grid.
        nx: Number of x points.
        ny: Number of y points.
        mask: Domain mask (ny, nx) or None.
        contour_bc_type: BC type for custom contour.
        contour_bc_expression: Expression for contour Neumann values.

    Returns:
        Neumann values array (ny, nx) or None if no Neumann BCs.
    """
    has_neumann = False
    if bc_types and any(t == BC_NEUMANN for t in bc_types):
        has_neumann = True
    if contour_bc_type == BC_NEUMANN:
        has_neumann = True
    if not has_neumann:
        return None

    arr = np.zeros((ny, nx))

    if mask is not None and contour_bc_type == BC_NEUMANN and contour_bc_expression:
        # Custom contour: evaluate expression on full grid
        func = parse_pde_rhs_expression(contour_bc_expression, variables, parameters)
        for j in range(ny):
            for i in range(nx):
                arr[j, i] = func(x_grid[i], y_grid[j])
    elif bc_types and bc_expressions:
        types = bc_types + [BC_DIRICHLET] * (4 - len(bc_types))
        # Bottom (row 0)
        if types[0] == BC_NEUMANN and len(bc_expressions) > 0:
            func = parse_pde_rhs_expression(bc_expressions[0], [variables[0]], parameters)
            for i in range(nx):
                arr[0, i] = func(x_grid[i])
        # Top (row ny-1)
        if types[1] == BC_NEUMANN and len(bc_expressions) > 1:
            func = parse_pde_rhs_expression(bc_expressions[1], [variables[0]], parameters)
            for i in range(nx):
                arr[ny - 1, i] = func(x_grid[i])
        # Left (col 0)
        if types[2] == BC_NEUMANN and len(bc_expressions) > 2:
            func = parse_pde_rhs_expression(bc_expressions[2], [variables[1]], parameters)
            for j in range(ny):
                arr[j, 0] = func(y_grid[j])
        # Right (col nx-1)
        if types[3] == BC_NEUMANN and len(bc_expressions) > 3:
            func = parse_pde_rhs_expression(bc_expressions[3], [variables[1]], parameters)
            for j in range(ny):
                arr[j, nx - 1] = func(y_grid[j])

    return arr


def _dispatch_2d_pde(
    *,
    expression: str | None,
    vars_list: list[str],
    parameters: dict[str, float],
    x_min: float,
    x_max: float,
    y_min: float,
    y_max: float,
    n_points: int,
    n_points_y: int | None,
    pde_operator: str,
    bc_expressions: list[str] | None,
    bc_types: list[str] | None = None,
    mask_expression: str | None = None,
    contour_bc_expression: str | None = None,
    contour_bc_type: str | None = None,
) -> _DispatchResult:
    """Dispatch a 2D PDE solve.

    Returns:
        :class:`_DispatchResult` with grid, solution, and metadata.
    """
    ny = n_points_y if n_points_y is not None else n_points
    rhs_func = parse_pde_rhs_expression(expression or "0", vars_list, parameters)

    def residual(
        x: float,
        y: float,
        f: float,
        fx: float,
        fy: float,
        fxx: float,
        fxy: float,
        fyy: float,
        **kw: Any,
    ) -> float:
        rhs = rhs_func(
            x,
            y,
            f=f,
            fx=fx,
            fy=fy,
            fxx=fxx,
            fxy=fxy,
            fyy=fyy,
            **kw,
        )
        if pde_operator == "neg_laplacian":
            return -fxx - fyy - rhs
        if pde_operator == "laplacian":
            return fxx + fyy - rhs
        if pde_operator == "fxx":
            return fxx - rhs
        if pde_operator == "fyy":
            return fyy - rhs
        if pde_operator == "fx":
            return fx - rhs
        if pde_operator == "fy":
            return fy - rhs
        if pde_operator == "fxy":
            return fxy - rhs
        return -fxx - fyy - rhs

    x_grid = np.linspace(x_min, x_max, n_points)
    y_grid_bc = np.linspace(y_min, y_max, ny)

    # Build mask
    mask = _build_mask(mask_expression, x_grid, y_grid_bc, parameters)

    # Build Dirichlet BC values
    bc_values: np.ndarray | None = None
    if mask is not None and contour_bc_type != BC_NEUMANN and contour_bc_expression:
        # Custom contour with Dirichlet BC: evaluate expression on grid
        bc_values = np.zeros((ny, n_points))
        func = parse_pde_rhs_expression(contour_bc_expression, vars_list, parameters)
        for j in range(ny):
            for i in range(n_points):
                bc_values[j, i] = func(x_grid[i], y_grid_bc[j])
    elif bc_expressions and any(e.strip() not in ("0", "") for e in bc_expressions):
        bc_values = _build_bc_array(
            bc_expressions,
            vars_list,
            parameters,
            x_grid,
            y_grid_bc,
            n_points,
            ny,
        )

    # Build BC type and Neumann arrays
    bc_type_arr = _build_bc_type_array(bc_types, n_points, ny, mask, contour_bc_type)
    neumann_arr = _build_neumann_array(
        bc_types,
        bc_expressions,
        vars_list,
        parameters,
        x_grid,
        y_grid_bc,
        n_points,
        ny,
        mask,
        contour_bc_type,
        contour_bc_expression,
    )

    pde_sol = solve_pde_2d(
        residual,
        x_min,
        x_max,
        y_min,
        y_max,
        n_points,
        ny,
        bc_values=bc_values,
        parameters=parameters,
        mask=mask,
        bc_type=bc_type_arr,
        bc_neumann_value=neumann_arr,
    )
    return _DispatchResult(
        x=pde_sol.grid[0],
        y=pde_sol.u,
        success=pde_sol.success,
        message=pde_sol.message,
        n_eval=pde_sol.n_eval,
        y_grid=pde_sol.grid[1],
    )


def _dispatch_difference(
    *,
    expression: str | None,
    function_name: str | None,
    order: int,
    parameters: dict[str, float],
    x_min: float,
    x_max: float,
    y0: list[float],
) -> _DispatchResult:
    """Dispatch a difference equation solve.

    Returns:
        :class:`_DispatchResult` with n, y, and metadata.
    """
    recur_func = get_difference_function(
        expression=expression,
        function_name=function_name,
        order=order,
        parameters=parameters,
    )
    diff_sol = solve_difference(recur_func, int(x_min), int(x_max), y0, order)
    if not diff_sol.success:
        from utils import SolverFailedError

        logger.error("Difference equation solver failed: %s", diff_sol.message)
        raise SolverFailedError(diff_sol.message)
    return _DispatchResult(
        x=diff_sol.n,
        y=diff_sol.y,
        success=diff_sol.success,
        message=diff_sol.message,
        n_eval=0,
    )


def _dispatch_vector_ode(
    *,
    vector_expressions: list[str] | None,
    function_name: str | None,
    order: int,
    vector_components: int,
    parameters: dict[str, float],
    x_min: float,
    x_max: float,
    y0: list[float],
    n_points: int,
    method: str,
) -> _DispatchResult:
    """Dispatch a vector ODE solve.

    Returns:
        :class:`_DispatchResult` with solution and error metrics.
    """
    vec_exprs = vector_expressions if vector_expressions else None
    ode_func = get_vector_ode_function(
        vector_expressions=vec_exprs,
        function_name=function_name if not vec_exprs else None,
        order=order,
        vector_components=vector_components,
        parameters=parameters,
    )
    t_eval = np.linspace(x_min, x_max, n_points)
    solution = solve_ode(ode_func, (x_min, x_max), y0, method=method, t_eval=t_eval)
    return _DispatchResult(
        x=solution.x,
        y=solution.y,
        success=solution.success,
        message=solution.message,
        n_eval=solution.n_eval,
        error_metrics=compute_ode_residual_error(ode_func, solution.x, solution.y),
        solver_quality=_build_solver_quality(solution),
        ode_func=ode_func,
    )


def _dispatch_scalar_ode(
    *,
    expression: str | None,
    function_name: str | None,
    order: int,
    parameters: dict[str, float],
    x_min: float,
    x_max: float,
    y0: list[float],
    n_points: int,
    method: str,
    x0_list: list[float] | None,
) -> _DispatchResult:
    """Dispatch a scalar ODE solve (IVP or multipoint BVP).

    Returns:
        :class:`_DispatchResult` with solution and error metrics.
    """
    ode_func = get_ode_function(
        expression=expression,
        function_name=function_name,
        order=order,
        parameters=parameters,
    )
    t_eval = np.linspace(x_min, x_max, n_points)
    use_multipoint = x0_list is not None and any(abs(xi - x_min) > 1e-12 for xi in x0_list)
    if use_multipoint:
        conditions = [(k, xi, ai) for k, (xi, ai) in enumerate(zip(x0_list, y0))]  # type: ignore[arg-type]
        solution = solve_multipoint(
            ode_func,
            conditions=conditions,
            order=order,
            x_min=x_min,
            x_max=x_max,
            method=method,
            t_eval=t_eval,
        )
    else:
        solution = solve_ode(ode_func, (x_min, x_max), y0, method=method, t_eval=t_eval)
    return _DispatchResult(
        x=solution.x,
        y=solution.y,
        success=solution.success,
        message=solution.message,
        n_eval=solution.n_eval,
        error_metrics=compute_ode_residual_error(ode_func, solution.x, solution.y),
        solver_quality=_build_solver_quality(solution),
        ode_func=ode_func,
    )


[docs] @dataclass class SolverResult: """Data-only bundle produced by a solver run (no pre-generated plots). Attributes: x: Independent variable values (1D) or x grid for 2D PDE. y: Solution array — shape ``(n_vars, n_points)`` or ``(ny, nx)`` for 2D. statistics: Computed statistics dict. metadata: Equation info, solver parameters, domain, etc. equation_type: ``"ode"``, ``"difference"``, ``"pde"``, or ``"vector_ode"``. y_grid: For 2D PDE, the y-axis grid. ``None`` otherwise. is_vector: Whether the equation is a vector ODE. vector_components: Number of components for vector ODE. vector_order: Display order (derivatives per component). notation: F-notation context for labels. """ x: np.ndarray y: np.ndarray statistics: dict[str, Any] metadata: dict[str, Any] equation_type: str = "ode" y_grid: np.ndarray | None = None # For 2D PDE: y-axis grid is_vector: bool = False vector_components: int = 1 vector_order: int = 1 notation: FNotation | None = None
[docs] def run_solver_pipeline( *, expression: str | None = None, function_name: str | None = None, order: int, parameters: dict[str, float], equation_name: str, x_min: float, x_max: float, y0: list[float], n_points: int, method: str, selected_stats: set[str], x0_list: list[float] | None = None, equation_type: EquationType = "ode", variables: list[str] | None = None, y_min: float | None = None, y_max: float | None = None, n_points_y: int | None = None, vector_expressions: list[str] | None = None, vector_components: int = 1, pde_operator: str = "neg_laplacian", component_orders: tuple[int, ...] | None = None, bc_expressions: list[str] | None = None, bc_types: list[str] | None = None, mask_expression: str | None = None, contour_bc_expression: str | None = None, contour_bc_type: str | None = None, ) -> SolverResult: """Execute the full solve workflow and return data results. Stages: validate → resolve function → solve → statistics. Plot generation is deferred to the ResultDialog for interactive control. Args: expression: ODE expression string (optional). function_name: Name of function in config.equations (optional). order: ODE order. parameters: Named parameter values. equation_name: Display name for plots/metadata. x_min: Domain start. x_max: Domain end. y0: Initial condition values ``[f(x₀), f'(x₁), …]``. n_points: Number of evaluation points. method: Solver method name. selected_stats: Set of statistic keys to compute. x0_list: Per-derivative condition points ``[x₀, x₁, …]``. If ``None`` or all equal to ``x_min``, uses standard IVP. equation_type: ``"ode"``, ``"difference"``, ``"pde"``, or ``"vector_ode"``. variables: Independent variable names (e.g. ``["x"]`` or ``["x", "y"]``). y_min: For 2D PDE, domain y start. y_max: For 2D PDE, domain y end. n_points_y: For 2D PDE, number of y grid points. vector_expressions: For vector ODE, list of expressions per component. vector_components: Number of components for vector ODE. pde_operator: PDE operator type (e.g. ``"neg_laplacian"``). component_orders: For vector ODE, order per component (optional). Returns: A :class:`SolverResult` with solution data, statistics, and metadata. Raises: ValidationError: If inputs fail validation. EquationParseError: If the expression cannot be parsed or function not found. DifferentialLabError: If the solver fails. """ vars_list = variables if variables else ["x"] is_pde = equation_type == "pde" or is_multivariate(vars_list) is_2d_pde = is_pde and len(vars_list) >= 2 is_vector = ( vector_expressions is not None and len(vector_expressions) > 0 ) or equation_type == "vector_ode" # ── Validate ────────────────────────────────────────────────────── if not is_pde: errors = validate_all_inputs( expression=expression if not is_vector else None, function_name=function_name, order=order, x_min=x_min, x_max=x_max, y0=y0, num_points=n_points, method=method, params=parameters, x0_list=x0_list, equation_type=equation_type, vector_expressions=vector_expressions if is_vector else None, vector_components=vector_components if is_vector else 1, ) if errors: msg = "\n".join(errors) logger.warning("Validation failed: %s", msg) raise ValidationError(msg) # ── Dispatch to equation-type-specific solver ───────────────────── if is_2d_pde: if y_min is None or y_max is None: logger.warning("PDE validation failed: y_min and y_max required") raise ValidationError("PDE requires y_min and y_max for 2D domain") dr = _dispatch_2d_pde( expression=expression, vars_list=vars_list, parameters=parameters, x_min=x_min, x_max=x_max, y_min=float(y_min), y_max=float(y_max), n_points=n_points, n_points_y=n_points_y, pde_operator=pde_operator, bc_expressions=bc_expressions, bc_types=bc_types, mask_expression=mask_expression, contour_bc_expression=contour_bc_expression, contour_bc_type=contour_bc_type, ) elif equation_type == "difference": dr = _dispatch_difference( expression=expression, function_name=function_name, order=order, parameters=parameters, x_min=x_min, x_max=x_max, y0=y0, ) elif is_vector: dr = _dispatch_vector_ode( vector_expressions=vector_expressions, function_name=function_name, order=order, vector_components=vector_components, parameters=parameters, x_min=x_min, x_max=x_max, y0=y0, n_points=n_points, method=method, ) else: dr = _dispatch_scalar_ode( expression=expression, function_name=function_name, order=order, parameters=parameters, x_min=x_min, x_max=x_max, y0=y0, n_points=n_points, method=method, x0_list=x0_list, ) solution_x = dr.x solution_y = dr.y # ── Compute highest derivative and augment y ────────────────────── display_order = order if not is_2d_pde and equation_type != "difference" and dr.ode_func is not None: try: y_2d = np.atleast_2d(solution_y) if y_2d.shape[1] != len(solution_x): y_2d = y_2d.T n_pts = len(solution_x) n_comp = vector_components if is_vector else 1 dydt_all = np.column_stack( [dr.ode_func(solution_x[j], y_2d[:, j]) for j in range(n_pts)] ) new_order = order + 1 augmented = np.empty((n_comp * new_order, n_pts)) for comp_i in range(n_comp): for k in range(order): augmented[comp_i * new_order + k] = y_2d[comp_i * order + k] augmented[comp_i * new_order + order] = dydt_all[comp_i * order + order - 1] solution_y = augmented display_order = new_order except Exception: logger.debug("Could not compute highest derivative; using raw y", exc_info=True) # ── Statistics ──────────────────────────────────────────────────── if is_2d_pde: stats = compute_statistics_2d(solution_x, dr.y_grid, solution_y, selected_stats) else: stats = compute_statistics(solution_x, solution_y, selected_stats) # ── Metadata ────────────────────────────────────────────────────── metadata: dict[str, Any] = { "equation_name": equation_name, "equation_type": equation_type, "expression": expression if expression else f"<function:{function_name}>", "order": order, "parameters": parameters, "domain": ( [x_min, x_max, y_min, y_max] if (is_pde and y_min is not None and y_max is not None) else [x_min, x_max] ), "initial_conditions": y0, "ic_points": x0_list if x0_list is not None else [x_min] * order, "method": method if equation_type == "ode" else "fdm", "num_points": n_points, "solver_success": dr.success, "solver_message": dr.message, "n_evaluations": dr.n_eval, "rtol": dr.solver_quality.get("rtol"), "atol": dr.solver_quality.get("atol"), "residual_max": dr.error_metrics.get("residual_max"), "residual_mean": dr.error_metrics.get("residual_mean"), "residual_rms": dr.error_metrics.get("residual_rms"), "n_jacobian_evals": dr.solver_quality.get("n_jacobian_evals"), "variables": vars_list, "boundary_conditions": bc_expressions, } # ── Notation ────────────────────────────────────────────────────── if equation_type == "difference": notation = FNotation(kind="difference", order=order) elif is_vector: display_comp_orders: tuple[int, ...] | None = None if component_orders: display_comp_orders = tuple(co + 1 for co in component_orders) notation = FNotation( kind="vector_ode", n_components=vector_components, order=display_order, component_orders=display_comp_orders or (), ) elif is_pde: notation = FNotation(kind="pde", n_independent_vars=len(vars_list), order=order) else: notation = FNotation(kind="ode", order=display_order) logger.info("Pipeline complete for '%s'", equation_name) return SolverResult( x=solution_x, y=solution_y, statistics=stats, metadata=metadata, equation_type=equation_type, y_grid=dr.y_grid if is_2d_pde else None, is_vector=is_vector, vector_components=vector_components if is_vector else 1, vector_order=display_order, notation=notation, )