diff --git a/.github/workflows/pr-checks.yml b/.github/workflows/pr-checks.yml index cab2475..c659d54 100644 --- a/.github/workflows/pr-checks.yml +++ b/.github/workflows/pr-checks.yml @@ -4,6 +4,9 @@ name: Pull Request validation on: - pull_request +permissions: + contents: read + env: IMAGE_NAME: ophiosdev/ocrpdf @@ -17,12 +20,24 @@ jobs: - name: Set up Python id: setup-python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 + with: + python-version: '3.11' + cache: pip + + - name: Create project venv and install dependencies + run: | + "${{ steps.setup-python.outputs.python-path }}" -m venv .venv + "$PWD/.venv/bin/python" -m pip install -r requirements.txt -r requirements-dev.txt + echo "$PWD/.venv/bin" >> "$GITHUB_PATH" - name: Run pre-commit checks id: pre-commit uses: cloudposse/github-action-pre-commit@v4.0.0 + - name: Run pytest + run: .venv/bin/pytest + - name: Build Docker image if Dockerfile changed run: | if git diff --name-only origin/${{ github.base_ref }} | grep -q '^Dockerfile$'; then diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3c79c37..8ade387 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -63,6 +63,6 @@ repos: - id: ruff-format - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.408 + rev: v1.1.409 hooks: - id: pyright diff --git a/packages/smbmonitor/src/smbmonitor.py b/packages/smbmonitor/src/smbmonitor.py index c42a899..e754ae4 100644 --- a/packages/smbmonitor/src/smbmonitor.py +++ b/packages/smbmonitor/src/smbmonitor.py @@ -100,8 +100,8 @@ def _can_bind_args(sig: Signature, arg_count: int) -> bool: self._enforce_encryption = enforce_encryption self._server, self._share, self._watch_path = self._parse_unc_path(unc_path) self._stop_event = asyncio.Event() - self._watcher_task = None - self._consumer_task = None + self._watcher_task: asyncio.Task[None] | None = None + self._consumer_task: asyncio.Task[None] | None = None self._queue = asyncio.Queue() self._port = port self._uuid = uuid4() @@ -271,17 +271,20 @@ async def stop( log.info("Stopping SMB monitoring") if graceful: self._stop_event.set() - _, pending = await asyncio.wait( - (self._watcher_task, self._consumer_task), # pyright: ignore[reportArgumentType, reportCallIssue] - timeout=2.0, - return_when=asyncio.ALL_COMPLETED, - ) - for task in pending: - task.cancel() + tasks = [self._watcher_task, self._consumer_task] + active_tasks = {t for t in tasks if t is not None} + if active_tasks: + _, pending = await asyncio.wait( + active_tasks, + timeout=2.0, + return_when=asyncio.ALL_COMPLETED, + ) + for task in pending: + task.cancel() else: - _ = self._watcher_task.cancel() if self._consumer_task else None + _ = self._watcher_task.cancel() if self._watcher_task else None _ = self._consumer_task.cancel() if self._consumer_task else None def is_running(self) -> bool: """Check if the monitoring tasks are currently running.""" - return self._watcher_task and not self._watcher_task.done() # pyright: ignore[reportReturnType] + return self._watcher_task is not None and not self._watcher_task.done() diff --git a/packages/smbmonitor/tests/test_smbmonitor_handlers.py b/packages/smbmonitor/tests/test_smbmonitor_handlers.py index 9fed3da..6750cfb 100644 --- a/packages/smbmonitor/tests/test_smbmonitor_handlers.py +++ b/packages/smbmonitor/tests/test_smbmonitor_handlers.py @@ -121,3 +121,68 @@ async def handler(action, server, share, watch_path, filename) -> None: await asyncio.wait_for(consumer_task, timeout=2.0) asyncio.run(_run()) + + +def test_stop_graceful_waits_for_active_tasks() -> None: + async def _run() -> None: + monitor = SmbMonitor(_unc_path(), [handler_full]) + release = asyncio.Event() + + async def _task() -> None: + await release.wait() + + monitor._watcher_task = asyncio.create_task(_task()) + monitor._consumer_task = asyncio.create_task(_task()) + + stop_task = asyncio.create_task(monitor.stop(graceful=True)) + await asyncio.sleep(0) + assert not stop_task.done() + assert monitor._stop_event.is_set() + + release.set() + await asyncio.wait_for(stop_task, timeout=1.0) + assert monitor._watcher_task.done() + assert monitor._consumer_task.done() + + asyncio.run(_run()) + + +def test_stop_graceful_cancels_pending_tasks_after_timeout() -> None: + async def _run() -> None: + monitor = SmbMonitor(_unc_path(), [handler_full]) + started = asyncio.Event() + + async def _never_finishes() -> None: + started.set() + await asyncio.Future() + + monitor._watcher_task = asyncio.create_task(_never_finishes()) + await asyncio.wait_for(started.wait(), timeout=1.0) + + await monitor.stop(graceful=True) + + with pytest.raises(asyncio.CancelledError): + await monitor._watcher_task + + assert monitor._watcher_task.cancelled() + + asyncio.run(_run()) + + +def test_stop_graceful_handles_missing_consumer_task() -> None: + async def _run() -> None: + monitor = SmbMonitor(_unc_path(), [handler_full]) + finished = asyncio.Event() + + async def _task() -> None: + finished.set() + + monitor._watcher_task = asyncio.create_task(_task()) + monitor._consumer_task = None + + await monitor.stop(graceful=True) + + assert monitor._stop_event.is_set() + await asyncio.wait_for(finished.wait(), timeout=1.0) + + asyncio.run(_run())