Skip to content

Commit ef6e347

Browse files
committed
add distortion
1 parent 79532d6 commit ef6e347

2 files changed

Lines changed: 377 additions & 1 deletion

File tree

waveforms/distortion.py

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
import warnings
2+
from itertools import repeat, zip_longest
3+
from typing import Sequence
4+
5+
import numpy as np
6+
from scipy.fftpack import fft, fftfreq, ifft, ifftshift
7+
from scipy.optimize import curve_fit
8+
from scipy.signal import fftconvolve, lfilter, lfiltic, tf2zpk, zpk2sos, zpk2tf
9+
10+
11+
def shift(signal: np.ndarray, delay: float, dt: float) -> np.ndarray:
12+
"""
13+
delay a signal
14+
15+
Args:
16+
signal (np.ndarray): input signal
17+
delay (float): delayed time
18+
dt (float): time step of signal samples
19+
20+
Returns:
21+
np.ndarray: delayed signal
22+
"""
23+
points = int(delay // dt)
24+
delta = delay / dt - points
25+
26+
if delta > 0:
27+
ker = np.array([0, 1 - delta, delta])
28+
signal = np.convolve(signal, ker, mode='same')
29+
30+
if points == 0:
31+
return signal
32+
33+
ret = np.zeros_like(signal)
34+
if points < 0:
35+
ret[:points] = signal[-points:]
36+
else:
37+
ret[points:] = signal[:-points]
38+
return ret
39+
40+
41+
def extractKernel(sig_in, sig_out, sample_rate, bw=None, skip=0):
42+
corr = fft(sig_in) / fft(sig_out)
43+
ker = np.real(ifftshift(ifft(corr)))
44+
if bw is not None and bw < 0.5 * sample_rate:
45+
k = np.exp(-0.5 * np.linspace(-3.0, 3.0, int(2 * sample_rate / bw))**2)
46+
ker = np.convolve(ker, k / k.sum(), mode='same')
47+
return ker[int(skip):len(ker) - int(skip)]
48+
49+
50+
def zDistortKernel(dt: float, params: Sequence[tuple]) -> np.ndarray:
51+
t = 3 * np.asarray(params)[:, 0].max()
52+
omega = 2 * np.pi * fftfreq(int(t / dt) + 1, dt)
53+
54+
H = 1
55+
for tau, A in params:
56+
H += (1j * A * omega * tau) / (1j * omega * tau + 1)
57+
58+
ker = ifftshift(ifft(1 / H)).real
59+
return ker
60+
61+
62+
def high_pass_filter(tau, sample_rate):
63+
"""
64+
high pass filter
65+
"""
66+
k = 2.0 * tau * sample_rate
67+
a = [1.0, (1 - k) / (1 + k)]
68+
b = [k / (1 + k), -k / (1 + k)]
69+
return b, a
70+
71+
72+
def exp_decay_filter_old(amp, tau, sample_rate):
73+
"""
74+
exp decay filter
75+
76+
A
77+
H(w) = --------------------
78+
1 - 1j / (w * tau)
79+
80+
Args:
81+
amp (float): amplitude of the filter
82+
tau (float): decay time
83+
sample_rate (float): sampling rate
84+
"""
85+
86+
alpha = 1 - np.exp(-1 / (abs(sample_rate * tau) * (1 + amp)))
87+
88+
if amp >= 0:
89+
k = amp / (1 + amp - alpha)
90+
a = [(1 - k + k * alpha), -(1 - k) * (1 - alpha)]
91+
else:
92+
k = -amp / (1 + amp) / (1 - alpha)
93+
a = [(1 + k - k * alpha), -(1 + k) * (1 - alpha)]
94+
95+
b = [1 / a[0], -(1 - alpha) / a[0]]
96+
a = [1, a[1] / a[0]]
97+
98+
return b, a
99+
100+
101+
def exp_decay_filter(amp: float | Sequence[float],
102+
tau: float | Sequence[float],
103+
sample_rate: float,
104+
inv: bool = False,
105+
output='ba') -> tuple[np.ndarray, np.ndarray]:
106+
"""
107+
exp decay filter
108+
109+
Infinite impulse response as multiexponential decay. When input signal
110+
is the Heaviside theta function u(t), the output signal is:
111+
out(t) = u(t) * (1 - A_1 * exp(-t / tau_1) - A_2 * exp(-t / tau_2) ...)
112+
where A_i and tau_i are the amplitude and decay time of the i-th
113+
exponential decay.
114+
115+
The transfer function of the filter is:
116+
117+
H(w) = 1 - H_1(w) - H_2(w) - ... - H_n(w)
118+
119+
where
120+
A_i
121+
H_i(w) = --------------------------
122+
1 - 1 / (1j * w * tau_i)
123+
124+
Args:
125+
amp (float): amplitude of the filter
126+
tau (float): decay time
127+
sample_rate (float): sampling rate
128+
inv (bool): if True, the filter is inverted
129+
output (str): output type, 'ba' for numerator (b) and denominator (a)
130+
polynomials, 'sos' for second-order sections, 'zpk' for zeros (z),
131+
poles (p) and gain (k). See scipy.signal.lfilter for more.
132+
133+
Returns:
134+
tuple: (b, a) array like, numerator (b) and denominator (a)
135+
polynomials of the IIR filter. See scipy.signal.lfilter for more.
136+
"""
137+
138+
if isinstance(amp, (int, float, complex)):
139+
amp = [amp]
140+
tau = [tau]
141+
numerator, denominator = np.poly1d([0.0]), np.poly1d([1.0])
142+
for i, (A, t) in enumerate(zip(amp, tau)):
143+
denominator = denominator * np.poly1d([1, -1 / t])
144+
n = np.poly1d([-A, 0.0])
145+
for j, t_ in enumerate(tau):
146+
if j != i:
147+
n = n * np.poly1d([1, -1 / t_])
148+
numerator = numerator + n
149+
numerator = numerator + denominator
150+
151+
z = np.exp(-numerator.roots / sample_rate)
152+
p = np.exp(-denominator.roots / sample_rate)
153+
if inv:
154+
z, p = p, z
155+
k = numerator(0) / denominator(0) * np.prod(1 - p) / np.prod(1 - z)
156+
157+
if output == 'sos':
158+
return zpk2sos(z, p, k)
159+
elif output == 'ba':
160+
return zpk2tf(z, p, k)
161+
elif output == 'zpk':
162+
return z, p, k
163+
164+
165+
def reflection_filter(f, A, tau):
166+
"""
167+
reflection filter
168+
169+
Infinite impulse response as reflection. When input signal
170+
is in(t), the output signal is:
171+
out(t) = in(t) + A * in(t - tau) + A^2 * in(t - 2 * tau) + ...
172+
173+
The transfer function of the filter is:
174+
1 - A
175+
H(w) = ----------------------------
176+
1 - A * exp(- i * w * tau)
177+
Args:
178+
f (float): frequency
179+
A (float): amplitude of the reflection
180+
tau (float): delay time
181+
"""
182+
return (1 - A) / (1 - A * np.exp(-2j * np.pi * f * tau))
183+
184+
185+
def reflection(sig, A, tau, sample_rate):
186+
freq = np.fft.fftfreq(len(sig), 1 / sample_rate)
187+
return np.fft.ifft(np.fft.fft(sig) * reflection_filter(freq, A, tau)).real
188+
189+
190+
def correct_reflection(sig, A, tau, sample_rate=None):
191+
from waveforms.waveform import Waveform
192+
193+
if isinstance(sig, Waveform):
194+
return 1 / (1 - A) * sig - A / (1 - A) * (sig >> tau)
195+
if sample_rate is not None:
196+
freq = np.fft.fftfreq(len(sig), 1 / sample_rate)
197+
return np.fft.ifft(np.fft.fft(sig) /
198+
reflection_filter(freq, A, tau)).real
199+
else:
200+
raise ValueError('sample_rate is not given')
201+
202+
203+
def combine_filters(
204+
filters: list[tuple[np.ndarray,
205+
np.ndarray]]) -> tuple[np.ndarray, np.ndarray]:
206+
"""
207+
combine filters
208+
209+
Args:
210+
filters (list): list of (b, a) array like, numerator (b) and denominator
211+
(a) polynomials of the IIR filter. See scipy.signal.lfilter for more.
212+
213+
Returns:
214+
tuple: (b, a) array like, numerator (b) and denominator (a)
215+
polynomials of the combined filter. See scipy.signal.lfilter for more.
216+
"""
217+
b, a = np.poly1d([1.0]), np.poly1d([1.0])
218+
for b_, a_ in filters:
219+
b = b * np.poly1d(b_)
220+
a = a * np.poly1d(a_)
221+
return b.coeffs, a.coeffs
222+
223+
224+
def factor_filter(b, a):
225+
"""
226+
factor filter
227+
228+
Args:
229+
b (array_like): numerator polynomial of the IIR filter.
230+
a (array_like): denominator polynomial of the IIR filter.
231+
232+
Returns:
233+
list: list of (b, a) array like, numerator (b) and denominator
234+
"""
235+
b, a = np.poly1d(b), np.poly1d(a)
236+
p = a.roots
237+
q = b.roots
238+
b_amp = (b[0] / a[0])**(1 / max(len(q), len(p)))
239+
filters = []
240+
for a_, b_ in zip_longest(p, q, fillvalue=0):
241+
filters.append(([b_amp, -b_amp * b_], [1, -a_]))
242+
return filters
243+
244+
245+
def stable_filter(exp_decay_filters: list, sample_rate: float):
246+
"""
247+
check if the filter is stable
248+
249+
Args:
250+
exp_decay_filters (list): list of (amp, tau) pairs
251+
"""
252+
filters = []
253+
for amp, tau in exp_decay_filters:
254+
a, b = exp_decay_filter(amp, tau, sample_rate)
255+
filters.append((b, a))
256+
257+
b, a = combine_filters(filters)
258+
z, p, k = tf2zpk(b, a)
259+
if np.all(np.abs(p) < 1):
260+
return True
261+
else:
262+
return False
263+
264+
265+
def predistort(sig: np.ndarray,
266+
filters: list = None,
267+
ker: np.ndarray = None,
268+
initial: float = 0.0,
269+
initial_x: np.ndarray | None = None,
270+
initial_y: np.ndarray | None = None,
271+
zi: np.ndarray | None = None,
272+
return_zf: bool = False) -> np.ndarray:
273+
if filters is not None:
274+
b, a = combine_filters(filters)
275+
z, p, k = tf2zpk(b, a)
276+
if np.all(np.abs(p) < 1):
277+
pass
278+
else:
279+
warnings.warn('Warning: filter is unstable')
280+
281+
if zi is None:
282+
if initial_x is None:
283+
initial_x = np.full((len(b) - 1, ), initial)
284+
else:
285+
initial_x = np.asarray(initial_x)[:len(b) - 1]
286+
if initial_y is None:
287+
initial_y = np.full((len(a) - 1, ), initial)
288+
else:
289+
initial_y = np.asarray(initial_y)[:len(a) - 1]
290+
zi = lfiltic(
291+
b,
292+
a,
293+
initial_y,
294+
initial_x,
295+
)
296+
sig, zf = lfilter(b, a, sig, zi=zi)
297+
298+
if ker is None:
299+
if return_zf:
300+
return sig, zf
301+
else:
302+
return sig
303+
304+
size = len(sig)
305+
sig = np.hstack((np.zeros_like(sig), sig, np.zeros_like(sig)))
306+
start = size + len(ker) // 2
307+
stop = start + size
308+
points = fftconvolve(sig, ker, mode='full')[start:stop]
309+
if return_zf:
310+
return points, zf
311+
else:
312+
return points
313+
314+
315+
def distort(points, params, sample_rate, initial=0.0):
316+
filters = []
317+
for amp, tau in np.asarray(params).reshape(-1, 2):
318+
b, a = exp_decay_filter(amp, abs(tau), sample_rate)
319+
filters.append((b, a))
320+
return predistort(points, filters, initial=initial)
321+
322+
323+
def phase_curve(t, params, df_dphi, pulse_width, start, wav, sample_rate):
324+
lim = max(np.max(np.abs(t)), 20e-6)
325+
num = round(2 * lim * sample_rate)
326+
tlist = np.arange(num) / sample_rate - lim
327+
points = wav(tlist)
328+
329+
pulse_points = round(pulse_width * sample_rate)
330+
start_points = round((start + pulse_width) * sample_rate) - 1
331+
332+
ker = np.hstack(
333+
[np.ones(pulse_points) / sample_rate,
334+
np.zeros(start_points)])
335+
336+
points = np.convolve(2 * np.pi * df_dphi *
337+
distort(points, params, sample_rate),
338+
ker,
339+
mode='same')
340+
return np.interp(t, tlist, points)
341+
342+
343+
if __name__ == '__main__':
344+
import matplotlib.pyplot as plt
345+
from waveforms import square
346+
347+
data = np.load('Z_distortion.npz')
348+
349+
x = data['time'] * 1e-6
350+
y = data['phase']
351+
df_dphi = 4343.313e6
352+
353+
sample_rate = 2e9
354+
wav = 0.1 * (square(2e-6) << 1e-6)
355+
356+
def f(t, *params):
357+
return phase_curve(t, params, df_dphi, 10e-9, 25e-9)
358+
359+
params = [-0.03, 0.1e-6, 0.02, 0.3e-6]
360+
popt, pcov = curve_fit(f, x, y, p0=params)
361+
362+
plt.plot(x / 1e-6, y, 'o')
363+
plt.semilogx(
364+
x / 1e-6,
365+
phase_curve(x,
366+
params,
367+
df_dphi,
368+
10e-9,
369+
0,
370+
wav=wav,
371+
sample_rate=sample_rate))
372+
plt.plot(x / 1e-6, f(x, *popt))
373+
374+
plt.xlabel('delay [us]')
375+
plt.ylabel('phase')
376+
plt.show()

waveforms/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
"""Define version number here and read it from setup.py automatically"""
2-
__version__ = "2.0.1"
2+
__version__ = "2.0.2"

0 commit comments

Comments
 (0)