From b538ba4d979bd63143e4e06d8f5baf42e146e510 Mon Sep 17 00:00:00 2001 From: Daniel Song Date: Sun, 1 Mar 2026 22:42:35 -0800 Subject: [PATCH 1/3] fix: remove CPython internals from PriorityThreadPoolExecutor (#4) Replace the _WorkItem/_work_queue subclass hack with a dispatcher-wrapper pattern: a PriorityQueue feeds work in priority order to a standard ThreadPoolExecutor via a daemon thread. This eliminates all CPython internal dependencies and fixes compatibility with Python 3.14+. Co-Authored-By: Claude Opus 4.6 --- data_diff/thread_utils.py | 61 ++++++++++------ tests/test_thread_utils.py | 146 +++++++++++++++++++++++++++++++++++++ 2 files changed, 185 insertions(+), 22 deletions(-) create mode 100644 tests/test_thread_utils.py diff --git a/data_diff/thread_utils.py b/data_diff/thread_utils.py index 8eb2187e..cbc7252b 100644 --- a/data_diff/thread_utils.py +++ b/data_diff/thread_utils.py @@ -1,42 +1,59 @@ import itertools +import threading from collections import deque from collections.abc import Callable, Iterable, Iterator -from concurrent.futures import ThreadPoolExecutor -from concurrent.futures.thread import _WorkItem +from concurrent.futures import Future, ThreadPoolExecutor from queue import PriorityQueue from time import sleep from typing import Any import attrs +_SENTINEL = object() -class AutoPriorityQueue(PriorityQueue): - """Overrides PriorityQueue to automatically get the priority from _WorkItem.kwargs - We also assign a unique id for each item, to avoid making comparisons on _WorkItem. - As a side effect, items with the same priority are returned FIFO. - """ +def _chain_future(source: Future, dest: Future) -> None: + """Propagate the result or exception from source to dest.""" + if source.cancelled(): + dest.cancel() + elif exc := source.exception(): + dest.set_exception(exc) + else: + dest.set_result(source.result()) - _counter = itertools.count().__next__ - def put(self, item: _WorkItem | None, block=True, timeout=None) -> None: - priority = item.kwargs.pop("priority") if item is not None else 0 - super().put((-priority, self._counter(), item), block, timeout) +class PriorityThreadPoolExecutor: + """Thread pool that executes tasks in priority order. - def get(self, block=True, timeout=None) -> _WorkItem | None: - _p, _c, work_item = super().get(block, timeout) - return work_item + Uses a dispatcher thread to pull work from a PriorityQueue and + submit it to a standard ThreadPoolExecutor. No CPython internals. + """ + def __init__(self, max_workers: int | None = None) -> None: + self._inner = ThreadPoolExecutor(max_workers=max_workers) + self._queue: PriorityQueue = PriorityQueue() + self._counter = itertools.count().__next__ + self._dispatcher = threading.Thread(target=self._dispatch, daemon=True) + self._dispatcher.start() -class PriorityThreadPoolExecutor(ThreadPoolExecutor): - """Overrides ThreadPoolExecutor to use AutoPriorityQueue + def _dispatch(self) -> None: + while True: + _priority, _count, item = self._queue.get() + if item is _SENTINEL: + break + fn, args, kwargs, proxy = item + inner_future = self._inner.submit(fn, *args, **kwargs) + inner_future.add_done_callback(lambda f, p=proxy: _chain_future(f, p)) - XXX WARNING: Might break in future versions of Python - """ + def submit(self, fn, /, *args, priority: int = 0, **kwargs) -> Future: + proxy = Future() + self._queue.put((-priority, self._counter(), (fn, args, kwargs, proxy))) + return proxy - def __init__(self, *args) -> None: - super().__init__(*args) - self._work_queue = AutoPriorityQueue() + def shutdown(self, wait: bool = True) -> None: + self._queue.put((0, self._counter(), _SENTINEL)) + self._dispatcher.join() + self._inner.shutdown(wait=wait) @attrs.define(frozen=False, init=False) @@ -47,7 +64,7 @@ class ThreadedYielder(Iterable): Priority for the iterator can be provided via the keyword argument 'priority'. (higher runs first) """ - _pool: ThreadPoolExecutor + _pool: PriorityThreadPoolExecutor _futures: deque _yield: deque = attrs.field(alias="_yield") # Python keyword! _exception: None = None diff --git a/tests/test_thread_utils.py b/tests/test_thread_utils.py new file mode 100644 index 00000000..ceb0807a --- /dev/null +++ b/tests/test_thread_utils.py @@ -0,0 +1,146 @@ +import threading + +import pytest + +from data_diff.thread_utils import PriorityThreadPoolExecutor, ThreadedYielder + + +class TestPriorityThreadPoolExecutor: + def test_priority_ordering(self): + """Higher-priority tasks execute before lower-priority ones.""" + gate = threading.Event() + results = [] + + pool = PriorityThreadPoolExecutor(max_workers=1) + + # Block the single worker so tasks queue up + pool.submit(lambda: gate.wait(), priority=0) + + # Submit tasks with different priorities while worker is blocked + for p in [1, 3, 2]: + pool.submit(lambda p=p: results.append(p), priority=p) + + # Release the gate — queued tasks run in priority order + gate.set() + pool.shutdown(wait=True) + + assert results == [3, 2, 1] + + def test_fifo_within_same_priority(self): + """Equal-priority tasks run in submission order (FIFO).""" + gate = threading.Event() + results = [] + + pool = PriorityThreadPoolExecutor(max_workers=1) + pool.submit(lambda: gate.wait(), priority=0) + + for i in range(5): + pool.submit(lambda i=i: results.append(i), priority=1) + + gate.set() + pool.shutdown(wait=True) + + assert results == [0, 1, 2, 3, 4] + + def test_submit_returns_future_with_result(self): + """submit() returns a Future that resolves to the function's return value.""" + pool = PriorityThreadPoolExecutor(max_workers=2) + future = pool.submit(lambda: 42) + assert future.result(timeout=5) == 42 + pool.shutdown() + + def test_submit_returns_future_with_exception(self): + """Exceptions in submitted functions propagate through the Future.""" + pool = PriorityThreadPoolExecutor(max_workers=2) + future = pool.submit(lambda: 1 / 0) + with pytest.raises(ZeroDivisionError): + future.result(timeout=5) + pool.shutdown() + + def test_concurrent_submit(self): + """Submitting from multiple threads is safe.""" + pool = PriorityThreadPoolExecutor(max_workers=4) + results = [] + lock = threading.Lock() + + def task(n): + with lock: + results.append(n) + + threads = [] + for i in range(20): + t = threading.Thread(target=lambda i=i: pool.submit(task, i, priority=0)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + pool.shutdown(wait=True) + assert sorted(results) == list(range(20)) + + def test_shutdown_with_pending_work(self): + """Shutdown completes all pending work before returning.""" + results = [] + pool = PriorityThreadPoolExecutor(max_workers=1) + + for i in range(10): + pool.submit(lambda i=i: results.append(i), priority=0) + + pool.shutdown(wait=True) + assert sorted(results) == list(range(10)) + + def test_no_cpython_internals_imported(self): + """Verify _WorkItem is not imported.""" + import data_diff.thread_utils as mod + + assert not hasattr(mod, "_WorkItem") + + +class TestThreadedYielder: + def test_basic_yield(self): + """ThreadedYielder collects results from submitted functions.""" + ty = ThreadedYielder(max_workers=2) + ty.submit(lambda: [1, 2, 3]) + ty.submit(lambda: [4, 5, 6]) + + result = list(ty) + assert sorted(result) == [1, 2, 3, 4, 5, 6] + + def test_priority_behavior(self): + """Higher-priority iterators get scheduled first.""" + gate = threading.Event() + ty = ThreadedYielder(max_workers=1) + + # Block the worker + def wait_gate(): + gate.wait() + return [] + + ty.submit(wait_gate, priority=0) + + # Queue tasks with different priorities + ty.submit(lambda: ["low"], priority=1) + ty.submit(lambda: ["high"], priority=3) + ty.submit(lambda: ["mid"], priority=2) + + gate.set() + result = list(ty) + # High-priority tasks should execute first + assert result == ["high", "mid", "low"] + + def test_yield_list_mode(self): + """yield_list=True appends entire results rather than extending.""" + ty = ThreadedYielder(max_workers=1, yield_list=True) + ty.submit(lambda: [1, 2, 3]) + + result = list(ty) + assert result == [[1, 2, 3]] + + def test_exception_propagation(self): + """Exceptions in submitted functions propagate through iteration.""" + ty = ThreadedYielder(max_workers=1) + ty.submit(lambda: (_ for _ in ()).throw(ValueError("boom"))) + + with pytest.raises(ValueError, match="boom"): + list(ty) From 11a9c52df2919a2991ca4e065640613429df17dd Mon Sep 17 00:00:00 2001 From: Daniel Song Date: Sun, 1 Mar 2026 23:21:49 -0800 Subject: [PATCH 2/3] fix: harden PriorityThreadPoolExecutor error handling - Guard _chain_future against cancelled dest and internal exceptions - Add try/except in dispatcher to propagate errors to proxy futures - Reject submit() after shutdown() with RuntimeError - Use float('inf') sentinel priority to never preempt queued work - Add 30s timeout to dispatcher join to prevent deadlock on crash - Add tests for all new error paths Co-Authored-By: Claude Opus 4.6 --- data_diff/thread_utils.py | 51 +++++++++++++++++++++--------- tests/test_thread_utils.py | 63 +++++++++++++++++++++++++++++++++++++- 2 files changed, 98 insertions(+), 16 deletions(-) diff --git a/data_diff/thread_utils.py b/data_diff/thread_utils.py index cbc7252b..3aed62d5 100644 --- a/data_diff/thread_utils.py +++ b/data_diff/thread_utils.py @@ -13,13 +13,21 @@ def _chain_future(source: Future, dest: Future) -> None: - """Propagate the result or exception from source to dest.""" - if source.cancelled(): - dest.cancel() - elif exc := source.exception(): - dest.set_exception(exc) - else: - dest.set_result(source.result()) + """Propagate the outcome (result, exception, or cancellation) from source to dest.""" + if dest.cancelled(): + return + try: + if source.cancelled(): + dest.cancel() + elif exc := source.exception(): + dest.set_exception(exc) + else: + dest.set_result(source.result()) + except Exception as exc: + try: + dest.set_exception(exc) + except Exception: + pass class PriorityThreadPoolExecutor: @@ -33,26 +41,39 @@ def __init__(self, max_workers: int | None = None) -> None: self._inner = ThreadPoolExecutor(max_workers=max_workers) self._queue: PriorityQueue = PriorityQueue() self._counter = itertools.count().__next__ + self._shutdown = False self._dispatcher = threading.Thread(target=self._dispatch, daemon=True) self._dispatcher.start() def _dispatch(self) -> None: while True: - _priority, _count, item = self._queue.get() - if item is _SENTINEL: - break - fn, args, kwargs, proxy = item - inner_future = self._inner.submit(fn, *args, **kwargs) - inner_future.add_done_callback(lambda f, p=proxy: _chain_future(f, p)) + try: + _priority, _count, item = self._queue.get() + if item is _SENTINEL: + break + fn, args, kwargs, proxy = item + inner_future = self._inner.submit(fn, *args, **kwargs) + inner_future.add_done_callback(lambda f, p=proxy: _chain_future(f, p)) + except Exception as exc: + if "proxy" in dir() and not proxy.done(): + try: + proxy.set_exception(exc) + except Exception: + pass def submit(self, fn, /, *args, priority: int = 0, **kwargs) -> Future: + if self._shutdown: + raise RuntimeError("cannot submit after shutdown") proxy = Future() self._queue.put((-priority, self._counter(), (fn, args, kwargs, proxy))) return proxy def shutdown(self, wait: bool = True) -> None: - self._queue.put((0, self._counter(), _SENTINEL)) - self._dispatcher.join() + self._shutdown = True + self._queue.put((float("inf"), self._counter(), _SENTINEL)) + self._dispatcher.join(timeout=30) + if self._dispatcher.is_alive(): + raise RuntimeError("PriorityThreadPoolExecutor dispatcher did not shut down within 30s") self._inner.shutdown(wait=wait) diff --git a/tests/test_thread_utils.py b/tests/test_thread_utils.py index ceb0807a..91d72a16 100644 --- a/tests/test_thread_utils.py +++ b/tests/test_thread_utils.py @@ -1,8 +1,13 @@ import threading +from concurrent.futures import Future import pytest -from data_diff.thread_utils import PriorityThreadPoolExecutor, ThreadedYielder +from data_diff.thread_utils import ( + PriorityThreadPoolExecutor, + ThreadedYielder, + _chain_future, +) class TestPriorityThreadPoolExecutor: @@ -96,6 +101,62 @@ def test_no_cpython_internals_imported(self): assert not hasattr(mod, "_WorkItem") + def test_submit_forwards_args_and_kwargs(self): + """submit() correctly forwards positional and keyword arguments.""" + pool = PriorityThreadPoolExecutor(max_workers=1) + future = pool.submit(lambda a, b, c=None: (a, b, c), 1, 2, c=3) + assert future.result(timeout=5) == (1, 2, 3) + pool.shutdown() + + def test_submit_after_shutdown_raises(self): + """submit() raises RuntimeError after shutdown() is called.""" + pool = PriorityThreadPoolExecutor(max_workers=1) + pool.shutdown() + with pytest.raises(RuntimeError, match="cannot submit after shutdown"): + pool.submit(lambda: None) + + def test_shutdown_drains_high_priority_work(self): + """Sentinel does not preempt queued higher-priority work.""" + gate = threading.Event() + results = [] + + pool = PriorityThreadPoolExecutor(max_workers=1) + pool.submit(lambda: gate.wait(), priority=0) + + for i in range(5): + pool.submit(lambda i=i: results.append(i), priority=10) + + gate.set() + pool.shutdown(wait=True) + assert sorted(results) == list(range(5)) + + +class TestChainFuture: + def test_propagates_result(self): + """Chains result from source to dest.""" + source = Future() + dest = Future() + source.set_result(42) + _chain_future(source, dest) + assert dest.result() == 42 + + def test_propagates_exception(self): + """Chains exception from source to dest.""" + source = Future() + dest = Future() + source.set_exception(ValueError("oops")) + _chain_future(source, dest) + with pytest.raises(ValueError, match="oops"): + dest.result() + + def test_skips_cancelled_dest(self): + """Does not raise if dest was already cancelled.""" + source = Future() + dest = Future() + dest.cancel() + source.set_result(42) + _chain_future(source, dest) # should not raise + class TestThreadedYielder: def test_basic_yield(self): From 170846cf02c6cc9d1bedac25cc686b8ca625d2d6 Mon Sep 17 00:00:00 2001 From: Daniel Song Date: Sun, 1 Mar 2026 23:32:06 -0800 Subject: [PATCH 3/3] fix: replace fragile dir() check with explicit proxy init in dispatcher Initialize proxy = None before try block and check `is not None` instead of using `"proxy" in dir()` which doesn't reliably reflect local variables and retains stale references across loop iterations. Co-Authored-By: Claude Opus 4.6 --- data_diff/thread_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/data_diff/thread_utils.py b/data_diff/thread_utils.py index 3aed62d5..61e8ab2c 100644 --- a/data_diff/thread_utils.py +++ b/data_diff/thread_utils.py @@ -47,6 +47,7 @@ def __init__(self, max_workers: int | None = None) -> None: def _dispatch(self) -> None: while True: + proxy = None try: _priority, _count, item = self._queue.get() if item is _SENTINEL: @@ -55,7 +56,7 @@ def _dispatch(self) -> None: inner_future = self._inner.submit(fn, *args, **kwargs) inner_future.add_done_callback(lambda f, p=proxy: _chain_future(f, p)) except Exception as exc: - if "proxy" in dir() and not proxy.done(): + if proxy is not None and not proxy.done(): try: proxy.set_exception(exc) except Exception: