Skip to content

Commit f1008dd

Browse files
committed
perf: run tracking cross-correlation in parallel threads
1 parent 9952e24 commit f1008dd

1 file changed

Lines changed: 61 additions & 40 deletions

File tree

src/pybox/gui/tracking_plots.py

Lines changed: 61 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class _TrackingResult:
2727
time: np.ndarray
2828
gyros: list[np.ndarray] # [roll, pitch, yaw]
2929
setpoints: list[np.ndarray] # [roll, pitch, yaw]
30-
delays: list[float] # ms per axis
30+
delays: list[float | None] # ms per axis
3131

3232

3333
AXIS_NAMES = ["Roll", "Pitch", "Yaw"]
@@ -37,43 +37,28 @@ class _TrackingWorker(QThread):
3737
"""Compute cross-correlation delays and downsample in a background thread."""
3838
finished = Signal(object) # _TrackingResult or None
3939

40-
def __init__(self, entry: LogEntry, parent=None):
40+
def __init__(self, entry: LogEntry, axis: int, setpoints: np.ndarray, gyros: np.ndarray, parent=None):
4141
super().__init__(parent)
4242
self._entry = entry
43+
self._axis = axis
44+
self._setpoints = setpoints
45+
self._gyros = gyros
4346

4447
def run(self):
4548
entry = self._entry
46-
time_s, gyro_r, gyro_p, gyro_y = entry.gyro_arrays()
47-
sp_r, sp_p, sp_y = entry.setpoint_arrays()
49+
time_s, _, _, _ = entry.gyro_arrays()
4850

4951
if len(time_s) == 0:
50-
self.finished.emit(None)
52+
self.finished.emit([self._axis, None])
5153
return
5254

5355
mask = entry.time_mask()
5456
t_full = time_s[mask]
55-
gyros_full = [gyro_r[mask], gyro_p[mask], gyro_y[mask]]
56-
setpoints_full = [sp_r[mask], sp_p[mask], sp_y[mask]]
5757

5858
dt = np.median(np.diff(t_full)) if len(t_full) > 2 else 0.001
59-
delays = []
60-
for axis in range(3):
61-
delays.append(TrackingPlots._xcorr_delay(setpoints_full[axis], gyros_full[axis], dt))
62-
63-
# Downsample for plotting
64-
t = t_full
65-
gyros = gyros_full
66-
setpoints = setpoints_full
67-
max_pts = 20000
68-
if len(t) > max_pts:
69-
step = len(t) // max_pts
70-
t = t[::step]
71-
gyros = [g[::step] for g in gyros]
72-
setpoints = [s[::step] for s in setpoints]
59+
delays = TrackingPlots._xcorr_delay(self._setpoints, self._gyros, dt)
7360

74-
self.finished.emit(_TrackingResult(
75-
time=t, gyros=gyros, setpoints=setpoints, delays=delays,
76-
))
61+
self.finished.emit([self._axis, delays])
7762

7863

7964
class TrackingPlots(QWidget):
@@ -112,7 +97,8 @@ def __init__(self, parent=None):
11297
self._curves: dict[tuple[str, int], pg.PlotDataItem] = {}
11398
# key: ("gyro"|"sp", axis)
11499
self._cache_key: tuple | None = None
115-
self._worker: _TrackingWorker | None = None
100+
self._workers: list[_TrackingWorker | None] | None = []
101+
self._results: _TrackingResult | None = None
116102

117103
for i, name in enumerate(AXIS_NAMES):
118104
if i > 0:
@@ -133,18 +119,42 @@ def __init__(self, parent=None):
133119
for i in range(1, 3):
134120
self._plots[i].setXLink(self._plots[0])
135121

122+
def _prepare_data(self, entry: LogEntry) -> tuple[list[np.ndarray], list[np.ndarray]]:
123+
mask = entry.time_mask()
124+
time_s, gyro_r, gyro_p, gyro_y = entry.gyro_arrays()
125+
sp_r, sp_p, sp_y = entry.setpoint_arrays()
126+
t_full = time_s[mask]
127+
gyros_full = [gyro_r[mask], gyro_p[mask], gyro_y[mask]]
128+
setpoints_full = [sp_r[mask], sp_p[mask], sp_y[mask]]
129+
130+
# Downsample for plotting
131+
t = t_full
132+
gyros = gyros_full
133+
setpoints = setpoints_full
134+
max_pts = 20000
135+
if len(t) > max_pts:
136+
step = len(t) // max_pts
137+
t = t[::step]
138+
gyros = [g[::step] for g in gyros]
139+
setpoints = [s[::step] for s in setpoints]
140+
141+
self._results = _TrackingResult(time=t, gyros=gyros, setpoints=setpoints, delays=[None, None, None])
142+
143+
return gyros_full, setpoints_full
144+
136145
@staticmethod
137146
def _entry_key(entry: LogEntry) -> tuple:
138147
return (entry.file_path, entry.log_index, entry.time_start_s, entry.time_end_s)
139148

140149
def update_plots(self, entry: LogEntry | None):
141150
"""Display setpoint/gyro tracking for the given (selected) log entry."""
142151
# Cancel any running worker
143-
if self._worker is not None and self._worker.isRunning():
144-
self._worker.finished.disconnect()
145-
self._worker.quit()
146-
self._worker.wait(500)
147-
self._worker = None
152+
for worker in self._workers:
153+
if worker is not None and worker.isRunning():
154+
worker.finished.disconnect()
155+
worker.quit()
156+
worker.wait(500)
157+
self._workers = []
148158

149159
if entry is not None:
150160
key = self._entry_key(entry)
@@ -159,11 +169,14 @@ def update_plots(self, entry: LogEntry | None):
159169
if entry is None:
160170
return
161171

172+
gyros_full, setpoints_full = self._prepare_data(entry)
173+
162174
# Show loading indicator and start background computation
163175
self._info_label.setText(self.tr("Computing tracking delays..."))
164-
self._worker = _TrackingWorker(entry, parent=self)
165-
self._worker.finished.connect(self._on_worker_finished)
166-
self._worker.start()
176+
for i in range(3):
177+
self._workers.append(_TrackingWorker(entry, i, setpoints_full[i], gyros_full[i], parent=self))
178+
self._workers[i].finished.connect(self._on_worker_finished)
179+
self._workers[i].start()
167180

168181
@staticmethod
169182
def _xcorr_delay(setpoint: np.ndarray, gyro: np.ndarray, dt: float) -> float:
@@ -183,28 +196,35 @@ def _xcorr_delay(setpoint: np.ndarray, gyro: np.ndarray, dt: float) -> float:
183196
lag = int(np.argmax(search))
184197
return lag * dt * 1000.0 # ms
185198

186-
def _on_worker_finished(self, result: _TrackingResult | None):
199+
def _on_worker_finished(self, result: tuple[int, list[float] | None]):
187200
"""Called on main thread when background computation is done."""
188-
self._worker = None
189-
if result is None:
201+
axis = result[0]
202+
self._workers[axis] = None
203+
if result[1] is None:
190204
self._info_label.setText("")
191205
return
192-
self._draw_result(result)
206+
self._results.delays[axis] = result[1]
207+
self._draw_result(self._results)
193208

194209
def _draw_result(self, r: _TrackingResult):
195210
"""Draw pre-computed tracking data (runs on main thread)."""
196211
axis_colors = ["#e6194b", "#3cb44b", "#4363d8"] # R, G, B
197212
t_theme = current_theme()
198213

199214
for axis in range(3):
200-
if len(r.gyros[axis]) < 10:
215+
# If this axis is already contained in self._curves, skip it
216+
if self._curves.get(("gyro", axis)) is not None:
217+
continue
218+
219+
if len(r.gyros[axis]) < 10 or r.delays[axis] is None:
201220
continue
202221

203222
color = axis_colors[axis]
204223

205224
# Update title with delay
225+
time_str = f"{r.delays[axis]:.1f} ms" if r.delays[axis] is not None else "N/A"
206226
self._plots[axis].setTitle(
207-
f"{AXIS_NAMES[axis]} – delay: {r.delays[axis]:.1f} ms",
227+
f"{AXIS_NAMES[axis]} – delay: {time_str}",
208228
color=t_theme.plot_title_color, size="12pt",
209229
)
210230

@@ -223,8 +243,9 @@ def _draw_result(self, r: _TrackingResult):
223243
self._curves[("gyro", axis)] = gyro_curve
224244

225245
# Update info label
246+
time_str = [f"{r.delays[axis]:.1f} ms" if r.delays [axis] is not None else "N/A" for axis in range(3)]
226247
self._info_label.setText(
227-
f"Delay – Roll: {r.delays[0]:.1f} ms | Pitch: {r.delays[1]:.1f} ms | Yaw: {r.delays[2]:.1f} ms"
248+
f"Delay – {AXIS_NAMES[0]}: {time_str[0]} | {AXIS_NAMES[1]}: {time_str[1]} | {AXIS_NAMES[2]}: {time_str[2]}"
228249
)
229250

230251
# Auto-range

0 commit comments

Comments
 (0)