Source code for complex_problems.schrodinger_td.ui

"""UI dialog for configuring time-dependent Schrödinger simulations."""

from __future__ import annotations

import tkinter as tk
from tkinter import messagebox, ttk

from complex_problems.common import (
    add_how_to_config_section,
    compile_scalar_expression,
    parse_float,
    parse_positive_float,
    parse_positive_int,
    run_solver_with_loading,
)
from complex_problems.schrodinger_td.solver import solve_schrodinger_td
from config import get_env_from_schema
from frontend.theme import get_font
from frontend.ui_dialogs.scrollable_frame import ScrollableFrame
from frontend.ui_dialogs.tooltip import ToolTip
from frontend.window_utils import fit_and_center, make_modal

_DIMENSIONS = ("1D", "2D")
_BOUNDARIES = ("periodic", "absorbing")
_POTENTIALS = ("free", "harmonic", "square_well", "barrier", "double_well", "lattice", "custom")
_PACKETS = ("gaussian", "superposition", "custom")


[docs] class SchrodingerTDDialog: """Configuration dialog for TDSE in 1D and 2D.""" def __init__(self, parent: tk.Tk | tk.Toplevel) -> None: self.parent = parent self.win = tk.Toplevel(parent) self.win.title("Schrodinger TD (1D/2D)") self.win.configure(bg=get_env_from_schema("UI_BACKGROUND")) self._build_ui() fit_and_center(self.win, min_width=1040, min_height=780, padding=32, resizable=True) self.win.minsize(960, 700) make_modal(self.win, parent) def _build_ui(self) -> None: pad = int(get_env_from_schema("UI_PADDING")) root = ttk.Frame(self.win, padding=pad * 2) root.pack(fill=tk.BOTH, expand=True) scroll = ScrollableFrame(root) scroll.apply_bg(get_env_from_schema("UI_BACKGROUND")) scroll.pack(fill=tk.BOTH, expand=True) body = scroll.inner body.configure(padding=pad) ttk.Label(body, text="Schrodinger Time Evolution", style="Title.TLabel").pack(anchor=tk.W) ttk.Label( body, text=( "Split-operator spectral solver for TDSE in 1D or 2D.\n" "Use periodic boundaries for conservation checks or " "absorbing edges for open-domain behavior." ), style="Small.TLabel", justify=tk.LEFT, ).pack(anchor=tk.W, pady=(0, pad)) add_how_to_config_section( body, scroll, problem_id="schrodinger_td", pad=pad, wraplength=800, ) row = ttk.Frame(body) row.pack(fill=tk.X, pady=pad // 2) self._dimension_var = tk.StringVar(value="1D") self._boundary_var = tk.StringVar(value="periodic") dim_combo = self._make_combo(row, "Dimension", self._dimension_var, _DIMENSIONS, width=8) bnd_combo = self._make_combo(row, "Boundary", self._boundary_var, _BOUNDARIES, width=10) dim_combo.bind("<<ComboboxSelected>>", lambda _e: self._update_visibility()) bnd_combo.bind("<<ComboboxSelected>>", lambda _e: self._update_visibility()) row = ttk.Frame(body) row.pack(fill=tk.X, pady=pad // 2) self._x_min_var = tk.StringVar(value="-12.0") self._x_max_var = tk.StringVar(value="12.0") self._nx_var = tk.StringVar(value="256") self._make_entry(row, "xₘᵢₙ", self._x_min_var, width=10) self._make_entry(row, "xₘₐₓ", self._x_max_var, width=10) self._make_spinbox(row, "Nₓ", self._nx_var, from_=32, to=16384, width=8) self._y_domain_row = ttk.Frame(body) self._y_domain_row.pack(fill=tk.X, pady=pad // 2) self._y_min_var = tk.StringVar(value="-12.0") self._y_max_var = tk.StringVar(value="12.0") self._ny_var = tk.StringVar(value="128") self._make_entry(self._y_domain_row, "yₘᵢₙ", self._y_min_var, width=10) self._make_entry(self._y_domain_row, "yₘₐₓ", self._y_max_var, width=10) self._make_spinbox(self._y_domain_row, "Nᵧ", self._ny_var, from_=32, to=16384, width=8) row = ttk.Frame(body) row.pack(fill=tk.X, pady=pad // 2) self._t_min_var = tk.StringVar(value="0.0") self._t_max_var = tk.StringVar(value="8.0") self._dt_var = tk.StringVar(value="0.002") self._make_entry(row, "tₘᵢₙ", self._t_min_var, width=10) self._make_entry(row, "tₘₐₓ", self._t_max_var, width=10) self._make_entry(row, "Δt", self._dt_var, width=10) row = ttk.Frame(body) row.pack(fill=tk.X, pady=pad // 2) self._hbar_var = tk.StringVar(value="1.0") self._mass_var = tk.StringVar(value="1.0") self._make_entry(row, "ħ", self._hbar_var, width=10) self._make_entry(row, "Mass m", self._mass_var, width=10) self._absorb_row = ttk.Frame(body) self._absorb_row.pack(fill=tk.X, pady=pad // 2) self._absorb_ratio_var = tk.StringVar(value="0.10") self._absorb_strength_var = tk.StringVar(value="1.0") self._make_entry(self._absorb_row, "Absorb ratio", self._absorb_ratio_var, width=9) self._make_entry(self._absorb_row, "Absorb strength", self._absorb_strength_var, width=11) ToolTip( self._absorb_row, "Only used for absorbing boundaries. Ratio should be between 0 and 0.49.", ) ttk.Separator(body).pack(fill=tk.X, pady=pad) ttk.Label(body, text="Potential", style="Small.TLabel").pack(anchor=tk.W) row = ttk.Frame(body) row.pack(fill=tk.X, pady=pad // 2) self._potential_var = tk.StringVar(value="free") pot_combo = self._make_combo(row, "Type", self._potential_var, _POTENTIALS, width=14) pot_combo.bind("<<ComboboxSelected>>", lambda _e: self._update_visibility()) ToolTip( pot_combo, "free, harmonic, square_well, barrier, double_well, lattice, or custom expression.", ) row = ttk.Frame(body) row.pack(fill=tk.X, pady=pad // 2) self._omega_var = tk.StringVar(value="1.0") self._v0_var = tk.StringVar(value="5.0") self._width_var = tk.StringVar(value="2.0") self._barrier_sigma_var = tk.StringVar(value="0.4") self._make_entry(row, "ω", self._omega_var, width=8) self._make_entry(row, "V₀", self._v0_var, width=8) self._make_entry(row, "Width", self._width_var, width=8) self._make_entry(row, "Barrier σ", self._barrier_sigma_var, width=9) row = ttk.Frame(body) row.pack(fill=tk.X, pady=pad // 2) self._lattice_k_var = tk.StringVar(value="2.0") self._a_dw_var = tk.StringVar(value="1.0") self._b_dw_var = tk.StringVar(value="1.0") self._make_entry(row, "Lattice k", self._lattice_k_var, width=8) self._make_entry(row, "a (double-well)", self._a_dw_var, width=12) self._make_entry(row, "b (double-well)", self._b_dw_var, width=12) self._custom_potential_row = ttk.Frame(body) self._custom_potential_row.pack(fill=tk.X, pady=pad // 2) ttk.Label(self._custom_potential_row, text="V_custom =").pack(side=tk.LEFT, padx=(0, 4)) self._custom_potential_var = tk.StringVar(value="0.5*(x**2 + y**2)") self._custom_potential_entry = ttk.Entry( self._custom_potential_row, textvariable=self._custom_potential_var, width=56, font=get_font(), ) self._custom_potential_entry.pack(side=tk.LEFT, fill=tk.X, expand=True) ttk.Separator(body).pack(fill=tk.X, pady=pad) ttk.Label(body, text="Initial packet", style="Small.TLabel").pack(anchor=tk.W) row = ttk.Frame(body) row.pack(fill=tk.X, pady=pad // 2) self._packet_var = tk.StringVar(value="gaussian") packet_combo = self._make_combo(row, "Type", self._packet_var, _PACKETS, width=14) packet_combo.bind("<<ComboboxSelected>>", lambda _e: self._update_visibility()) row = ttk.Frame(body) row.pack(fill=tk.X, pady=pad // 2) self._sigma_var = tk.StringVar(value="0.8") self._x0_var = tk.StringVar(value="-3.0") self._k0x_var = tk.StringVar(value="2.0") self._separation_var = tk.StringVar(value="2.0") self._make_entry(row, "σ", self._sigma_var, width=8) self._make_entry(row, "x₀", self._x0_var, width=8) self._make_entry(row, "k₀x", self._k0x_var, width=8) self._make_entry(row, "Separation", self._separation_var, width=9) self._packet_y_row = ttk.Frame(body) self._packet_y_row.pack(fill=tk.X, pady=pad // 2) self._y0_var = tk.StringVar(value="0.0") self._k0y_var = tk.StringVar(value="0.0") self._make_entry(self._packet_y_row, "y₀", self._y0_var, width=8) self._make_entry(self._packet_y_row, "k₀y", self._k0y_var, width=8) self._custom_packet_row = ttk.Frame(body) self._custom_packet_row.pack(fill=tk.X, pady=pad // 2) ttk.Label(self._custom_packet_row, text="ψ_custom amplitude =").pack( side=tk.LEFT, padx=(0, 4) ) self._custom_packet_var = tk.StringVar(value="exp(-(x**2 + y**2))") self._custom_packet_entry = ttk.Entry( self._custom_packet_row, textvariable=self._custom_packet_var, width=46, font=get_font(), ) self._custom_packet_entry.pack(side=tk.LEFT, fill=tk.X, expand=True) ToolTip( self._custom_packet_entry, "Custom real amplitude profile. The phase comes from k0x/k0y.", ) self._btn_row = ttk.Frame(body) self._btn_row.pack(fill=tk.X, pady=(pad * 2, 0)) ttk.Button(self._btn_row, text="Solve", command=self._on_solve).pack( side=tk.LEFT, padx=(0, pad) ) ttk.Button( self._btn_row, text="Close", style="Cancel.TButton", command=self.win.destroy, ).pack( side=tk.LEFT ) self._update_visibility() scroll.bind_new_children() def _make_entry( self, parent: ttk.Frame, label: str, var: tk.StringVar, *, width: int = 10 ) -> ttk.Entry: ttk.Label(parent, text=f"{label}:").pack(side=tk.LEFT, padx=(0, 4)) entry = ttk.Entry(parent, textvariable=var, width=width, font=get_font()) entry.pack(side=tk.LEFT, padx=(0, 12)) return entry def _make_spinbox( self, parent: ttk.Frame, label: str, var: tk.StringVar, *, from_: int, to: int, width: int = 8, ) -> ttk.Spinbox: ttk.Label(parent, text=f"{label}:").pack(side=tk.LEFT, padx=(0, 4)) spin = ttk.Spinbox( parent, textvariable=var, from_=from_, to=to, width=width, font=get_font(), ) spin.pack(side=tk.LEFT, padx=(0, 12)) return spin def _make_combo( self, parent: ttk.Frame, label: str, var: tk.StringVar, values: tuple[str, ...], *, width: int = 12, ) -> ttk.Combobox: ttk.Label(parent, text=f"{label}:").pack(side=tk.LEFT, padx=(0, 4)) combo = ttk.Combobox( parent, textvariable=var, values=list(values), state="readonly", width=width, font=get_font(), ) combo.pack(side=tk.LEFT, padx=(0, 12)) return combo def _update_visibility(self) -> None: is_2d = self._dimension_var.get() == "2D" if is_2d: self._y_domain_row.pack(fill=tk.X, pady=4, before=self._btn_row) self._packet_y_row.pack(fill=tk.X, pady=4, before=self._btn_row) else: self._y_domain_row.pack_forget() self._packet_y_row.pack_forget() if self._boundary_var.get() == "absorbing": self._absorb_row.pack(fill=tk.X, pady=4, before=self._btn_row) else: self._absorb_row.pack_forget() if self._potential_var.get() == "custom": self._custom_potential_row.pack(fill=tk.X, pady=4, before=self._btn_row) else: self._custom_potential_row.pack_forget() if self._packet_var.get() == "custom": self._custom_packet_row.pack(fill=tk.X, pady=4, before=self._btn_row) else: self._custom_packet_row.pack_forget() def _collect_inputs(self) -> dict[str, object]: dimension = 2 if self._dimension_var.get() == "2D" else 1 boundary = self._boundary_var.get().strip().lower() potential_type = self._potential_var.get().strip().lower() packet_type = self._packet_var.get().strip().lower() x_min = parse_float(self._x_min_var.get(), name="xₘᵢₙ") x_max = parse_float(self._x_max_var.get(), name="xₘₐₓ") if x_max <= x_min: raise ValueError("xₘₐₓ must be greater than xₘᵢₙ.") nx = parse_positive_int(self._nx_var.get(), name="Nₓ", min_value=32) y_min = parse_float(self._y_min_var.get(), name="yₘᵢₙ") y_max = parse_float(self._y_max_var.get(), name="yₘₐₓ") ny = parse_positive_int(self._ny_var.get(), name="Nᵧ", min_value=32) if dimension == 2 and y_max <= y_min: raise ValueError("yₘₐₓ must be greater than yₘᵢₙ.") t_min = parse_float(self._t_min_var.get(), name="tₘᵢₙ") t_max = parse_float(self._t_max_var.get(), name="tₘₐₓ") if t_max <= t_min: raise ValueError("tₘₐₓ must be greater than tₘᵢₙ.") dt = parse_positive_float(self._dt_var.get(), name="Δt") hbar = parse_positive_float(self._hbar_var.get(), name="ħ") mass = parse_positive_float(self._mass_var.get(), name="Mass m") absorb_ratio = parse_float(self._absorb_ratio_var.get(), name="Absorb ratio") if not (0.0 <= absorb_ratio < 0.5): raise ValueError("Absorb ratio must be in [0, 0.5).") absorb_strength = parse_positive_float( self._absorb_strength_var.get(), name="Absorb strength", ) omega = parse_positive_float(self._omega_var.get(), name="ω") v0 = parse_float(self._v0_var.get(), name="V₀") width = parse_positive_float(self._width_var.get(), name="Width") barrier_sigma = parse_positive_float(self._barrier_sigma_var.get(), name="Barrier σ") lattice_k = parse_positive_float(self._lattice_k_var.get(), name="Lattice k") a_dw = parse_positive_float(self._a_dw_var.get(), name="a (double-well)") b_dw = parse_positive_float(self._b_dw_var.get(), name="b (double-well)") sigma = parse_positive_float(self._sigma_var.get(), name="σ") x0 = parse_float(self._x0_var.get(), name="x₀") y0 = parse_float(self._y0_var.get(), name="y₀") k0x = parse_float(self._k0x_var.get(), name="k₀x") k0y = parse_float(self._k0y_var.get(), name="k₀y") separation = parse_positive_float(self._separation_var.get(), name="Separation") custom_potential_fn_1d = None custom_potential_fn_2d = None if potential_type == "custom": if dimension == 2: compiled = compile_scalar_expression( self._custom_potential_var.get(), variables=("x", "y"), ) def custom_potential_fn_2d( x: float, y: float, _fn=compiled, ) -> float: return _fn(x=float(x), y=float(y)) else: compiled = compile_scalar_expression( self._custom_potential_var.get(), variables=("x",), ) def custom_potential_fn_1d( x: float, _fn=compiled, ) -> float: return _fn(x=float(x)) custom_packet_fn_1d = None custom_packet_fn_2d = None if packet_type == "custom": if dimension == 2: compiled = compile_scalar_expression( self._custom_packet_var.get(), variables=("x", "y"), ) def custom_packet_fn_2d( x: float, y: float, _fn=compiled, ) -> float: return _fn(x=float(x), y=float(y)) else: compiled = compile_scalar_expression( self._custom_packet_var.get(), variables=("x",), ) def custom_packet_fn_1d( x: float, _fn=compiled, ) -> float: return _fn(x=float(x)) return { "dimension": dimension, "x_min": x_min, "x_max": x_max, "nx": nx, "y_min": y_min, "y_max": y_max, "ny": ny, "t_min": t_min, "t_max": t_max, "dt": dt, "hbar": hbar, "mass": mass, "boundary": boundary, "absorb_ratio": absorb_ratio, "absorb_strength": absorb_strength, "potential_type": potential_type, "omega": omega, "v0": v0, "width": width, "barrier_sigma": barrier_sigma, "lattice_k": lattice_k, "a_dw": a_dw, "b_dw": b_dw, "packet_type": packet_type, "sigma": sigma, "x0": x0, "y0": y0, "k0x": k0x, "k0y": k0y, "separation": separation, "custom_potential_fn_1d": custom_potential_fn_1d, "custom_potential_fn_2d": custom_potential_fn_2d, "custom_packet_fn_1d": custom_packet_fn_1d, "custom_packet_fn_2d": custom_packet_fn_2d, } def _on_solve(self) -> None: try: params = self._collect_inputs() except ValueError as exc: messagebox.showerror("Invalid input", str(exc), parent=self.win) return self.win.destroy() def _task(): return solve_schrodinger_td(**params) def _on_success(result) -> None: from complex_problems.schrodinger_td.result_dialog import SchrodingerTDResultDialog SchrodingerTDResultDialog(self.parent, result=result) run_solver_with_loading( parent=self.parent, message="Solving Schrodinger TD...", task=_task, on_success=_on_success, )