|
| 1 | +import warnings |
| 2 | +from itertools import repeat, zip_longest |
| 3 | +from typing import Sequence |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from scipy.fftpack import fft, fftfreq, ifft, ifftshift |
| 7 | +from scipy.optimize import curve_fit |
| 8 | +from scipy.signal import fftconvolve, lfilter, lfiltic, tf2zpk, zpk2sos, zpk2tf |
| 9 | + |
| 10 | + |
| 11 | +def shift(signal: np.ndarray, delay: float, dt: float) -> np.ndarray: |
| 12 | + """ |
| 13 | + delay a signal |
| 14 | +
|
| 15 | + Args: |
| 16 | + signal (np.ndarray): input signal |
| 17 | + delay (float): delayed time |
| 18 | + dt (float): time step of signal samples |
| 19 | +
|
| 20 | + Returns: |
| 21 | + np.ndarray: delayed signal |
| 22 | + """ |
| 23 | + points = int(delay // dt) |
| 24 | + delta = delay / dt - points |
| 25 | + |
| 26 | + if delta > 0: |
| 27 | + ker = np.array([0, 1 - delta, delta]) |
| 28 | + signal = np.convolve(signal, ker, mode='same') |
| 29 | + |
| 30 | + if points == 0: |
| 31 | + return signal |
| 32 | + |
| 33 | + ret = np.zeros_like(signal) |
| 34 | + if points < 0: |
| 35 | + ret[:points] = signal[-points:] |
| 36 | + else: |
| 37 | + ret[points:] = signal[:-points] |
| 38 | + return ret |
| 39 | + |
| 40 | + |
| 41 | +def extractKernel(sig_in, sig_out, sample_rate, bw=None, skip=0): |
| 42 | + corr = fft(sig_in) / fft(sig_out) |
| 43 | + ker = np.real(ifftshift(ifft(corr))) |
| 44 | + if bw is not None and bw < 0.5 * sample_rate: |
| 45 | + k = np.exp(-0.5 * np.linspace(-3.0, 3.0, int(2 * sample_rate / bw))**2) |
| 46 | + ker = np.convolve(ker, k / k.sum(), mode='same') |
| 47 | + return ker[int(skip):len(ker) - int(skip)] |
| 48 | + |
| 49 | + |
| 50 | +def zDistortKernel(dt: float, params: Sequence[tuple]) -> np.ndarray: |
| 51 | + t = 3 * np.asarray(params)[:, 0].max() |
| 52 | + omega = 2 * np.pi * fftfreq(int(t / dt) + 1, dt) |
| 53 | + |
| 54 | + H = 1 |
| 55 | + for tau, A in params: |
| 56 | + H += (1j * A * omega * tau) / (1j * omega * tau + 1) |
| 57 | + |
| 58 | + ker = ifftshift(ifft(1 / H)).real |
| 59 | + return ker |
| 60 | + |
| 61 | + |
| 62 | +def high_pass_filter(tau, sample_rate): |
| 63 | + """ |
| 64 | + high pass filter |
| 65 | + """ |
| 66 | + k = 2.0 * tau * sample_rate |
| 67 | + a = [1.0, (1 - k) / (1 + k)] |
| 68 | + b = [k / (1 + k), -k / (1 + k)] |
| 69 | + return b, a |
| 70 | + |
| 71 | + |
| 72 | +def exp_decay_filter_old(amp, tau, sample_rate): |
| 73 | + """ |
| 74 | + exp decay filter |
| 75 | +
|
| 76 | + A |
| 77 | + H(w) = -------------------- |
| 78 | + 1 - 1j / (w * tau) |
| 79 | +
|
| 80 | + Args: |
| 81 | + amp (float): amplitude of the filter |
| 82 | + tau (float): decay time |
| 83 | + sample_rate (float): sampling rate |
| 84 | + """ |
| 85 | + |
| 86 | + alpha = 1 - np.exp(-1 / (abs(sample_rate * tau) * (1 + amp))) |
| 87 | + |
| 88 | + if amp >= 0: |
| 89 | + k = amp / (1 + amp - alpha) |
| 90 | + a = [(1 - k + k * alpha), -(1 - k) * (1 - alpha)] |
| 91 | + else: |
| 92 | + k = -amp / (1 + amp) / (1 - alpha) |
| 93 | + a = [(1 + k - k * alpha), -(1 + k) * (1 - alpha)] |
| 94 | + |
| 95 | + b = [1 / a[0], -(1 - alpha) / a[0]] |
| 96 | + a = [1, a[1] / a[0]] |
| 97 | + |
| 98 | + return b, a |
| 99 | + |
| 100 | + |
| 101 | +def exp_decay_filter(amp: float | Sequence[float], |
| 102 | + tau: float | Sequence[float], |
| 103 | + sample_rate: float, |
| 104 | + inv: bool = False, |
| 105 | + output='ba') -> tuple[np.ndarray, np.ndarray]: |
| 106 | + """ |
| 107 | + exp decay filter |
| 108 | +
|
| 109 | + Infinite impulse response as multiexponential decay. When input signal |
| 110 | + is the Heaviside theta function u(t), the output signal is: |
| 111 | + out(t) = u(t) * (1 - A_1 * exp(-t / tau_1) - A_2 * exp(-t / tau_2) ...) |
| 112 | + where A_i and tau_i are the amplitude and decay time of the i-th |
| 113 | + exponential decay. |
| 114 | +
|
| 115 | + The transfer function of the filter is: |
| 116 | +
|
| 117 | + H(w) = 1 - H_1(w) - H_2(w) - ... - H_n(w) |
| 118 | +
|
| 119 | + where |
| 120 | + A_i |
| 121 | + H_i(w) = -------------------------- |
| 122 | + 1 - 1 / (1j * w * tau_i) |
| 123 | +
|
| 124 | + Args: |
| 125 | + amp (float): amplitude of the filter |
| 126 | + tau (float): decay time |
| 127 | + sample_rate (float): sampling rate |
| 128 | + inv (bool): if True, the filter is inverted |
| 129 | + output (str): output type, 'ba' for numerator (b) and denominator (a) |
| 130 | + polynomials, 'sos' for second-order sections, 'zpk' for zeros (z), |
| 131 | + poles (p) and gain (k). See scipy.signal.lfilter for more. |
| 132 | +
|
| 133 | + Returns: |
| 134 | + tuple: (b, a) array like, numerator (b) and denominator (a) |
| 135 | + polynomials of the IIR filter. See scipy.signal.lfilter for more. |
| 136 | + """ |
| 137 | + |
| 138 | + if isinstance(amp, (int, float, complex)): |
| 139 | + amp = [amp] |
| 140 | + tau = [tau] |
| 141 | + numerator, denominator = np.poly1d([0.0]), np.poly1d([1.0]) |
| 142 | + for i, (A, t) in enumerate(zip(amp, tau)): |
| 143 | + denominator = denominator * np.poly1d([1, -1 / t]) |
| 144 | + n = np.poly1d([-A, 0.0]) |
| 145 | + for j, t_ in enumerate(tau): |
| 146 | + if j != i: |
| 147 | + n = n * np.poly1d([1, -1 / t_]) |
| 148 | + numerator = numerator + n |
| 149 | + numerator = numerator + denominator |
| 150 | + |
| 151 | + z = np.exp(-numerator.roots / sample_rate) |
| 152 | + p = np.exp(-denominator.roots / sample_rate) |
| 153 | + if inv: |
| 154 | + z, p = p, z |
| 155 | + k = numerator(0) / denominator(0) * np.prod(1 - p) / np.prod(1 - z) |
| 156 | + |
| 157 | + if output == 'sos': |
| 158 | + return zpk2sos(z, p, k) |
| 159 | + elif output == 'ba': |
| 160 | + return zpk2tf(z, p, k) |
| 161 | + elif output == 'zpk': |
| 162 | + return z, p, k |
| 163 | + |
| 164 | + |
| 165 | +def reflection_filter(f, A, tau): |
| 166 | + """ |
| 167 | + reflection filter |
| 168 | +
|
| 169 | + Infinite impulse response as reflection. When input signal |
| 170 | + is in(t), the output signal is: |
| 171 | + out(t) = in(t) + A * in(t - tau) + A^2 * in(t - 2 * tau) + ... |
| 172 | +
|
| 173 | + The transfer function of the filter is: |
| 174 | + 1 - A |
| 175 | + H(w) = ---------------------------- |
| 176 | + 1 - A * exp(- i * w * tau) |
| 177 | + Args: |
| 178 | + f (float): frequency |
| 179 | + A (float): amplitude of the reflection |
| 180 | + tau (float): delay time |
| 181 | + """ |
| 182 | + return (1 - A) / (1 - A * np.exp(-2j * np.pi * f * tau)) |
| 183 | + |
| 184 | + |
| 185 | +def reflection(sig, A, tau, sample_rate): |
| 186 | + freq = np.fft.fftfreq(len(sig), 1 / sample_rate) |
| 187 | + return np.fft.ifft(np.fft.fft(sig) * reflection_filter(freq, A, tau)).real |
| 188 | + |
| 189 | + |
| 190 | +def correct_reflection(sig, A, tau, sample_rate=None): |
| 191 | + from waveforms.waveform import Waveform |
| 192 | + |
| 193 | + if isinstance(sig, Waveform): |
| 194 | + return 1 / (1 - A) * sig - A / (1 - A) * (sig >> tau) |
| 195 | + if sample_rate is not None: |
| 196 | + freq = np.fft.fftfreq(len(sig), 1 / sample_rate) |
| 197 | + return np.fft.ifft(np.fft.fft(sig) / |
| 198 | + reflection_filter(freq, A, tau)).real |
| 199 | + else: |
| 200 | + raise ValueError('sample_rate is not given') |
| 201 | + |
| 202 | + |
| 203 | +def combine_filters( |
| 204 | + filters: list[tuple[np.ndarray, |
| 205 | + np.ndarray]]) -> tuple[np.ndarray, np.ndarray]: |
| 206 | + """ |
| 207 | + combine filters |
| 208 | +
|
| 209 | + Args: |
| 210 | + filters (list): list of (b, a) array like, numerator (b) and denominator |
| 211 | + (a) polynomials of the IIR filter. See scipy.signal.lfilter for more. |
| 212 | +
|
| 213 | + Returns: |
| 214 | + tuple: (b, a) array like, numerator (b) and denominator (a) |
| 215 | + polynomials of the combined filter. See scipy.signal.lfilter for more. |
| 216 | + """ |
| 217 | + b, a = np.poly1d([1.0]), np.poly1d([1.0]) |
| 218 | + for b_, a_ in filters: |
| 219 | + b = b * np.poly1d(b_) |
| 220 | + a = a * np.poly1d(a_) |
| 221 | + return b.coeffs, a.coeffs |
| 222 | + |
| 223 | + |
| 224 | +def factor_filter(b, a): |
| 225 | + """ |
| 226 | + factor filter |
| 227 | +
|
| 228 | + Args: |
| 229 | + b (array_like): numerator polynomial of the IIR filter. |
| 230 | + a (array_like): denominator polynomial of the IIR filter. |
| 231 | +
|
| 232 | + Returns: |
| 233 | + list: list of (b, a) array like, numerator (b) and denominator |
| 234 | + """ |
| 235 | + b, a = np.poly1d(b), np.poly1d(a) |
| 236 | + p = a.roots |
| 237 | + q = b.roots |
| 238 | + b_amp = (b[0] / a[0])**(1 / max(len(q), len(p))) |
| 239 | + filters = [] |
| 240 | + for a_, b_ in zip_longest(p, q, fillvalue=0): |
| 241 | + filters.append(([b_amp, -b_amp * b_], [1, -a_])) |
| 242 | + return filters |
| 243 | + |
| 244 | + |
| 245 | +def stable_filter(exp_decay_filters: list, sample_rate: float): |
| 246 | + """ |
| 247 | + check if the filter is stable |
| 248 | +
|
| 249 | + Args: |
| 250 | + exp_decay_filters (list): list of (amp, tau) pairs |
| 251 | + """ |
| 252 | + filters = [] |
| 253 | + for amp, tau in exp_decay_filters: |
| 254 | + a, b = exp_decay_filter(amp, tau, sample_rate) |
| 255 | + filters.append((b, a)) |
| 256 | + |
| 257 | + b, a = combine_filters(filters) |
| 258 | + z, p, k = tf2zpk(b, a) |
| 259 | + if np.all(np.abs(p) < 1): |
| 260 | + return True |
| 261 | + else: |
| 262 | + return False |
| 263 | + |
| 264 | + |
| 265 | +def predistort(sig: np.ndarray, |
| 266 | + filters: list = None, |
| 267 | + ker: np.ndarray = None, |
| 268 | + initial: float = 0.0, |
| 269 | + initial_x: np.ndarray | None = None, |
| 270 | + initial_y: np.ndarray | None = None, |
| 271 | + zi: np.ndarray | None = None, |
| 272 | + return_zf: bool = False) -> np.ndarray: |
| 273 | + if filters is not None: |
| 274 | + b, a = combine_filters(filters) |
| 275 | + z, p, k = tf2zpk(b, a) |
| 276 | + if np.all(np.abs(p) < 1): |
| 277 | + pass |
| 278 | + else: |
| 279 | + warnings.warn('Warning: filter is unstable') |
| 280 | + |
| 281 | + if zi is None: |
| 282 | + if initial_x is None: |
| 283 | + initial_x = np.full((len(b) - 1, ), initial) |
| 284 | + else: |
| 285 | + initial_x = np.asarray(initial_x)[:len(b) - 1] |
| 286 | + if initial_y is None: |
| 287 | + initial_y = np.full((len(a) - 1, ), initial) |
| 288 | + else: |
| 289 | + initial_y = np.asarray(initial_y)[:len(a) - 1] |
| 290 | + zi = lfiltic( |
| 291 | + b, |
| 292 | + a, |
| 293 | + initial_y, |
| 294 | + initial_x, |
| 295 | + ) |
| 296 | + sig, zf = lfilter(b, a, sig, zi=zi) |
| 297 | + |
| 298 | + if ker is None: |
| 299 | + if return_zf: |
| 300 | + return sig, zf |
| 301 | + else: |
| 302 | + return sig |
| 303 | + |
| 304 | + size = len(sig) |
| 305 | + sig = np.hstack((np.zeros_like(sig), sig, np.zeros_like(sig))) |
| 306 | + start = size + len(ker) // 2 |
| 307 | + stop = start + size |
| 308 | + points = fftconvolve(sig, ker, mode='full')[start:stop] |
| 309 | + if return_zf: |
| 310 | + return points, zf |
| 311 | + else: |
| 312 | + return points |
| 313 | + |
| 314 | + |
| 315 | +def distort(points, params, sample_rate, initial=0.0): |
| 316 | + filters = [] |
| 317 | + for amp, tau in np.asarray(params).reshape(-1, 2): |
| 318 | + b, a = exp_decay_filter(amp, abs(tau), sample_rate) |
| 319 | + filters.append((b, a)) |
| 320 | + return predistort(points, filters, initial=initial) |
| 321 | + |
| 322 | + |
| 323 | +def phase_curve(t, params, df_dphi, pulse_width, start, wav, sample_rate): |
| 324 | + lim = max(np.max(np.abs(t)), 20e-6) |
| 325 | + num = round(2 * lim * sample_rate) |
| 326 | + tlist = np.arange(num) / sample_rate - lim |
| 327 | + points = wav(tlist) |
| 328 | + |
| 329 | + pulse_points = round(pulse_width * sample_rate) |
| 330 | + start_points = round((start + pulse_width) * sample_rate) - 1 |
| 331 | + |
| 332 | + ker = np.hstack( |
| 333 | + [np.ones(pulse_points) / sample_rate, |
| 334 | + np.zeros(start_points)]) |
| 335 | + |
| 336 | + points = np.convolve(2 * np.pi * df_dphi * |
| 337 | + distort(points, params, sample_rate), |
| 338 | + ker, |
| 339 | + mode='same') |
| 340 | + return np.interp(t, tlist, points) |
| 341 | + |
| 342 | + |
| 343 | +if __name__ == '__main__': |
| 344 | + import matplotlib.pyplot as plt |
| 345 | + from waveforms import square |
| 346 | + |
| 347 | + data = np.load('Z_distortion.npz') |
| 348 | + |
| 349 | + x = data['time'] * 1e-6 |
| 350 | + y = data['phase'] |
| 351 | + df_dphi = 4343.313e6 |
| 352 | + |
| 353 | + sample_rate = 2e9 |
| 354 | + wav = 0.1 * (square(2e-6) << 1e-6) |
| 355 | + |
| 356 | + def f(t, *params): |
| 357 | + return phase_curve(t, params, df_dphi, 10e-9, 25e-9) |
| 358 | + |
| 359 | + params = [-0.03, 0.1e-6, 0.02, 0.3e-6] |
| 360 | + popt, pcov = curve_fit(f, x, y, p0=params) |
| 361 | + |
| 362 | + plt.plot(x / 1e-6, y, 'o') |
| 363 | + plt.semilogx( |
| 364 | + x / 1e-6, |
| 365 | + phase_curve(x, |
| 366 | + params, |
| 367 | + df_dphi, |
| 368 | + 10e-9, |
| 369 | + 0, |
| 370 | + wav=wav, |
| 371 | + sample_rate=sample_rate)) |
| 372 | + plt.plot(x / 1e-6, f(x, *popt)) |
| 373 | + |
| 374 | + plt.xlabel('delay [us]') |
| 375 | + plt.ylabel('phase') |
| 376 | + plt.show() |
0 commit comments