-
Notifications
You must be signed in to change notification settings - Fork 55
Expand file tree
/
Copy pathasr.py
More file actions
executable file
·457 lines (379 loc) · 16.3 KB
/
asr.py
File metadata and controls
executable file
·457 lines (379 loc) · 16.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
"""Utils for ASR functions."""
import numpy as np
from numpy import linalg
from scipy import signal
from scipy.linalg import toeplitz
from scipy.spatial.distance import cdist, euclidean
from scipy.special import gamma, gammaincinv
SHAPE_RANGE = np.linspace(1.7, 3.5, 13)
def fit_eeg_distribution(X, min_clean_fraction=0.25, max_dropout_fraction=0.1,
fit_quantiles=[0.022, 0.6], step_sizes=[0.0220, 0.6000],
shape_range=SHAPE_RANGE):
"""Estimate the mean and SD of clean EEG from contaminated data.
This function estimates the mean and standard deviation of clean EEG from a
sample of amplitude values (that have preferably been computed over short
windows) that may include a large fraction of contaminated samples. The
clean EEG is assumed to represent a generalized Gaussian component in a
mixture with near-arbitrary artifact components. By default, at least 25%
(``min_clean_fraction``) of the data must be clean EEG, and the rest can be
contaminated. No more than 10% (``max_dropout_fraction``) of the data is
allowed to come from contaminations that cause lower-than-EEG amplitudes
(e.g., sensor unplugged). There are no restrictions on artifacts causing
larger-than-EEG amplitudes, i.e., virtually anything is handled (with the
exception of a very unlikely type of distribution that combines with the
clean EEG samples into a larger symmetric generalized Gaussian peak and
thereby "fools" the estimator). The default parameters should work for a
wide range of applications but may be adapted to accommodate special
circumstances.
The method works by fitting a truncated generalized Gaussian whose
parameters are constrained by ``min_clean_fraction``,
``max_dropout_fraction``, ``fit_quantiles``, and ``shape_range``. The fit
is performed by a grid search that always finds a close-to-optimal solution
if the above assumptions are fulfilled.
Parameters
----------
X : array, shape=(n_channels, n_samples)
EEG data, possibly containing artifacts.
min_clean_fraction : float
Minimum fraction that needs to be clean. This is the minimum fraction
of time windows that need to contain essentially uncontaminated EEG
(default=0.1).
max_dropout_fraction : float
Maximum fraction that can have dropouts. This is the maximum fraction
of time windows that may have arbitrarily low amplitude (e.g., due to
the sensors being unplugged) (default=0.25).
fit_quantiles : 2-tuple
Quantile range [lower,upper] of the truncated generalized Gaussian
distribution that shall be fit to the EEG contents (default=[0.022
0.6]).
step_sizes : 2-tuple
Step size of the grid search; the first value is the stepping of the
lower bound (which essentially steps over any dropout samples), and the
second value is the stepping over possible scales (i.e., clean-data
quantiles) (default=[0.01, 0.01]).
shape_range : array
Range that the clean EEG distribution's shape parameter beta may take.
Returns
-------
mu : array
Estimated mean of the clean EEG distribution.
sig : array
Estimated standard deviation of the clean EEG distribution.
alpha : float
Estimated scale parameter of the generalized Gaussian clean EEG
distribution.
beta : float
Estimated shape parameter of the generalized Gaussian clean EEG
distribution.
"""
# sort data so we can access quantiles directly
X = np.sort(X)
n = len(X)
# compute z bounds for the truncated standard generalized Gaussian pdf and
# pdf rescaler
quants = np.array(fit_quantiles)
zbounds = []
rescale = []
for b in range(len(shape_range)):
gam = gammaincinv(
1 / shape_range[b], np.sign(quants - 1 / 2) * (2 * quants - 1))
zbounds.append(np.sign(quants - 1 / 2) * gam ** (1 / shape_range[b]))
rescale.append(shape_range[b] / (2 * gamma(1 / shape_range[b])))
# determine the quantile-dependent limits for the grid search
# we can generally skip the tail below the lower quantile
lower_min = np.min(quants)
# maximum width is the fit interval if all data is clean
max_width = np.diff(quants)[0]
# minimum width of the fit interval, as fraction of data
min_width = min_clean_fraction * max_width
# Build quantile interval matrix
cols = np.arange(lower_min,
lower_min + max_dropout_fraction + step_sizes[0] * 1e-9,
step_sizes[0])
cols = np.round(n * cols).astype(int)
rows = np.arange(0, int(np.round(n * max_width)))
newX = np.zeros((len(rows), len(cols)))
for i, c in enumerate(range(len(rows))):
newX[i] = X[c + cols]
# subtract baseline value for each interval
X1 = newX[0, :]
newX = newX - X1
opt_val = np.inf
opt_val = np.inf
opt_lu = np.inf
opt_bounds = np.inf
opt_beta = np.inf
gridsearch = np.round(n * np.arange(max_width, min_width, -step_sizes[1]))
for m in gridsearch.astype(int):
mcurr = m - 1
nbins = int(np.round(3 * np.log2(1 + m / 2)))
cols = nbins / newX[mcurr]
H = newX[:m] * cols
hist_all = []
for ih in range(len(cols)):
histcurr = np.histogram(H[:, ih], bins=np.arange(0, nbins + 1))
hist_all.append(histcurr[0])
hist_all = np.array(hist_all, dtype=int).T
hist_all = np.vstack((hist_all, np.zeros(len(cols), dtype=int)))
logq = np.log(hist_all + 0.01)
# for each shape value...
for k, b in enumerate(shape_range):
bounds = zbounds[k]
x = bounds[0] + np.arange(0.5, nbins + 0.5) / nbins * np.diff(bounds) # noqa:E501
p = np.exp(-np.abs(x) ** b) * rescale[k]
p = p / np.sum(p)
# calc KL divergences
kl = np.sum(p * (np.log(p) - logq[:-1, :].T), axis=1) + np.log(m)
# update optimal parameters
min_val = np.min(kl)
idx = np.argmin(kl)
if min_val < opt_val:
opt_val = min_val
opt_beta = shape_range[k]
opt_bounds = bounds
opt_lu = [X1[idx], X1[idx] + newX[m - 1, idx]]
# recover distribution parameters at optimum
alpha = (opt_lu[1] - opt_lu[0]) / np.diff(opt_bounds)
mu = opt_lu[0] - opt_bounds[0] * alpha
beta = opt_beta
# calculate the distribution's standard deviation from alpha and beta
sig = np.sqrt((alpha ** 2) * gamma(3 / beta) / gamma(1 / beta))
# Ensure scalar values are returned (extract from arrays if needed)
alpha = float(np.asarray(alpha).squeeze())
mu = float(np.asarray(mu).squeeze())
sig = float(np.asarray(sig).squeeze())
return mu, sig, alpha, beta
def yulewalk(order, F, M):
"""Recursive filter design using a least-squares method.
[B,A] = YULEWALK(N,F,M) finds the N-th order recursive filter
coefficients B and A such that the filter:
B(z) b(1) + b(2)z^-1 + .... + b(n)z^-(n-1)
---- = -------------------------------------
A(z) 1 + a(1)z^-1 + .... + a(n)z^-(n-1)
matches the magnitude frequency response given by vectors F and M.
The YULEWALK function performs a least squares fit in the time domain. The
denominator coefficients {a(1),...,a(NA)} are computed by the so called
"modified Yule Walker" equations, using NR correlation coefficients
computed by inverse Fourier transformation of the specified frequency
response H.
The numerator is computed by a four step procedure. First, a numerator
polynomial corresponding to an additive decomposition of the power
frequency response is computed. Next, the complete frequency response
corresponding to the numerator and denominator polynomials is evaluated.
Then a spectral factorization technique is used to obtain the impulse
response of the filter. Finally, the numerator polynomial is obtained by a
least squares fit to this impulse response. For a more detailed explanation
of the algorithm see [1]_.
Parameters
----------
order : int
Filter order.
F : array
Normalised frequency breakpoints for the filter. The frequencies in F
must be between 0.0 and 1.0, with 1.0 corresponding to half the sample
rate. They must be in increasing order and start with 0.0 and end with
1.0.
M : array
Magnitude breakpoints for the filter such that PLOT(F,M) would show a
plot of the desired frequency response.
References
----------
.. [1] B. Friedlander and B. Porat, "The Modified Yule-Walker Method of
ARMA Spectral Estimation," IEEE Transactions on Aerospace Electronic
Systems, Vol. AES-20, No. 2, pp. 158-173, March 1984.
Examples
--------
Design an 8th-order lowpass filter and overplot the desired
frequency response with the actual frequency response:
>>> f = [0, .6, .6, 1] # Frequency breakpoints
>>> m = [1, 1, 0, 0] # Magnitude breakpoints
>>> [b, a] = yulewalk(8, f, m) # Filter design using a least-squares method
"""
F = np.asarray(F)
M = np.asarray(M)
npt = 512
lap = np.fix(npt / 25).astype(int)
mf = F.size
npt = npt + 1 # For [dc 1 2 ... nyquist].
Ht = np.array(np.zeros((1, npt)))
nint = mf - 1
df = np.diff(F)
nb = 0
Ht[0][0] = M[0]
for i in range(nint):
if df[i] == 0:
nb = nb - int(lap / 2)
ne = nb + lap
else:
ne = int(np.fix(F[i + 1] * npt)) - 1
j = np.arange(nb, ne + 1)
if ne == nb:
inc = 0
else:
inc = (j - nb) / (ne - nb)
Ht[0][nb:ne + 1] = np.array(inc * M[i + 1] + (1 - inc) * M[i])
nb = ne + 1
Ht = np.concatenate((Ht, Ht[0][-2:0:-1]), axis=None)
n = Ht.size
n2 = np.fix((n + 1) / 2)
nb = order
nr = 4 * order
nt = np.arange(0, nr)
# compute correlation function of magnitude squared response
R = np.real(np.fft.ifft(Ht * Ht))
R = R[0:nr] * (0.54 + 0.46 * np.cos(np.pi * nt / (nr - 1))) # pick NR correlations # noqa
# Form window to be used in extracting the right "wing" of two-sided
# covariance sequence
Rwindow = np.concatenate(
(1 / 2, np.ones((1, int(n2 - 1))), np.zeros((1, int(n - n2)))),
axis=None)
A = polystab(denf(R, order)) # compute denominator
# compute additive decomposition
Qh = numf(np.concatenate((R[0] / 2, R[1:nr]), axis=None), A, order)
# compute impulse response
_, Ss = 2 * np.real(signal.freqz(Qh, A, worN=n, whole=True))
hh = np.fft.ifft(
np.exp(np.fft.fft(Rwindow * np.fft.ifft(np.log(Ss, dtype=complex))))
)
B = np.real(numf(hh[0:nr], A, nb))
return B, A
def yulewalk_filter(X, sfreq, zi=None, ab=None, axis=-1):
"""Yulewalk filter.
Parameters
----------
X : array, shape = (n_channels, n_samples)
Data to filter.
sfreq : float
Sampling frequency.
zi : array, shape=(n_channels, filter_order)
Initial conditions.
a, b : 2-tuple | None
Coefficients of an IIR filter that is used to shape the spectrum of the
signal when calculating artifact statistics. The output signal does not
go through this filter. This is an optional way to tune the sensitivity
of the algorithm to each frequency component of the signal. The default
filter is less sensitive at alpha and beta frequencies and more
sensitive at delta (blinks) and gamma (muscle) frequencies.
axis : int
Axis to filter on (default=-1, corresponding to samples).
Returns
-------
out : array
Filtered data.
zf : array, shape=(n_channels, filter_order)
Output filter state.
"""
[C, S] = X.shape
if ab is None:
F = np.array([0, 2, 3, 13, 16, 40, np.minimum(
80.0, (sfreq / 2.0) - 1.0), sfreq / 2.0]) * 2.0 / sfreq
M = np.array([3, 0.75, 0.33, 0.33, 1, 1, 3, 3])
B, A = yulewalk(8, F, M)
else:
A, B = ab
# apply the signal shaping filter and initialize the IIR filter state
if zi is None:
zi = signal.lfilter_zi(B, A)
zi = np.transpose(X[:, 0] * zi[:, None])
out, zf = signal.lfilter(B, A, X, zi=zi, axis=axis)
else:
out, zf = signal.lfilter(B, A, X, zi=zi, axis=axis)
return out, zf
def geometric_median(X, tol=1e-5, max_iter=500):
"""Geometric median.
This code is adapted from [2]_ using the Vardi and Zhang algorithm
described in [1]_.
Parameters
----------
X : array, shape=(n_observations, n_variables)
The data.
tol : float
Tolerance (default=1.e-5)
max_iter : int
Max number of iterations (default=500):
Returns
-------
y1 : array, shape=(n_variables,)
Geometric median over X.
References
----------
.. [1] Vardi, Y., & Zhang, C. H. (2000). The multivariate L1-median and
associated data depth. Proceedings of the National Academy of Sciences,
97(4), 1423-1426. https://doi.org/10.1073/pnas.97.4.1423
.. [2] https://stackoverflow.com/questions/30299267/
"""
y = np.mean(X, 0) # initial value
i = 0
while i < max_iter:
D = cdist(X, [y])
nonzeros = (D != 0)[:, 0]
Dinv = 1. / D[nonzeros]
Dinvs = np.sum(Dinv)
W = Dinv / Dinvs
T = np.sum(W * X[nonzeros], 0)
num_zeros = len(X) - np.sum(nonzeros)
if num_zeros == 0:
y1 = T
elif num_zeros == len(X):
return y
else:
R = (T - y) * Dinvs
r = np.linalg.norm(R)
rinv = 0 if r == 0 else num_zeros / r
y1 = max(0, 1 - rinv) * T + min(1, rinv) * y
if euclidean(y, y1) < tol:
return y1
y = y1
i += 1
else:
print(f"Geometric median could converge in {i} iterations "
f"with a tolerance of {tol}")
def polystab(a):
"""Polynomial stabilization.
POLYSTAB(A), where A is a vector of polynomial coefficients,
stabilizes the polynomial with respect to the unit circle;
roots whose magnitudes are greater than one are reflected
inside the unit circle.
Examples
--------
Convert a linear-phase filter into a minimum-phase filter with the same
magnitude response.
>>> h = fir1(25,0.4); # Window-based FIR filter design
>>> flag_linphase = islinphase(h) # Determines if filter is linear phase
>>> hmin = polystab(h) * norm(h)/norm(polystab(h));
>>> flag_minphase = isminphase(hmin)# Determines if filter is minimum phase
"""
v = np.roots(a)
i = np.where(v != 0)
vs = 0.5 * (np.sign(np.abs(v[i]) - 1) + 1)
v[i] = (1 - vs) * v[i] + vs / np.conj(v[i])
ind = np.where(a != 0)
b = a[ind[0][0]] * np.poly(v)
# Return only real coefficients if input was real:
if not np.sum(np.imag(a)):
b = np.real(b)
return b
def numf(h, a, nb):
"""Find numerator B given impulse-response h of B/A and denominator A.
NB is the numerator order. This function is used by YULEWALK.
"""
nh = np.max(h.size)
xn = np.concatenate((1, np.zeros((1, nh - 1))), axis=None)
impr = signal.lfilter(np.array([1.0]), a, xn)
b = linalg.lstsq(
toeplitz(impr, np.concatenate((1, np.zeros((1, nb))), axis=None)),
h.T, rcond=None)[0].T
# Ensure 1D array is returned
return np.atleast_1d(b.squeeze())
def denf(R, na):
"""Compute denominator from covariances.
A = DENF(R,NA) computes order NA denominator A from covariances
R(0)...R(nr) using the Modified Yule-Walker method. This function is used
by YULEWALK.
"""
nr = np.max(np.size(R))
Rm = toeplitz(R[na:nr - 1], R[na:0:-1])
Rhs = - R[na + 1:nr]
A = np.concatenate(
(1, linalg.lstsq(Rm, Rhs.T, rcond=None)[0].T), axis=None)
return A