Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion .github/workflows/pr-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ name: Pull Request validation
on:
- pull_request

permissions:
contents: read

env:
IMAGE_NAME: ophiosdev/ocrpdf

Expand All @@ -17,12 +20,24 @@ jobs:

- name: Set up Python
id: setup-python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.11'
cache: pip

- name: Create project venv and install dependencies
run: |
"${{ steps.setup-python.outputs.python-path }}" -m venv .venv
"$PWD/.venv/bin/python" -m pip install -r requirements.txt -r requirements-dev.txt
echo "$PWD/.venv/bin" >> "$GITHUB_PATH"

- name: Run pre-commit checks
id: pre-commit
uses: cloudposse/github-action-pre-commit@v4.0.0

- name: Run pytest
run: .venv/bin/pytest

- name: Build Docker image if Dockerfile changed
run: |
if git diff --name-only origin/${{ github.base_ref }} | grep -q '^Dockerfile$'; then
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ repos:
- id: ruff-format

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.408
rev: v1.1.409
hooks:
- id: pyright
25 changes: 14 additions & 11 deletions packages/smbmonitor/src/smbmonitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def _can_bind_args(sig: Signature, arg_count: int) -> bool:
self._enforce_encryption = enforce_encryption
self._server, self._share, self._watch_path = self._parse_unc_path(unc_path)
self._stop_event = asyncio.Event()
self._watcher_task = None
self._consumer_task = None
self._watcher_task: asyncio.Task[None] | None = None
self._consumer_task: asyncio.Task[None] | None = None
self._queue = asyncio.Queue()
self._port = port
self._uuid = uuid4()
Expand Down Expand Up @@ -271,17 +271,20 @@ async def stop(
log.info("Stopping SMB monitoring")
if graceful:
self._stop_event.set()
_, pending = await asyncio.wait(
(self._watcher_task, self._consumer_task), # pyright: ignore[reportArgumentType, reportCallIssue]
timeout=2.0,
return_when=asyncio.ALL_COMPLETED,
)
for task in pending:
task.cancel()
tasks = [self._watcher_task, self._consumer_task]
active_tasks = {t for t in tasks if t is not None}
if active_tasks:
_, pending = await asyncio.wait(
active_tasks,
timeout=2.0,
return_when=asyncio.ALL_COMPLETED,
)
for task in pending:
task.cancel()
else:
_ = self._watcher_task.cancel() if self._consumer_task else None
_ = self._watcher_task.cancel() if self._watcher_task else None
_ = self._consumer_task.cancel() if self._consumer_task else None

def is_running(self) -> bool:
"""Check if the monitoring tasks are currently running."""
return self._watcher_task and not self._watcher_task.done() # pyright: ignore[reportReturnType]
return self._watcher_task is not None and not self._watcher_task.done()
65 changes: 65 additions & 0 deletions packages/smbmonitor/tests/test_smbmonitor_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,68 @@ async def handler(action, server, share, watch_path, filename) -> None:
await asyncio.wait_for(consumer_task, timeout=2.0)

asyncio.run(_run())


def test_stop_graceful_waits_for_active_tasks() -> None:
async def _run() -> None:
monitor = SmbMonitor(_unc_path(), [handler_full])
release = asyncio.Event()

async def _task() -> None:
await release.wait()

monitor._watcher_task = asyncio.create_task(_task())
monitor._consumer_task = asyncio.create_task(_task())

stop_task = asyncio.create_task(monitor.stop(graceful=True))
await asyncio.sleep(0)
assert not stop_task.done()
assert monitor._stop_event.is_set()

release.set()
await asyncio.wait_for(stop_task, timeout=1.0)
assert monitor._watcher_task.done()
assert monitor._consumer_task.done()

asyncio.run(_run())


def test_stop_graceful_cancels_pending_tasks_after_timeout() -> None:
async def _run() -> None:
monitor = SmbMonitor(_unc_path(), [handler_full])
started = asyncio.Event()

async def _never_finishes() -> None:
started.set()
await asyncio.Future()

monitor._watcher_task = asyncio.create_task(_never_finishes())
await asyncio.wait_for(started.wait(), timeout=1.0)

await monitor.stop(graceful=True)

with pytest.raises(asyncio.CancelledError):
await monitor._watcher_task

assert monitor._watcher_task.cancelled()

asyncio.run(_run())


def test_stop_graceful_handles_missing_consumer_task() -> None:
async def _run() -> None:
monitor = SmbMonitor(_unc_path(), [handler_full])
finished = asyncio.Event()

async def _task() -> None:
finished.set()

monitor._watcher_task = asyncio.create_task(_task())
monitor._consumer_task = None

await monitor.stop(graceful=True)

assert monitor._stop_event.is_set()
await asyncio.wait_for(finished.wait(), timeout=1.0)

asyncio.run(_run())
Loading