diff --git a/data_diff/thread_utils.py b/data_diff/thread_utils.py index 8eb2187e..61e8ab2c 100644 --- a/data_diff/thread_utils.py +++ b/data_diff/thread_utils.py @@ -1,42 +1,81 @@ import itertools +import threading from collections import deque from collections.abc import Callable, Iterable, Iterator -from concurrent.futures import ThreadPoolExecutor -from concurrent.futures.thread import _WorkItem +from concurrent.futures import Future, ThreadPoolExecutor from queue import PriorityQueue from time import sleep from typing import Any import attrs - -class AutoPriorityQueue(PriorityQueue): - """Overrides PriorityQueue to automatically get the priority from _WorkItem.kwargs - - We also assign a unique id for each item, to avoid making comparisons on _WorkItem. - As a side effect, items with the same priority are returned FIFO. - """ - - _counter = itertools.count().__next__ - - def put(self, item: _WorkItem | None, block=True, timeout=None) -> None: - priority = item.kwargs.pop("priority") if item is not None else 0 - super().put((-priority, self._counter(), item), block, timeout) - - def get(self, block=True, timeout=None) -> _WorkItem | None: - _p, _c, work_item = super().get(block, timeout) - return work_item +_SENTINEL = object() + + +def _chain_future(source: Future, dest: Future) -> None: + """Propagate the outcome (result, exception, or cancellation) from source to dest.""" + if dest.cancelled(): + return + try: + if source.cancelled(): + dest.cancel() + elif exc := source.exception(): + dest.set_exception(exc) + else: + dest.set_result(source.result()) + except Exception as exc: + try: + dest.set_exception(exc) + except Exception: + pass -class PriorityThreadPoolExecutor(ThreadPoolExecutor): - """Overrides ThreadPoolExecutor to use AutoPriorityQueue +class PriorityThreadPoolExecutor: + """Thread pool that executes tasks in priority order. - XXX WARNING: Might break in future versions of Python + Uses a dispatcher thread to pull work from a PriorityQueue and + submit it to a standard ThreadPoolExecutor. No CPython internals. """ - def __init__(self, *args) -> None: - super().__init__(*args) - self._work_queue = AutoPriorityQueue() + def __init__(self, max_workers: int | None = None) -> None: + self._inner = ThreadPoolExecutor(max_workers=max_workers) + self._queue: PriorityQueue = PriorityQueue() + self._counter = itertools.count().__next__ + self._shutdown = False + self._dispatcher = threading.Thread(target=self._dispatch, daemon=True) + self._dispatcher.start() + + def _dispatch(self) -> None: + while True: + proxy = None + try: + _priority, _count, item = self._queue.get() + if item is _SENTINEL: + break + fn, args, kwargs, proxy = item + inner_future = self._inner.submit(fn, *args, **kwargs) + inner_future.add_done_callback(lambda f, p=proxy: _chain_future(f, p)) + except Exception as exc: + if proxy is not None and not proxy.done(): + try: + proxy.set_exception(exc) + except Exception: + pass + + def submit(self, fn, /, *args, priority: int = 0, **kwargs) -> Future: + if self._shutdown: + raise RuntimeError("cannot submit after shutdown") + proxy = Future() + self._queue.put((-priority, self._counter(), (fn, args, kwargs, proxy))) + return proxy + + def shutdown(self, wait: bool = True) -> None: + self._shutdown = True + self._queue.put((float("inf"), self._counter(), _SENTINEL)) + self._dispatcher.join(timeout=30) + if self._dispatcher.is_alive(): + raise RuntimeError("PriorityThreadPoolExecutor dispatcher did not shut down within 30s") + self._inner.shutdown(wait=wait) @attrs.define(frozen=False, init=False) @@ -47,7 +86,7 @@ class ThreadedYielder(Iterable): Priority for the iterator can be provided via the keyword argument 'priority'. (higher runs first) """ - _pool: ThreadPoolExecutor + _pool: PriorityThreadPoolExecutor _futures: deque _yield: deque = attrs.field(alias="_yield") # Python keyword! _exception: None = None diff --git a/tests/test_thread_utils.py b/tests/test_thread_utils.py new file mode 100644 index 00000000..91d72a16 --- /dev/null +++ b/tests/test_thread_utils.py @@ -0,0 +1,207 @@ +import threading +from concurrent.futures import Future + +import pytest + +from data_diff.thread_utils import ( + PriorityThreadPoolExecutor, + ThreadedYielder, + _chain_future, +) + + +class TestPriorityThreadPoolExecutor: + def test_priority_ordering(self): + """Higher-priority tasks execute before lower-priority ones.""" + gate = threading.Event() + results = [] + + pool = PriorityThreadPoolExecutor(max_workers=1) + + # Block the single worker so tasks queue up + pool.submit(lambda: gate.wait(), priority=0) + + # Submit tasks with different priorities while worker is blocked + for p in [1, 3, 2]: + pool.submit(lambda p=p: results.append(p), priority=p) + + # Release the gate — queued tasks run in priority order + gate.set() + pool.shutdown(wait=True) + + assert results == [3, 2, 1] + + def test_fifo_within_same_priority(self): + """Equal-priority tasks run in submission order (FIFO).""" + gate = threading.Event() + results = [] + + pool = PriorityThreadPoolExecutor(max_workers=1) + pool.submit(lambda: gate.wait(), priority=0) + + for i in range(5): + pool.submit(lambda i=i: results.append(i), priority=1) + + gate.set() + pool.shutdown(wait=True) + + assert results == [0, 1, 2, 3, 4] + + def test_submit_returns_future_with_result(self): + """submit() returns a Future that resolves to the function's return value.""" + pool = PriorityThreadPoolExecutor(max_workers=2) + future = pool.submit(lambda: 42) + assert future.result(timeout=5) == 42 + pool.shutdown() + + def test_submit_returns_future_with_exception(self): + """Exceptions in submitted functions propagate through the Future.""" + pool = PriorityThreadPoolExecutor(max_workers=2) + future = pool.submit(lambda: 1 / 0) + with pytest.raises(ZeroDivisionError): + future.result(timeout=5) + pool.shutdown() + + def test_concurrent_submit(self): + """Submitting from multiple threads is safe.""" + pool = PriorityThreadPoolExecutor(max_workers=4) + results = [] + lock = threading.Lock() + + def task(n): + with lock: + results.append(n) + + threads = [] + for i in range(20): + t = threading.Thread(target=lambda i=i: pool.submit(task, i, priority=0)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + pool.shutdown(wait=True) + assert sorted(results) == list(range(20)) + + def test_shutdown_with_pending_work(self): + """Shutdown completes all pending work before returning.""" + results = [] + pool = PriorityThreadPoolExecutor(max_workers=1) + + for i in range(10): + pool.submit(lambda i=i: results.append(i), priority=0) + + pool.shutdown(wait=True) + assert sorted(results) == list(range(10)) + + def test_no_cpython_internals_imported(self): + """Verify _WorkItem is not imported.""" + import data_diff.thread_utils as mod + + assert not hasattr(mod, "_WorkItem") + + def test_submit_forwards_args_and_kwargs(self): + """submit() correctly forwards positional and keyword arguments.""" + pool = PriorityThreadPoolExecutor(max_workers=1) + future = pool.submit(lambda a, b, c=None: (a, b, c), 1, 2, c=3) + assert future.result(timeout=5) == (1, 2, 3) + pool.shutdown() + + def test_submit_after_shutdown_raises(self): + """submit() raises RuntimeError after shutdown() is called.""" + pool = PriorityThreadPoolExecutor(max_workers=1) + pool.shutdown() + with pytest.raises(RuntimeError, match="cannot submit after shutdown"): + pool.submit(lambda: None) + + def test_shutdown_drains_high_priority_work(self): + """Sentinel does not preempt queued higher-priority work.""" + gate = threading.Event() + results = [] + + pool = PriorityThreadPoolExecutor(max_workers=1) + pool.submit(lambda: gate.wait(), priority=0) + + for i in range(5): + pool.submit(lambda i=i: results.append(i), priority=10) + + gate.set() + pool.shutdown(wait=True) + assert sorted(results) == list(range(5)) + + +class TestChainFuture: + def test_propagates_result(self): + """Chains result from source to dest.""" + source = Future() + dest = Future() + source.set_result(42) + _chain_future(source, dest) + assert dest.result() == 42 + + def test_propagates_exception(self): + """Chains exception from source to dest.""" + source = Future() + dest = Future() + source.set_exception(ValueError("oops")) + _chain_future(source, dest) + with pytest.raises(ValueError, match="oops"): + dest.result() + + def test_skips_cancelled_dest(self): + """Does not raise if dest was already cancelled.""" + source = Future() + dest = Future() + dest.cancel() + source.set_result(42) + _chain_future(source, dest) # should not raise + + +class TestThreadedYielder: + def test_basic_yield(self): + """ThreadedYielder collects results from submitted functions.""" + ty = ThreadedYielder(max_workers=2) + ty.submit(lambda: [1, 2, 3]) + ty.submit(lambda: [4, 5, 6]) + + result = list(ty) + assert sorted(result) == [1, 2, 3, 4, 5, 6] + + def test_priority_behavior(self): + """Higher-priority iterators get scheduled first.""" + gate = threading.Event() + ty = ThreadedYielder(max_workers=1) + + # Block the worker + def wait_gate(): + gate.wait() + return [] + + ty.submit(wait_gate, priority=0) + + # Queue tasks with different priorities + ty.submit(lambda: ["low"], priority=1) + ty.submit(lambda: ["high"], priority=3) + ty.submit(lambda: ["mid"], priority=2) + + gate.set() + result = list(ty) + # High-priority tasks should execute first + assert result == ["high", "mid", "low"] + + def test_yield_list_mode(self): + """yield_list=True appends entire results rather than extending.""" + ty = ThreadedYielder(max_workers=1, yield_list=True) + ty.submit(lambda: [1, 2, 3]) + + result = list(ty) + assert result == [[1, 2, 3]] + + def test_exception_propagation(self): + """Exceptions in submitted functions propagate through iteration.""" + ty = ThreadedYielder(max_workers=1) + ty.submit(lambda: (_ for _ in ()).throw(ValueError("boom"))) + + with pytest.raises(ValueError, match="boom"): + list(ty)