Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pynumdiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
from .finite_difference import finitediff, first_order, second_order, fourth_order
from .smooth_finite_difference import kerneldiff, meandiff, mediandiff, gaussiandiff, friedrichsdiff, butterdiff
from .polynomial_fit import splinediff, polydiff, savgoldiff
from .basis_fit import spectraldiff, rbfdiff
from .basis_fit import spectraldiff, rbfdiff, waveletdiff
from .total_variation_regularization import iterative_velocity
from .kalman_smooth import kalman_filter, rts_smooth, rtsdiff, constant_velocity, constant_acceleration, constant_jerk
109 changes: 109 additions & 0 deletions pynumdiff/basis_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from warnings import warn
import numpy as np
from scipy import sparse
import pywt

from pynumdiff.utils import utility

Expand Down Expand Up @@ -133,3 +134,111 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0):
dxdt_hat_flattened = drbfdt @ alpha

return np.moveaxis(x_hat_flattened.reshape(plump), 0, axis), np.moveaxis(dxdt_hat_flattened.reshape(plump), 0, axis)


def waveletdiff(x, dt, wavelet='db8', level=None, threshold=1.0, axis=0, mode='periodization'):
"""Smooth and differentiate noisy data in a wavelet basis.

Three steps: (1) decompose x with the DWT and soft-threshold the detail
coefficients to denoise (Donoho-Johnstone universal threshold), reconstructing
a smoothed x_hat; (2) extend x_hat antisymmetrically so the periodic derivative
operator stays accurate at the edges; (3) recover the wavelet scaling
coefficients of x_hat and apply the analytic derivative of the wavelet basis.

The derivative differentiates the basis functions themselves rather than
finite-differencing the signal. PyWavelets treats the samples as finest-level
scaling coefficients, so x_hat is the interpolant x(t) = sum_n a_n phi(t/dt - n)
for the scaling function phi. Sampling x and its analytic derivative on the grid
gives two convolutions against phi and phi' evaluated at *integers*,

x_hat = Phi @ a and x' = Phi_prime @ a,

so x' = Phi_prime @ Phi^-1 @ x_hat, exact for signals the basis can represent.
The integer samples phi(p), phi'(p) are the eigenvalue-1 and eigenvalue-1/2
eigenvectors of the refinement relation phi(t) = sqrt2 sum_k h_k phi(2t - k)
(the "connection coefficients"), normalized to reproduce constants and ramps.

Because the DWT requires uniform spacing, this method only accepts a scalar

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Informative mention.

time step dt (not a vector of sample times). For non-uniformly sampled data,
use :func:`rbfdiff` or :func:`splinediff` instead.

:param np.array x: data to differentiate. May be multidimensional; see :code:`axis`.
:param float dt: uniform time step between samples.
:param str wavelet: PyWavelets wavelet name. Must have a differentiable scaling
function, so smoother wavelets give better derivatives: 'db8' (default) and
'sym8' are best for noisy data; 'db4', 'sym4', and 'coif2' also work well.
:param int level: decomposition depth. None (default) resolves to
min(pywt.dwt_max_level(N, wavelet), 5) to avoid over-decomposing short signals.
:param float threshold: soft-thresholding scale factor in [0, inf).
:param int axis: axis along which to differentiate (default 0).
:param str mode: PyWavelets signal extension mode for the denoising transform.
'periodization' keeps coefficient arrays compact. The derivative operator is
periodic, so x_hat is antisymmetrically extended before it is applied (see below).
:return: - **x_hat** (np.array) -- estimated (smoothed) x
- **dxdt_hat** (np.array) -- estimated derivative of x
"""
if not np.isscalar(dt):
raise ValueError("`dt` must be a scalar. The DWT requires uniformly sampled data. "
"For variable step sizes, use rbfdiff or splinediff instead.")

# The Haar scaling function is a step, so it has no pointwise derivative and the
# connection-coefficient operator below is undefined for it. Haar/db1 is the only
# orthonormal wavelet with a 2-tap filter, so dec_len identifies it.
if pywt.Wavelet(wavelet).dec_len == 2:
raise ValueError("The Haar/db1 wavelet has a discontinuous (piecewise-constant) scaling "
"function with no derivative, so it cannot be used to differentiate. Pick a smoother "
"wavelet such as 'db4', 'sym4', or 'coif2'.")

N = x.shape[axis]
x_work = np.ascontiguousarray(np.moveaxis(x, axis, 0)) # differentiation axis to front
shape = x_work.shape # remember it to restore the input's dimensionality
x_flat = x_work.reshape(N, -1) # rest of the dims flattened into columns
Ne = 3 * N - 2 # length after the antisymmetric extension in step 2

# Build the wavelet-basis derivative operator (depends only on the grid and wavelet).
# Sampling the refinement relation phi(t) = sqrt2 sum_k h_k phi(2t - k) at integers makes
# phi(p) the eigenvalue-1 and phi'(p) the eigenvalue-1/2 eigenvector of T[p,q] = sqrt2 h_{2p-q}.
h = np.array(pywt.Wavelet(wavelet).rec_lo); h = h / h.sum() * np.sqrt(2) # refinement filter, integral of phi = 1
L = len(h); p = np.arange(L) # phi is supported on the integers [0, L-1]
shift = 2 * p[:, None] - p[None, :]
T = np.where((shift >= 0) & (shift < L), np.sqrt(2) * h[np.clip(shift, 0, L - 1)], 0.0)
evals, evecs = np.linalg.eig(T)
phi = np.real(evecs[:, np.argmin(np.abs(evals - 1.0))]); phi /= phi.sum() # sum_p phi(p) = 1
dphi = np.real(evecs[:, np.argmin(np.abs(evals - 0.5))]); dphi /= np.dot(p, dphi)*-1 # sum_p p*phi'(p) = -1
# Phi and Phi_prime hold circulant samples of phi and phi'/dt on the extended grid; both
# share a common shift that cancels in Phi_prime @ Phi^-1, so the offset choice is cosmetic.
rows, cols, phi_vals, dphi_vals = [], [], [], []
m = np.arange(Ne)
for offset, phi_p, dphi_p in zip(p, phi, dphi / dt):
rows.extend(m); cols.extend((m - offset) % Ne); phi_vals.extend([phi_p]*Ne); dphi_vals.extend([dphi_p]*Ne)
Phi = sparse.csr_matrix((phi_vals, (rows, cols)), shape=(Ne, Ne)).tocsc() # to invert
Phi_prime = sparse.csr_matrix((dphi_vals, (rows, cols)), shape=(Ne, Ne)) # to apply

if level is None:
level = min(pywt.dwt_max_level(N, wavelet), 5)

# 1. Denoise: DWT all columns at once, then soft-threshold the detail bands. The
# noise level is estimated robustly per column from the finest details (coeffs[-1]).
coeffs = pywt.wavedec(x_flat, wavelet, level=level, mode=mode, axis=0)
sigma = np.maximum(np.median(np.abs(coeffs[-1]), axis=0) / 0.6745, 1e-10)
thresh = threshold * sigma * np.sqrt(2 * np.log(N))
coeffs = [coeffs[0]] + [pywt.threshold(c, thresh[np.newaxis, :], mode='soft') for c in coeffs[1:]]
x_hat = pywt.waverec(coeffs, wavelet, mode=mode, axis=0)[:N]

# 2. The derivative operator is periodic, but x_hat usually isn't. Extend it
# antisymmetrically (reflect through each endpoint: x[-1-k] -> 2*x[0]-x[1+k]) so the
# periodic wrap is continuous in both value and slope, which keeps the derivative
# accurate at the edges instead of spiking there. This is the odd-symmetry analog of
# spectraldiff's even extension; a ramp extends to a ramp, so slopes survive exactly.
left = 2 * x_hat[0] - x_hat[1:][::-1]
right = 2 * x_hat[-1] - x_hat[:-1][::-1]
x_ext = np.concatenate([left, x_hat, right], axis=0) # length 3N-2, original at [N-1:2N-1]

# 3. Differentiate the basis: recover the scaling coefficients a = Phi^-1 @ x_ext, then
# apply the analytic basis derivative dxdt = Phi_prime @ a, and crop back to the original.
a = sparse.linalg.spsolve(Phi, x_ext)
dxdt_flat = (Phi_prime @ a.reshape(Ne, -1))[N - 1:2 * N - 1]

x_hat = np.moveaxis(x_hat.reshape(shape), 0, axis)
dxdt_hat = np.moveaxis(dxdt_flat.reshape(shape), 0, axis)
return x_hat, dxdt_hat
11 changes: 10 additions & 1 deletion pynumdiff/tests/test_diff_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..smooth_finite_difference import kerneldiff, mediandiff, meandiff, gaussiandiff, friedrichsdiff, butterdiff
from ..finite_difference import finitediff, first_order, second_order, fourth_order
from ..polynomial_fit import polydiff, savgoldiff, splinediff
from ..basis_fit import spectraldiff, rbfdiff
from ..basis_fit import spectraldiff, rbfdiff, waveletdiff
from ..total_variation_regularization import velocity, acceleration, jerk, iterative_velocity, smooth_acceleration, tvrdiff
from ..kalman_smooth import rtsdiff, constant_velocity, constant_acceleration, constant_jerk, robustdiff
from ..linear_model import lineardiff
Expand Down Expand Up @@ -51,6 +51,7 @@ def polydiff_irreg_step(*args, **kwargs): return polydiff(*args, **kwargs)
(spline_irreg_step, {'degree':5, 's':2}),
(spectraldiff, {'high_freq_cutoff':0.2}), (spectraldiff, [0.2]),
(rbfdiff, {'sigma':0.5, 'lmbd':0.001}),
(waveletdiff, {'wavelet':'db8', 'threshold':1.0}),
(constant_velocity, {'r':1e-2, 'q':1e3}), (constant_velocity, [1e-2, 1e3]),
(constant_acceleration, {'r':1e-3, 'q':1e4}), (constant_acceleration, [1e-3, 1e4]),
(constant_jerk, {'r':1e-4, 'q':1e5}), (constant_jerk, [1e-4, 1e5]),
Expand Down Expand Up @@ -173,6 +174,12 @@ def polydiff_irreg_step(*args, **kwargs): return polydiff(*args, **kwargs)
[(-2, -2), (0, 0), (0, -1), (0, 0)],
[(0, 0), (2, 2), (0, 0), (2, 2)],
[(1, 1), (3, 3), (1, 1), (3, 3)]],
waveletdiff: [[(-15, -15), (-13, -13), (0, -1), (1, 0)],
[(-2, -2), (-1, -1), (0, 0), (1, 1)],
[(-2, -2), (-1, -1), (0, 0), (1, 1)],
[(-3, -3), (-1, -1), (0, 0), (1, 1)],
[(0, -1), (2, 2), (0, 0), (2, 2)],
[(0, -1), (3, 3), (0, 0), (3, 3)]],
velocity: [[(-25, -25), (-18, -19), (0, -1), (1, 0)],
[(-12, -12), (-11, -12), (-1, -1), (-1, -2)],
[(0, -1), (1, 0), (0, -1), (1, 0)],
Expand Down Expand Up @@ -327,6 +334,7 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
(finitediff, {}),
(polydiff, {'degree': 2, 'window_size': 5}),
(savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3}),
(waveletdiff, {'wavelet': 'db8', 'threshold': 1.0}),
(rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True}),
(spectraldiff, {'high_freq_cutoff': 0.25, 'pad_to_zero_dxdt': False}),
(rbfdiff, {'sigma': 0.5, 'lmbd': 1e-6}),
Expand All @@ -343,6 +351,7 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
kerneldiff: [(2, 1), (3, 2)],
butterdiff: [(0, -1), (1, -1)],
finitediff: [(0, -1), (1, -1)],
waveletdiff: [(1, 0), (2, 2)],
polydiff: [(1, -1), (1, 0)],
savgoldiff: [(0, -1), (1, 1)],
rtsdiff: [(1, -1), (1, 0)],
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ classifiers = [
dependencies = [
"numpy",
"scipy",
"matplotlib"
"matplotlib",
"pywavelets"
]

[project.urls]
Expand Down
Loading