55
66from pynumdiff .utils import utility
77
8+ #maria spectral diff below
89
9- def spectraldiff (x , dt , params = None , options = None , high_freq_cutoff = None , even_extension = True , pad_to_zero_dxdt = True ):
10+ def spectraldiff (x , dt , axis = 0 , params = None , options = None , high_freq_cutoff = None ,
11+ even_extension = True , pad_to_zero_dxdt = True ):
1012 """Take a derivative in the Fourier domain, with high frequency attentuation.
1113
1214 :param np.array[float] x: data to differentiate
@@ -18,59 +20,77 @@ def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_e
1820 and 1. Frequencies below this threshold will be kept, and above will be zeroed.
1921 :param bool even_extension: if True, extend the data with an even extension so signal starts and ends at the same value.
2022 :param bool pad_to_zero_dxdt: if True, extend the data with extra regions that smoothly force the derivative to
23+
2124 zero before taking FFT.
2225
2326 :return: - **x_hat** (np.array) -- estimated (smoothed) x
2427 - **dxdt_hat** (np.array) -- estimated derivative of x
2528 """
26- if params is not None : # Warning to support old interface for a while. Remove these lines along with params in a future release.
29+ if params is not None :
2730 warn ("`params` and `options` parameters will be removed in a future version. Use `high_freq_cutoff`, " +
28- "`even_extension`, and `pad_to_zero_dxdt` instead." , DeprecationWarning )
31+ "`even_extension`, and `pad_to_zero_dxdt` instead." , DeprecationWarning )
2932 high_freq_cutoff = params [0 ] if isinstance (params , list ) else params
3033 if options is not None :
3134 if 'even_extension' in options : even_extension = options ['even_extension' ]
3235 if 'pad_to_zero_dxdt' in options : pad_to_zero_dxdt = options ['pad_to_zero_dxdt' ]
3336 elif high_freq_cutoff is None :
3437 raise ValueError ("`high_freq_cutoff` must be given." )
3538
36- L = len (x )
39+ x = np .asarray (x )
40+ x0 = np .moveaxis (x , axis , 0 ) # move time axis to the front of the array
41+ # now x0 dims are (# of data points, # of signals)
42+ L = x0 .shape [0 ]
3743
3844 # make derivative go to zero at ends (optional)
3945 if pad_to_zero_dxdt :
4046 padding = 100
41- pre = getattr (x , 'values' , x )[0 ]* np .ones (padding ) # getattr to use .values if x is a pandas Series
42- post = getattr (x , 'values' , x )[- 1 ]* np .ones (padding )
43- x = np .hstack ((pre , x , post )) # extend the edges
47+
48+ # just pad first and last values x100
49+ first = x0 [0 :1 ]
50+ last = x0 [- 1 :]
51+ pre = np .repeat (first , padding , axis = 0 )
52+ post = np .repeat (last , padding , axis = 0 )
53+
54+ xpad = np .concatenate ((pre , x0 , post ), axis = 0 ) # i think hstack won't work with the correct axis
55+
4456 kernel = utility .mean_kernel (padding // 2 )
45- x_hat = utility .convolutional_smoother (x , kernel ) # smooth the edges in
46- x_hat [padding :- padding ] = x [padding :- padding ] # replace middle with original signal
47- x = x_hat
57+ x_hat0 = utility .convolutional_smoother (xpad , kernel , axis = 0 )
58+
59+ x_hat0 [padding :- padding ] = xpad [padding :- padding ]
60+ x0 = x_hat0
4861 else :
4962 padding = 0
5063
51- # Do even extension (optional)
64+ # Do even extension (optional):
5265 if even_extension is True :
53- x = np .hstack (( x , x [::- 1 ]) )
66+ x0 = np .concatenate (( x0 , x0 [::- 1 , ...]), axis = 0 )
5467
5568 # Form wavenumbers
56- N = len ( x )
69+ N = x0 . shape [ 0 ]
5770 k = np .concatenate ((np .arange (N // 2 + 1 ), np .arange (- N // 2 + 1 , 0 )))
58- if N % 2 == 0 : k [N // 2 ] = 0 # odd derivatives get the Nyquist element zeroed out
71+ if N % 2 == 0 : k [N // 2 ] = 0 # odd derivatives get the Nyquist element zeroed out
5972
6073 # Filter to zero out higher wavenumbers
61- discrete_cutoff = int (high_freq_cutoff * N / 2 ) # Nyquist is at N/2 location, and we're cutting off as a fraction of that
74+ discrete_cutoff = int (high_freq_cutoff * N / 2 ) # Nyquist is at N/2 location, and we're cutting off as a fraction of that
75+ filt = np .ones_like (k , dtype = float )
6276 filt = np .ones (k .shape ); filt [discrete_cutoff :N - discrete_cutoff ] = 0
77+ filt = filt .reshape ((N ,) + (1 ,)* (x0 .ndim - 1 ))
6378
64- # Smoothed signal
65- X = np .fft .fft (x )
66- x_hat = np .real (np .fft .ifft (filt * X ))
67- x_hat = x_hat [padding :L + padding ]
79+ # Smoothed signal
80+ X = np .fft .fft (x0 , axis = 0 )
6881
69- # Derivative = 90 deg phase shift
70- omega = 2 * np .pi / (dt * N ) # factor of 2pi/T turns wavenumbers into frequencies in radians/s
71- dxdt_hat = np .real (np .fft .ifft (1j * k * omega * filt * X ))
72- dxdt_hat = dxdt_hat [padding :L + padding ]
82+ x_hat0 = np .real (np .fft .ifft (filt * X , axis = 0 ))
83+ x_hat0 = x_hat0 [padding :L + padding ]
7384
85+ # Derivative = 90 deg phase shift
86+ omega = 2 * np .pi / (dt * N )
87+ k0 = k .reshape ((N ,) + (1 ,)* (x0 .ndim - 1 ))
88+ dxdt0 = np .real (np .fft .ifft (1j * k0 * omega * filt * X , axis = 0 ))
89+ dxdt0 = dxdt0 [padding :L + padding ]
90+ # move back to original axis position
91+ x_hat = np .moveaxis (x_hat0 , 0 , axis )
92+ dxdt_hat = np .moveaxis (dxdt0 , 0 , axis )
93+
7494 return x_hat , dxdt_hat
7595
7696
@@ -82,7 +102,7 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01):
82102
83103 :param np.array[float] x: data to differentiate
84104 :param float or array[float] dt_or_t: This function supports variable step size. This parameter is either the constant
85- :math:`\\ Delta t` if given as a single float, or data locations if given as an array of same length as :code:`x`.
105+ :math:`\\ Delta t` if given as a single float, or data locations if given as an array of same length as :code:`x`.
86106 :param float sigma: controls width of radial basis functions
87107 :param float lmbd: controls smoothness
88108
0 commit comments