forked from farr/LineCleaner
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathline_cleaner.py
More file actions
109 lines (82 loc) · 3.92 KB
/
line_cleaner.py
File metadata and controls
109 lines (82 loc) · 3.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
import scipy.signal as sig
def line_model_freq(freqs, d_re, d_im, f_line_guess, f_prior_width, gamma_min, gamma_max, estimate_line=False):
f0 = numpyro.sample('f0', dist.TruncatedNormal(f_line_guess, f_prior_width, low=f_line_guess - 5*f_prior_width, high=f_line_guess + 5*f_prior_width))
gamma = numpyro.sample('gamma', dist.Uniform(gamma_min, gamma_max))
sigma = numpyro.sample('sigma', dist.Exponential(1))
jitter = numpyro.sample('jitter', dist.Exponential(1))
tau = numpyro.deterministic('tau', 1/gamma)
A2 = 2.0*jnp.pi*jnp.square(f0)*gamma*jnp.square(sigma)
w = 2*jnp.pi*freqs
w0 = 2*jnp.pi*f0
lor_var = numpyro.deterministic('lor_var', A2 / (jnp.square(jnp.square(w0) - jnp.square(w)) + 4*jnp.square(w0)*jnp.square(gamma)))
jvar = jnp.square(jitter)
data_var = lor_var + jvar
data_sd = jnp.sqrt(data_var)
numpyro.sample('obs_re', dist.Normal(0, data_sd), obs=d_re)
numpyro.sample('obs_im', dist.Normal(0, data_sd), obs=d_im)
if estimate_line:
mean_wt = 1 / (1 + jvar/lor_var)
sd = jnp.sqrt(1 / (1/jvar + 1/lor_var))
numpyro.sample('line_re', dist.Normal(d_re*mean_wt, sd))
numpyro.sample('line_im', dist.Normal(d_im*mean_wt, sd))
def clean_strain(times, data, srate, f0s, bandwidths, Twindow, mcmc_seed=None, resample_seed=None, return_mcmcs=False):
if mcmc_seed is None:
mcmc_seed = np.random.randint(1<<32)
if resample_seed is None:
resample_seed = np.random.randint(1<<32)
mcmc_rng_key = random.PRNGKey(mcmc_seed)
resample_rng_key = random.PRNGKey(resample_seed)
window = sig.windows.tukey(len(data), alpha=Twindow*srate/len(data))
data_freq = np.fft.rfft(data*window)/srate
data_freq_residual = np.copy(data_freq)
data_freq_re = np.real(data_freq)
data_freq_im = np.imag(data_freq)
freqs = np.fft.rfftfreq(len(data), 1/srate)
if return_mcmcs:
mcmcs = []
pred_sampless = []
for (f0, bandwidth) in zip(f0s, bandwidths):
sel = np.abs(freqs - f0) < bandwidth/2
scale_factor = np.sqrt(np.trapz(np.square(np.abs(data_freq[sel])), freqs[sel]))
nuts_kernel = NUTS(line_model_freq, dense_mass=True)
mcmc = MCMC(
nuts_kernel,
num_warmup=1000,
num_samples=1000,
num_chains=4
)
mcmc_rng_key, mk = random.split(mcmc_rng_key)
mcmc.run(
mk,
freqs[sel],
data_freq_re[sel]/scale_factor,
data_freq_im[sel]/scale_factor,
f0,
bandwidth/10,
1/(times[-1]-times[0]),
f0/2
)
resample_rng_key, rk = random.split(resample_rng_key)
pred = Predictive(line_model_freq, posterior_samples=mcmc.get_samples())
pred_samples = pred(rk, freqs[sel], data_freq_re[sel]/scale_factor, data_freq_im[sel]/scale_factor, f0, bandwidth/10, 1/(times[-1]-times[0]), f0/2, estimate_line=True)
resample_rng_key, rk = random.split(resample_rng_key)
ind = random.randint(rk, (1,), 0, pred_samples['line_re'].shape[0])[0]
line_model_re = scale_factor*pred_samples['line_re'][ind,:]
line_model_im = scale_factor*pred_samples['line_im'][ind,:]
data_freq_residual[sel] = data_freq_re[sel] - line_model_re + 1j*(data_freq_im[sel] - line_model_im)
if return_mcmcs:
mcmcs.append(mcmc)
pred_sampless.append(pred_samples)
data_residual = np.fft.irfft(data_freq_residual)/(times[-1]-times[0])*len(data)
data_residual = data_residual[window==1]
times_residual = times[window==1]
if return_mcmcs:
return (times_residual, data_residual, mcmcs, pred_sampless, line_model_im, line_model_re, sel, data_freq_re)
else:
return (times_residual, data_residual)