Skip to content

Commit b2eae85

Browse files
committed
I made spectraldiff capable of handling multidimensional data. The 4 test cases for multidimensionality pass now, but when we test it in the multidimensionality_demo it is very sensitive. Also, we made sure that spectraldiff is now in the multidimensionality_demo.
1 parent 49631ee commit b2eae85

2 files changed

Lines changed: 55 additions & 33 deletions

File tree

notebooks/6_multidimensionality_demo.ipynb

Lines changed: 11 additions & 9 deletions
Large diffs are not rendered by default.

pynumdiff/basis_fit.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55

66
from pynumdiff.utils import utility
77

8+
#maria spectral diff below
89

9-
def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_extension=True, pad_to_zero_dxdt=True):
10+
def spectraldiff(x, dt, axis=0, params=None, options=None, high_freq_cutoff=None,
11+
even_extension=True, pad_to_zero_dxdt=True):
1012
"""Take a derivative in the Fourier domain, with high frequency attentuation.
1113
1214
:param np.array[float] x: data to differentiate
@@ -18,59 +20,77 @@ def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_e
1820
and 1. Frequencies below this threshold will be kept, and above will be zeroed.
1921
:param bool even_extension: if True, extend the data with an even extension so signal starts and ends at the same value.
2022
:param bool pad_to_zero_dxdt: if True, extend the data with extra regions that smoothly force the derivative to
23+
2124
zero before taking FFT.
2225
2326
:return: - **x_hat** (np.array) -- estimated (smoothed) x
2427
- **dxdt_hat** (np.array) -- estimated derivative of x
2528
"""
26-
if params is not None: # Warning to support old interface for a while. Remove these lines along with params in a future release.
29+
if params is not None:
2730
warn("`params` and `options` parameters will be removed in a future version. Use `high_freq_cutoff`, " +
28-
"`even_extension`, and `pad_to_zero_dxdt` instead.", DeprecationWarning)
31+
"`even_extension`, and `pad_to_zero_dxdt` instead.", DeprecationWarning)
2932
high_freq_cutoff = params[0] if isinstance(params, list) else params
3033
if options is not None:
3134
if 'even_extension' in options: even_extension = options['even_extension']
3235
if 'pad_to_zero_dxdt' in options: pad_to_zero_dxdt = options['pad_to_zero_dxdt']
3336
elif high_freq_cutoff is None:
3437
raise ValueError("`high_freq_cutoff` must be given.")
3538

36-
L = len(x)
39+
x = np.asarray(x)
40+
x0 = np.moveaxis(x, axis, 0) # move time axis to the front of the array
41+
# now x0 dims are (# of data points, # of signals)
42+
L = x0.shape[0]
3743

3844
# make derivative go to zero at ends (optional)
3945
if pad_to_zero_dxdt:
4046
padding = 100
41-
pre = getattr(x, 'values', x)[0]*np.ones(padding) # getattr to use .values if x is a pandas Series
42-
post = getattr(x, 'values', x)[-1]*np.ones(padding)
43-
x = np.hstack((pre, x, post)) # extend the edges
47+
48+
# just pad first and last values x100
49+
first = x0[0:1]
50+
last = x0[-1:]
51+
pre = np.repeat(first, padding, axis=0)
52+
post = np.repeat(last, padding, axis=0)
53+
54+
xpad = np.concatenate((pre, x0, post), axis=0) # i think hstack won't work with the correct axis
55+
4456
kernel = utility.mean_kernel(padding//2)
45-
x_hat = utility.convolutional_smoother(x, kernel) # smooth the edges in
46-
x_hat[padding:-padding] = x[padding:-padding] # replace middle with original signal
47-
x = x_hat
57+
x_hat0 = utility.convolutional_smoother(xpad, kernel, axis=0)
58+
59+
x_hat0[padding:-padding] = xpad[padding:-padding]
60+
x0 = x_hat0
4861
else:
4962
padding = 0
5063

51-
# Do even extension (optional)
64+
# Do even extension (optional):
5265
if even_extension is True:
53-
x = np.hstack((x, x[::-1]))
66+
x0 = np.concatenate((x0, x0[::-1, ...]), axis=0)
5467

5568
# Form wavenumbers
56-
N = len(x)
69+
N = x0.shape[0]
5770
k = np.concatenate((np.arange(N//2 + 1), np.arange(-N//2 + 1, 0)))
58-
if N % 2 == 0: k[N//2] = 0 # odd derivatives get the Nyquist element zeroed out
71+
if N % 2 == 0: k[N//2] = 0 # odd derivatives get the Nyquist element zeroed out
5972

6073
# Filter to zero out higher wavenumbers
61-
discrete_cutoff = int(high_freq_cutoff*N/2) # Nyquist is at N/2 location, and we're cutting off as a fraction of that
74+
discrete_cutoff = int(high_freq_cutoff * N / 2) # Nyquist is at N/2 location, and we're cutting off as a fraction of that
75+
filt = np.ones_like(k, dtype=float)
6276
filt = np.ones(k.shape); filt[discrete_cutoff:N-discrete_cutoff] = 0
77+
filt = filt.reshape((N,) + (1,)*(x0.ndim-1))
6378

64-
# Smoothed signal
65-
X = np.fft.fft(x)
66-
x_hat = np.real(np.fft.ifft(filt * X))
67-
x_hat = x_hat[padding:L+padding]
79+
# Smoothed signal
80+
X = np.fft.fft(x0, axis=0)
6881

69-
# Derivative = 90 deg phase shift
70-
omega = 2*np.pi/(dt*N) # factor of 2pi/T turns wavenumbers into frequencies in radians/s
71-
dxdt_hat = np.real(np.fft.ifft(1j * k * omega * filt * X))
72-
dxdt_hat = dxdt_hat[padding:L+padding]
82+
x_hat0 = np.real(np.fft.ifft(filt * X, axis=0))
83+
x_hat0 = x_hat0[padding:L+padding]
7384

85+
# Derivative = 90 deg phase shift
86+
omega = 2*np.pi/(dt*N)
87+
k0 = k.reshape((N,) + (1,)*(x0.ndim-1))
88+
dxdt0 = np.real(np.fft.ifft(1j * k0 * omega * filt * X, axis=0))
89+
dxdt0 = dxdt0[padding:L+padding]
90+
# move back to original axis position
91+
x_hat = np.moveaxis(x_hat0, 0, axis)
92+
dxdt_hat = np.moveaxis(dxdt0, 0, axis)
93+
7494
return x_hat, dxdt_hat
7595

7696

@@ -82,7 +102,7 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01):
82102
83103
:param np.array[float] x: data to differentiate
84104
:param float or array[float] dt_or_t: This function supports variable step size. This parameter is either the constant
85-
:math:`\\Delta t` if given as a single float, or data locations if given as an array of same length as :code:`x`.
105+
:math:`\\Delta t` if given as a single float, or data locations if given as an array of same length as :code:`x`.
86106
:param float sigma: controls width of radial basis functions
87107
:param float lmbd: controls smoothness
88108

0 commit comments

Comments
 (0)