Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 65 additions & 26 deletions data_diff/thread_utils.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,81 @@
import itertools
import threading
from collections import deque
from collections.abc import Callable, Iterable, Iterator
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures.thread import _WorkItem
from concurrent.futures import Future, ThreadPoolExecutor
from queue import PriorityQueue
from time import sleep
from typing import Any

import attrs


class AutoPriorityQueue(PriorityQueue):
"""Overrides PriorityQueue to automatically get the priority from _WorkItem.kwargs

We also assign a unique id for each item, to avoid making comparisons on _WorkItem.
As a side effect, items with the same priority are returned FIFO.
"""

_counter = itertools.count().__next__

def put(self, item: _WorkItem | None, block=True, timeout=None) -> None:
priority = item.kwargs.pop("priority") if item is not None else 0
super().put((-priority, self._counter(), item), block, timeout)

def get(self, block=True, timeout=None) -> _WorkItem | None:
_p, _c, work_item = super().get(block, timeout)
return work_item
_SENTINEL = object()


def _chain_future(source: Future, dest: Future) -> None:
"""Propagate the outcome (result, exception, or cancellation) from source to dest."""
if dest.cancelled():
return
try:
if source.cancelled():
dest.cancel()
elif exc := source.exception():
dest.set_exception(exc)
else:
dest.set_result(source.result())
except Exception as exc:
try:
dest.set_exception(exc)
except Exception:
pass


class PriorityThreadPoolExecutor(ThreadPoolExecutor):
"""Overrides ThreadPoolExecutor to use AutoPriorityQueue
class PriorityThreadPoolExecutor:
"""Thread pool that executes tasks in priority order.

XXX WARNING: Might break in future versions of Python
Uses a dispatcher thread to pull work from a PriorityQueue and
submit it to a standard ThreadPoolExecutor. No CPython internals.
"""

def __init__(self, *args) -> None:
super().__init__(*args)
self._work_queue = AutoPriorityQueue()
def __init__(self, max_workers: int | None = None) -> None:
self._inner = ThreadPoolExecutor(max_workers=max_workers)
self._queue: PriorityQueue = PriorityQueue()
self._counter = itertools.count().__next__
self._shutdown = False
self._dispatcher = threading.Thread(target=self._dispatch, daemon=True)
self._dispatcher.start()

def _dispatch(self) -> None:
while True:
proxy = None
try:
_priority, _count, item = self._queue.get()
if item is _SENTINEL:
break
fn, args, kwargs, proxy = item
inner_future = self._inner.submit(fn, *args, **kwargs)
inner_future.add_done_callback(lambda f, p=proxy: _chain_future(f, p))
except Exception as exc:
if proxy is not None and not proxy.done():
try:
proxy.set_exception(exc)
except Exception:
pass

def submit(self, fn, /, *args, priority: int = 0, **kwargs) -> Future:
if self._shutdown:
raise RuntimeError("cannot submit after shutdown")
proxy = Future()
self._queue.put((-priority, self._counter(), (fn, args, kwargs, proxy)))
return proxy
Comment on lines +65 to +70

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Reject submit calls after shutdown

After shutdown() stops the dispatcher thread, submit() still enqueues work unconditionally and returns a fresh Future. Because no thread consumes _queue anymore, that future can remain pending forever and callers waiting on result() will hang. The previous ThreadPoolExecutor-based behavior raised on post-shutdown submissions, so this is a regression for any late producer thread.

Useful? React with 👍 / 👎.


def shutdown(self, wait: bool = True) -> None:
self._shutdown = True
self._queue.put((float("inf"), self._counter(), _SENTINEL))
self._dispatcher.join(timeout=30)
if self._dispatcher.is_alive():
raise RuntimeError("PriorityThreadPoolExecutor dispatcher did not shut down within 30s")
self._inner.shutdown(wait=wait)


@attrs.define(frozen=False, init=False)
Expand All @@ -47,7 +86,7 @@ class ThreadedYielder(Iterable):
Priority for the iterator can be provided via the keyword argument 'priority'. (higher runs first)
"""

_pool: ThreadPoolExecutor
_pool: PriorityThreadPoolExecutor
_futures: deque
_yield: deque = attrs.field(alias="_yield") # Python keyword!
_exception: None = None
Expand Down
207 changes: 207 additions & 0 deletions tests/test_thread_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import threading
from concurrent.futures import Future

import pytest

from data_diff.thread_utils import (
PriorityThreadPoolExecutor,
ThreadedYielder,
_chain_future,
)


class TestPriorityThreadPoolExecutor:
def test_priority_ordering(self):
"""Higher-priority tasks execute before lower-priority ones."""
gate = threading.Event()
results = []

pool = PriorityThreadPoolExecutor(max_workers=1)

# Block the single worker so tasks queue up
pool.submit(lambda: gate.wait(), priority=0)

# Submit tasks with different priorities while worker is blocked
for p in [1, 3, 2]:
pool.submit(lambda p=p: results.append(p), priority=p)

# Release the gate — queued tasks run in priority order
gate.set()
pool.shutdown(wait=True)

assert results == [3, 2, 1]

def test_fifo_within_same_priority(self):
"""Equal-priority tasks run in submission order (FIFO)."""
gate = threading.Event()
results = []

pool = PriorityThreadPoolExecutor(max_workers=1)
pool.submit(lambda: gate.wait(), priority=0)

for i in range(5):
pool.submit(lambda i=i: results.append(i), priority=1)

gate.set()
pool.shutdown(wait=True)

assert results == [0, 1, 2, 3, 4]

def test_submit_returns_future_with_result(self):
"""submit() returns a Future that resolves to the function's return value."""
pool = PriorityThreadPoolExecutor(max_workers=2)
future = pool.submit(lambda: 42)
assert future.result(timeout=5) == 42
pool.shutdown()

def test_submit_returns_future_with_exception(self):
"""Exceptions in submitted functions propagate through the Future."""
pool = PriorityThreadPoolExecutor(max_workers=2)
future = pool.submit(lambda: 1 / 0)
with pytest.raises(ZeroDivisionError):
future.result(timeout=5)
pool.shutdown()

def test_concurrent_submit(self):
"""Submitting from multiple threads is safe."""
pool = PriorityThreadPoolExecutor(max_workers=4)
results = []
lock = threading.Lock()

def task(n):
with lock:
results.append(n)

threads = []
for i in range(20):
t = threading.Thread(target=lambda i=i: pool.submit(task, i, priority=0))
threads.append(t)
t.start()

for t in threads:
t.join()

pool.shutdown(wait=True)
assert sorted(results) == list(range(20))

def test_shutdown_with_pending_work(self):
"""Shutdown completes all pending work before returning."""
results = []
pool = PriorityThreadPoolExecutor(max_workers=1)

for i in range(10):
pool.submit(lambda i=i: results.append(i), priority=0)

pool.shutdown(wait=True)
assert sorted(results) == list(range(10))

def test_no_cpython_internals_imported(self):
"""Verify _WorkItem is not imported."""
import data_diff.thread_utils as mod

assert not hasattr(mod, "_WorkItem")

def test_submit_forwards_args_and_kwargs(self):
"""submit() correctly forwards positional and keyword arguments."""
pool = PriorityThreadPoolExecutor(max_workers=1)
future = pool.submit(lambda a, b, c=None: (a, b, c), 1, 2, c=3)
assert future.result(timeout=5) == (1, 2, 3)
pool.shutdown()

def test_submit_after_shutdown_raises(self):
"""submit() raises RuntimeError after shutdown() is called."""
pool = PriorityThreadPoolExecutor(max_workers=1)
pool.shutdown()
with pytest.raises(RuntimeError, match="cannot submit after shutdown"):
pool.submit(lambda: None)

def test_shutdown_drains_high_priority_work(self):
"""Sentinel does not preempt queued higher-priority work."""
gate = threading.Event()
results = []

pool = PriorityThreadPoolExecutor(max_workers=1)
pool.submit(lambda: gate.wait(), priority=0)

for i in range(5):
pool.submit(lambda i=i: results.append(i), priority=10)

gate.set()
pool.shutdown(wait=True)
assert sorted(results) == list(range(5))


class TestChainFuture:
def test_propagates_result(self):
"""Chains result from source to dest."""
source = Future()
dest = Future()
source.set_result(42)
_chain_future(source, dest)
assert dest.result() == 42

def test_propagates_exception(self):
"""Chains exception from source to dest."""
source = Future()
dest = Future()
source.set_exception(ValueError("oops"))
_chain_future(source, dest)
with pytest.raises(ValueError, match="oops"):
dest.result()

def test_skips_cancelled_dest(self):
"""Does not raise if dest was already cancelled."""
source = Future()
dest = Future()
dest.cancel()
source.set_result(42)
_chain_future(source, dest) # should not raise


class TestThreadedYielder:
def test_basic_yield(self):
"""ThreadedYielder collects results from submitted functions."""
ty = ThreadedYielder(max_workers=2)
ty.submit(lambda: [1, 2, 3])
ty.submit(lambda: [4, 5, 6])

result = list(ty)
assert sorted(result) == [1, 2, 3, 4, 5, 6]

def test_priority_behavior(self):
"""Higher-priority iterators get scheduled first."""
gate = threading.Event()
ty = ThreadedYielder(max_workers=1)

# Block the worker
def wait_gate():
gate.wait()
return []

ty.submit(wait_gate, priority=0)

# Queue tasks with different priorities
ty.submit(lambda: ["low"], priority=1)
ty.submit(lambda: ["high"], priority=3)
ty.submit(lambda: ["mid"], priority=2)

gate.set()
result = list(ty)
# High-priority tasks should execute first
assert result == ["high", "mid", "low"]

def test_yield_list_mode(self):
"""yield_list=True appends entire results rather than extending."""
ty = ThreadedYielder(max_workers=1, yield_list=True)
ty.submit(lambda: [1, 2, 3])

result = list(ty)
assert result == [[1, 2, 3]]

def test_exception_propagation(self):
"""Exceptions in submitted functions propagate through iteration."""
ty = ThreadedYielder(max_workers=1)
ty.submit(lambda: (_ for _ in ()).throw(ValueError("boom")))

with pytest.raises(ValueError, match="boom"):
list(ty)
Loading