"""Mathematical transforms: Fourier, Laplace, Taylor series."""
from __future__ import annotations
import math
from enum import Enum
from functools import lru_cache
from typing import Callable
import numpy as np
from numpy.polynomial import polynomial as P
from scipy import fft
from scipy.integrate import quad_vec
from utils import get_logger
logger = get_logger(__name__)
_MAX_FFT = 65536
def _nth_derivative(
f: Callable[[float], float],
x0: float,
n: int,
dx: float = 1e-6,
) -> float:
"""Compute the n-th derivative of f at x0 using central differences.
Args:
f: Scalar function.
x0: Point of evaluation.
n: Derivative order.
dx: Step size.
Returns:
Approximate n-th derivative.
"""
if n == 0:
return f(x0)
try:
from scipy.misc import derivative as scipy_deriv
return scipy_deriv(f, x0, n=n, dx=dx, order=2 * n + 1)
except (ImportError, AttributeError):
logger.debug("scipy.misc.derivative unavailable, using central-difference fallback")
# Fallback: recursive central difference
def df(x: float) -> float:
return (_nth_derivative(f, x + dx, n - 1, dx) - _nth_derivative(f, x - dx, n - 1, dx)) / (
2 * dx
)
return df(x0)
def _compute_taylor_coeffs(
func: Callable[[np.ndarray], np.ndarray],
center: float,
order: int,
x_min: float | None = None,
x_max: float | None = None,
) -> np.ndarray:
"""Compute Taylor series coefficients around a center point.
Uses least-squares polynomial fitting when x_min/x_max are provided (more stable
than numerical differentiation for high orders). Falls back to derivative-based
computation otherwise.
Args:
func: Vectorized callable f(x) -> y.
center: Center point for Taylor expansion.
order: Highest order of derivative.
x_min: Lower bound of domain (for polynomial fitting).
x_max: Upper bound of domain (for polynomial fitting).
Returns:
Array of Taylor coefficients a_0, a_1, ..., a_order.
"""
if x_min is not None and x_max is not None and x_max > x_min:
# Polynomial fitting: f(x) ≈ Σ a_k (x - center)^k. More stable than derivatives.
span = max(x_max - x_min, 1e-10)
radius = min(1.0, span / 2.0)
n_samples = max(order * 2 + 1, 50)
x_sample = np.linspace(center - radius, center + radius, n_samples)
x_sample = np.clip(x_sample, x_min, x_max)
y_sample = func(x_sample)
t = x_sample - center # Expand in powers of (x - center)
# Vandermonde: V[i,k] = t[i]^k
V = np.vander(t, order + 1, increasing=True)
coeffs, *_ = np.linalg.lstsq(V, y_sample, rcond=None)
# Zero out negligible coefficients (numerical noise)
max_c = float(np.max(np.abs(coeffs)))
if max_c > 0:
coeffs = np.where(np.abs(coeffs) < 1e-10 * max_c, 0.0, coeffs)
return coeffs
# Fallback: derivative-based (unstable for order >= 4)
def nth_derivative(n: int, x0: float) -> float:
dx = max(1e-6, 1e-2 * (0.5) ** max(0, n - 2)) # Larger dx for high n
return _nth_derivative(
lambda t: float(func(np.array([t]))[0]),
x0,
n=n,
dx=dx,
)
coeffs = np.zeros(order + 1)
for n in range(order + 1):
coeffs[n] = nth_derivative(n, center) / math.factorial(n)
max_c = float(np.max(np.abs(coeffs)))
if max_c > 0:
coeffs = np.where(np.abs(coeffs) < 1e-10 * max_c, 0.0, coeffs)
return coeffs
@lru_cache(maxsize=64)
def _hilbert_filter_kernel(n: int) -> np.ndarray:
"""Build Hilbert transform filter for FFT of length n.
Cached for repeated calls with same n (common in FFT-based transforms).
"""
h = np.zeros(n, dtype=complex)
h[0] = 1
h[1 : (n + 1) // 2] = 2
if n % 2 == 0:
h[n // 2] = 1
return h
def _trim_indices_by_amplitude(
magnitudes: np.ndarray,
threshold_fraction: float,
use_nanmax: bool = False,
) -> tuple[int, int] | None:
"""Find indices where magnitude is above threshold (fraction of max).
Args:
magnitudes: Array of magnitudes.
threshold_fraction: Fraction of max amplitude to use as threshold.
use_nanmax: If True, use np.nanmax (for Laplace with possible NaN).
Returns:
(i_min, i_max) if any point above threshold, else None.
"""
if threshold_fraction <= 0 or len(magnitudes) == 0:
return None
max_amp = float(np.nanmax(magnitudes) if use_nanmax else np.max(magnitudes))
if max_amp <= 0 or not np.isfinite(max_amp):
return None
threshold = max_amp * threshold_fraction
above = np.where(magnitudes >= threshold)[0]
if len(above) == 0:
return None
return int(above[0]), int(above[-1])
def _refine_fft_spectrum_in_range(
y: np.ndarray,
dx: float,
f_low: float,
f_high: float,
n_target: int,
magnitude_fn: Callable[[np.ndarray, int], np.ndarray],
fallback: tuple[np.ndarray, np.ndarray, np.ndarray] | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Refine FFT spectrum via zero-padding and extract range [f_low, f_high].
Args:
y: Real signal samples.
dx: Sample spacing.
f_low: Lower frequency bound.
f_high: Upper frequency bound.
n_target: Target number of points for refinement.
magnitude_fn: Callable(fft_vals, n) -> magnitudes of length n//2.
fallback: If mask is empty, return this (freqs, mag, bin_indices) instead.
Returns:
(freqs, magnitudes, bin_indices) in the requested range.
"""
n_points = len(y)
f_span = max(f_high - f_low, 1.0 / (n_points * dx))
n_refined = int(np.ceil(n_target / (f_span * dx)))
n_refined = max(n_refined, n_points)
# Cap at a reasonable size; if the signal is longer, downsample it
if n_refined > _MAX_FFT:
n_refined = _MAX_FFT
if n_points > n_refined:
step = max(1, n_points // n_refined)
y = y[::step]
n_points = len(y)
y_padded = np.zeros(n_refined, dtype=complex)
y_padded[:n_points] = y
fft_ref = fft.fft(y_padded)
mag_ref = magnitude_fn(fft_ref, n_refined)
freqs_ref = fft.fftfreq(n_refined, dx)[: n_refined // 2]
freqs_ref = np.abs(freqs_ref)
mask = (freqs_ref >= f_low) & (freqs_ref <= f_high)
if np.any(mask):
bin_indices = np.where(mask)[0]
return freqs_ref[mask], mag_ref[mask], bin_indices
if fallback is not None:
return fallback
return freqs_ref, mag_ref, np.arange(len(freqs_ref))
def _trim_and_refine_fft_spectrum(
y: np.ndarray,
dx: float,
freqs: np.ndarray,
magnitudes: np.ndarray,
threshold_fraction: float,
n_target: int,
magnitude_fn: Callable[[np.ndarray, int], np.ndarray],
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Trim spectrum by amplitude threshold and refine via zero-padding.
Args:
y: Real signal samples.
dx: Sample spacing.
freqs: Frequency axis.
magnitudes: Magnitude values.
threshold_fraction: Relative amplitude threshold.
n_target: Target number of points after refinement.
magnitude_fn: Callable(fft_vals, n) -> magnitudes for refined FFT.
Returns:
(freqs, magnitudes, bin_indices) trimmed and refined.
bin_indices are the actual FFT bin indices (k) for coefficient display.
"""
trimmed = _trim_indices_by_amplitude(magnitudes, threshold_fraction)
if trimmed is None:
return freqs, magnitudes, np.arange(len(freqs))
i_min, i_max = trimmed
f_low, f_high = float(freqs[i_min]), float(freqs[i_max])
fallback = (
freqs[i_min : i_max + 1],
magnitudes[i_min : i_max + 1],
np.arange(i_min, i_max + 1),
)
return _refine_fft_spectrum_in_range(
y, dx, f_low, f_high, n_target, magnitude_fn, fallback=fallback
)
def _trim_and_refine_laplace(
func: Callable[[np.ndarray], np.ndarray],
x_min: float,
x_max: float,
s_vals: np.ndarray,
laplace_vals: np.ndarray,
threshold_fraction: float,
n_target: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Trim Laplace spectrum by amplitude and recompute in refined s range.
Args:
func: Vectorized callable f(x) -> y.
x_min: Lower integration bound.
x_max: Upper integration bound.
s_vals: Original s values.
laplace_vals: Laplace transform values.
threshold_fraction: Relative amplitude threshold.
n_target: Target number of points.
Returns:
(s_vals, laplace_vals, sample_indices) trimmed and refined.
sample_indices are the original sample indices (i) for L(s_i) display.
"""
laplace_mag = np.abs(laplace_vals)
trimmed = _trim_indices_by_amplitude(laplace_mag, threshold_fraction, use_nanmax=True)
if trimmed is None:
return s_vals, laplace_vals, np.arange(len(s_vals))
i_min, i_max = trimmed
s_low, s_high = float(s_vals[i_min]), float(s_vals[i_max])
s_refined = np.linspace(s_low, s_high, n_target)
laplace_refined = _compute_laplace_samples(func, x_min, x_max, s_refined)
# Indices map refined points to original grid (i_min..i_max)
sample_indices = np.linspace(i_min, i_max, n_target)
return s_refined, laplace_refined, sample_indices
def _compute_laplace_samples(
func: Callable[[np.ndarray], np.ndarray],
x_min: float,
x_max: float,
s_vals: np.ndarray,
) -> np.ndarray:
"""Compute Laplace transform samples over given s values.
Uses quad_vec for vectorized integration (single call for all s values).
Args:
func: Vectorized callable f(x) -> y.
x_min: Lower bound of integration.
x_max: Upper bound of integration.
s_vals: Array of s values at which to evaluate the Laplace transform.
Returns:
Array of Laplace transform values at each s.
"""
s_vals = np.asarray(s_vals, dtype=float)
n_s = len(s_vals)
def vector_integrand(t: float) -> np.ndarray:
if t < x_min or t > x_max:
return np.zeros(n_s, dtype=float)
try:
ft = float(func(np.array([t]))[0])
return ft * np.exp(-s_vals * t)
except (ValueError, ZeroDivisionError, OverflowError):
return np.zeros(n_s, dtype=float)
try:
result, _ = quad_vec(
vector_integrand,
x_min,
x_max,
epsabs=1e-10,
epsrel=1e-8,
limit=200,
)
return np.asarray(result, dtype=float)
except Exception as exc:
logger.debug("Laplace quad_vec failed: %s", exc)
return np.full(n_s, np.nan, dtype=float)
[docs]
class DisplayMode(str, Enum):
"""How to display the transform result."""
CURVE = "Curve (f vs x)"
COEFFICIENTS = "Coefficients (a\u1d62 vs i)"
def _fft_magnitude_fn(fv: np.ndarray, n: int) -> np.ndarray:
"""Standard FFT magnitude (Fourier/Z-transform)."""
return np.abs(fv[: n // 2])
def _hilbert_magnitude_fn(fv: np.ndarray, n: int) -> np.ndarray:
"""Hilbert filter magnitude for coefficient display."""
return np.abs((fv * _hilbert_filter_kernel(n))[: n // 2])
def _compute_fft_spectrum(
func: Callable[[np.ndarray], np.ndarray],
x_min: float,
x_max: float,
n_points: int,
threshold: float,
magnitude_fn: Callable[[np.ndarray, int], np.ndarray] = _fft_magnitude_fn,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Sample a function, compute FFT, and return trimmed/refined spectrum.
Returns:
(freqs, magnitudes, bin_indices) after trimming and refinement.
"""
x, y = compute_function_samples(func, x_min, x_max, n_points)
dx = (x_max - x_min) / (n_points - 1) if n_points > 1 else 1.0
fft_mag = np.abs(fft.fft(y)[: n_points // 2])
freqs = np.abs(fft.fftfreq(n_points, dx)[: n_points // 2])
return _trim_and_refine_fft_spectrum(
y,
dx,
freqs,
fft_mag,
threshold,
n_points // 2,
magnitude_fn=magnitude_fn,
)
def _apply_fft_magnitude_spectrum(
func: Callable[[np.ndarray], np.ndarray],
x_min: float,
x_max: float,
n_points: int,
threshold: float,
y_label: str,
) -> tuple[np.ndarray, np.ndarray, str, str]:
"""Apply FFT magnitude spectrum (shared by Fourier and Z-transform)."""
freqs, fft_mag, _ = _compute_fft_spectrum(func, x_min, x_max, n_points, threshold)
return freqs, fft_mag, "ω/(2π)", y_label
def _get_fft_coefficients(
func: Callable[[np.ndarray], np.ndarray],
x_min: float,
x_max: float,
n_points: int,
threshold: float,
y_label: str,
base_meta: dict[str, object],
) -> tuple[np.ndarray, np.ndarray, str, str, dict[str, object]]:
"""Get FFT coefficient representation (shared by Fourier and Z-transform)."""
freqs, coeffs, _ = _compute_fft_spectrum(func, x_min, x_max, n_points, threshold)
meta = {**base_meta, "amp_threshold": threshold}
return freqs, coeffs, "ω/(2π)", y_label, meta
[docs]
def compute_function_samples(
func: Callable[[np.ndarray], np.ndarray],
x_min: float,
x_max: float,
n_points: int = 1024,
) -> tuple[np.ndarray, np.ndarray]:
"""Sample a function over a range.
Args:
func: Vectorized callable f(x) -> y.
x_min: Lower bound.
x_max: Upper bound.
n_points: Number of sample points.
Returns:
Tuple of (x, y) arrays.
"""
x = np.linspace(x_min, x_max, n_points)
y = func(x)
return x, y