-
Notifications
You must be signed in to change notification settings - Fork 4
DBA 150 Implement cli progress feedback #36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fb736ce
3e96dbe
c5376a9
7c22050
c00a5f4
ba697ed
b519009
c865596
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,14 @@ | ||
| from databao_context_engine import BuildDatasourceResult, DatabaoContextDomainManager | ||
|
|
||
| from databao_cli.shared.log.cli_progress import cli_progress | ||
| from databao_cli.shared.project.layout import ProjectLayout | ||
|
|
||
|
|
||
| def build_impl(project_layout: ProjectLayout, domain: str, should_index: bool) -> list[BuildDatasourceResult]: | ||
| dce_project_dir = project_layout.domains_dir / domain | ||
| results: list[BuildDatasourceResult] = DatabaoContextDomainManager(domain_dir=dce_project_dir).build_context( | ||
| datasource_ids=None, should_index=should_index | ||
| ) | ||
| manager = DatabaoContextDomainManager(domain_dir=dce_project_dir) | ||
|
|
||
| datasources = manager.get_configured_datasource_list() | ||
| with cli_progress(total=len(datasources), label="Building datasources"): | ||
| results: list[BuildDatasourceResult] = manager.build_context(datasource_ids=None, should_index=should_index) | ||
| return results | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| from databao_context_engine import DatabaoContextDomainManager, DatasourceId, IndexDatasourceResult | ||
|
|
||
| from databao_cli.shared.log.cli_progress import cli_progress | ||
| from databao_cli.shared.project.layout import ProjectLayout | ||
|
|
||
|
|
||
|
|
@@ -10,8 +11,11 @@ def index_impl( | |
|
|
||
| datasource_ids = [DatasourceId.from_string_repr(p) for p in datasources_config_files] if datasources_config_files else None | ||
|
|
||
| results: list[IndexDatasourceResult] = DatabaoContextDomainManager(domain_dir=dce_project_dir).index_built_contexts( | ||
| datasource_ids=datasource_ids | ||
| ) | ||
| manager = DatabaoContextDomainManager(domain_dir=dce_project_dir) | ||
|
|
||
| total = len(datasource_ids) if datasource_ids is not None else len(manager.get_configured_datasource_list()) | ||
|
|
||
| with cli_progress(total=total, label="Indexing datasources"): | ||
| results: list[IndexDatasourceResult] = manager.index_built_contexts(datasource_ids=datasource_ids) | ||
|
Comment on lines
+14
to
+19
|
||
|
|
||
| return results | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| import re | ||
| import sys | ||
| from collections.abc import Iterator | ||
| from contextlib import contextmanager | ||
| from typing import Any | ||
|
|
||
| # Log patterns emitted by databao_context_engine.build_sources.build_runner | ||
| _BUILD_START_RE = re.compile(r'^Found datasource of type ".*" with name (.+)$') | ||
| _INDEX_START_RE = re.compile(r"^Indexing datasource (.+)$") | ||
| _ENRICH_START_RE = re.compile(r"^Enriching context for datasource (.+)$") | ||
| _SKIP_RE = re.compile(r"^Skipping disabled datasource (.+)$") | ||
| _FAIL_RE = re.compile(r"^Failed to build source at \((.+?)\)") | ||
| _FAIL_ENRICH_RE = re.compile(r"^Failed to enrich context for datasource \((.+?)\)") | ||
|
gasparian marked this conversation as resolved.
|
||
|
|
||
|
Comment on lines
+10
to
+17
|
||
|
|
||
| class _ProgressTrackingHandler(logging.Handler): | ||
| """Intercepts databao_context_engine log messages to drive a Rich progress bar. | ||
|
|
||
| The library processes datasources sequentially. It logs "Found datasource..." | ||
| at the START of each one. We advance the progress bar when we detect that | ||
| a new datasource has started (meaning the previous one finished), and once | ||
| more when the context manager exits (for the last datasource). | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| progress: Any, | ||
| overall_task: Any, | ||
| datasource_task: Any, | ||
| ) -> None: | ||
| super().__init__() | ||
| self._progress = progress | ||
| self._overall_task = overall_task | ||
| self._datasource_task = datasource_task | ||
| self._has_active = False # whether a datasource is currently being processed | ||
|
|
||
| def emit(self, record: logging.LogRecord) -> None: | ||
| msg = record.getMessage() | ||
|
|
||
| # Datasource processing started | ||
| m = _BUILD_START_RE.match(msg) or _INDEX_START_RE.match(msg) or _ENRICH_START_RE.match(msg) | ||
| if m: | ||
| if self._has_active: | ||
| # Previous datasource finished — advance | ||
| self._progress.advance(self._overall_task) | ||
| self._has_active = True | ||
| name = m.group(1) | ||
| self._progress.update(self._datasource_task, description=f" [dim]{name}[/dim]") | ||
| return | ||
|
|
||
| # Datasource skipped (immediately done) | ||
| if _SKIP_RE.match(msg): | ||
| self._progress.advance(self._overall_task) | ||
| return | ||
|
|
||
| # Datasource failed (after "Found datasource", so active is already True) | ||
| if _FAIL_RE.match(msg) or _FAIL_ENRICH_RE.match(msg): | ||
| if self._has_active: | ||
| self._progress.advance(self._overall_task) | ||
| self._has_active = False | ||
| return | ||
|
|
||
| def finish(self) -> None: | ||
| """Advance for the last datasource that was being processed.""" | ||
| if self._has_active: | ||
| self._progress.advance(self._overall_task) | ||
| self._has_active = False | ||
|
|
||
|
|
||
| @contextmanager | ||
| def cli_progress(total: int | None = None, label: str = "Datasources") -> Iterator[None]: | ||
| """Show a Rich progress bar during build/index operations. | ||
|
|
||
| Intercepts ``databao_context_engine`` log messages to track per-datasource progress. | ||
| Redirects library logging through Rich for clean TTY output. | ||
|
|
||
| Args: | ||
| total: Number of datasources to process. | ||
| label: Label for the overall progress bar. | ||
| """ | ||
| try: | ||
| from rich.console import Console | ||
| from rich.logging import RichHandler | ||
| from rich.progress import ( | ||
| BarColumn, | ||
| MofNCompleteColumn, | ||
| Progress, | ||
| SpinnerColumn, | ||
| TextColumn, | ||
| ) | ||
| from rich.table import Column | ||
| except ImportError: | ||
| yield | ||
| return | ||
|
|
||
| console = Console(stderr=True) | ||
|
|
||
| # Rich's is_terminal already checks isatty(), NO_COLOR, TERM=dumb, etc. | ||
| # This prevents progress bar ANSI output from breaking pexpect-based e2e tests. | ||
| if not console.is_terminal: | ||
| yield | ||
| return | ||
|
|
||
| progress = Progress( | ||
| SpinnerColumn(), | ||
| TextColumn( | ||
| "[progress.description]{task.description}", | ||
| table_column=Column(width=50, overflow="ellipsis", no_wrap=True), | ||
| ), | ||
| BarColumn(), | ||
| MofNCompleteColumn(), | ||
| transient=True, | ||
| console=console, | ||
| ) | ||
|
|
||
| overall_task = progress.add_task(label, total=total) | ||
| datasource_task = progress.add_task(" [dim]starting…[/dim]", total=None) | ||
|
|
||
| # --- logging setup --- | ||
| engine_logger = logging.getLogger("databao_context_engine") | ||
| cli_logger = logging.getLogger("databao_cli") | ||
|
|
||
| prev_engine = (list(engine_logger.handlers), engine_logger.propagate) | ||
| prev_cli = (list(cli_logger.handlers), cli_logger.propagate) | ||
|
|
||
| def _is_console_handler(h: logging.Handler) -> bool: | ||
| return isinstance(h, logging.StreamHandler) and getattr(h, "stream", None) in (sys.stderr, sys.stdout) | ||
|
|
||
| rich_handler = RichHandler( | ||
| console=console, | ||
| show_time=False, | ||
| show_level=True, | ||
| show_path=False, | ||
| rich_tracebacks=False, | ||
| ) | ||
|
|
||
| tracker = _ProgressTrackingHandler(progress, overall_task, datasource_task) | ||
| tracker.setLevel(logging.DEBUG) | ||
|
|
||
| try: | ||
| for lgr in (engine_logger, cli_logger): | ||
| kept = [h for h in lgr.handlers if not _is_console_handler(h)] | ||
| lgr.handlers = [*kept, rich_handler] | ||
| lgr.propagate = False | ||
|
|
||
| engine_logger.addHandler(tracker) | ||
|
|
||
| with progress: | ||
| yield | ||
|
|
||
| tracker.finish() | ||
| finally: | ||
| engine_logger.handlers = prev_engine[0] | ||
| engine_logger.propagate = prev_engine[1] | ||
| cli_logger.handlers = prev_cli[0] | ||
| cli_logger.propagate = prev_cli[1] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| """Tests for the cli_progress module.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| from typing import Any | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| from databao_cli.shared.log.cli_progress import cli_progress | ||
|
|
||
|
|
||
| def test_cli_progress_noop_when_not_terminal() -> None: | ||
| """Progress context manager yields without error when console is not a terminal.""" | ||
| with patch("rich.console.Console") as mock_console_cls: | ||
| mock_console_cls.return_value.is_terminal = False | ||
| with cli_progress(total=5, label="Test"): | ||
| pass # Should not raise | ||
|
|
||
|
|
||
| def test_cli_progress_noop_when_rich_not_available() -> None: | ||
| """Progress context manager yields without error when Rich is not installed.""" | ||
| import builtins | ||
|
|
||
| original_import = builtins.__import__ | ||
|
|
||
| def mock_import(name: str, *args: Any, **kwargs: Any) -> object: | ||
| if name == "rich.console": | ||
| raise ImportError("No module named 'rich'") | ||
| return original_import(name, *args, **kwargs) | ||
|
|
||
| with patch("builtins.__import__", side_effect=mock_import), cli_progress(total=3, label="Test"): | ||
| pass # Should not raise | ||
|
|
||
|
|
||
| def test_progress_tracking_handler_advances_on_datasource_start() -> None: | ||
| """The tracking handler advances the progress bar when a new datasource starts.""" | ||
| from databao_cli.shared.log.cli_progress import _ProgressTrackingHandler | ||
|
|
||
| mock_progress = MagicMock() | ||
| overall_task = MagicMock() | ||
| ds_task = MagicMock() | ||
|
|
||
| handler = _ProgressTrackingHandler(mock_progress, overall_task, ds_task) | ||
|
|
||
| # First datasource — no advance (nothing to finish) | ||
| record = logging.LogRecord( | ||
| name="test", | ||
| level=logging.INFO, | ||
| pathname="", | ||
| lineno=0, | ||
| msg='Found datasource of type "duckdb" with name my_db', | ||
| args=(), | ||
| exc_info=None, | ||
| ) | ||
| handler.emit(record) | ||
| mock_progress.advance.assert_not_called() | ||
| mock_progress.update.assert_called_once_with(ds_task, description=" [dim]my_db[/dim]") | ||
|
|
||
| mock_progress.reset_mock() | ||
|
|
||
| # Second datasource — advance for first | ||
| record2 = logging.LogRecord( | ||
| name="test", | ||
| level=logging.INFO, | ||
| pathname="", | ||
| lineno=0, | ||
| msg='Found datasource of type "csv" with name my_csv', | ||
| args=(), | ||
| exc_info=None, | ||
| ) | ||
| handler.emit(record2) | ||
| mock_progress.advance.assert_called_once_with(overall_task) | ||
|
|
||
|
|
||
| def test_progress_tracking_handler_advances_on_skip() -> None: | ||
| """The tracking handler advances when a datasource is skipped.""" | ||
| from databao_cli.shared.log.cli_progress import _ProgressTrackingHandler | ||
|
|
||
| mock_progress = MagicMock() | ||
| handler = _ProgressTrackingHandler(mock_progress, "overall", "ds") | ||
|
|
||
| record = logging.LogRecord( | ||
| name="test", | ||
| level=logging.INFO, | ||
| pathname="", | ||
| lineno=0, | ||
| msg="Skipping disabled datasource my_ds", | ||
| args=(), | ||
| exc_info=None, | ||
| ) | ||
| handler.emit(record) | ||
| mock_progress.advance.assert_called_once_with("overall") | ||
|
|
||
|
|
||
| def test_progress_tracking_handler_advances_on_failure() -> None: | ||
| """The tracking handler advances when a datasource fails.""" | ||
| from databao_cli.shared.log.cli_progress import _ProgressTrackingHandler | ||
|
|
||
| mock_progress = MagicMock() | ||
| handler = _ProgressTrackingHandler(mock_progress, "overall", "ds") | ||
|
|
||
| # Start a datasource first | ||
| record1 = logging.LogRecord( | ||
| name="test", | ||
| level=logging.INFO, | ||
| pathname="", | ||
| lineno=0, | ||
| msg='Found datasource of type "duckdb" with name my_db', | ||
| args=(), | ||
| exc_info=None, | ||
| ) | ||
| handler.emit(record1) | ||
|
|
||
| # Fail it | ||
| record2 = logging.LogRecord( | ||
| name="test", | ||
| level=logging.INFO, | ||
| pathname="", | ||
| lineno=0, | ||
| msg="Failed to build source at (my_db): connection error", | ||
| args=(), | ||
| exc_info=None, | ||
| ) | ||
| handler.emit(record2) | ||
| mock_progress.advance.assert_called_once_with("overall") | ||
|
|
||
|
|
||
| def test_progress_tracking_handler_finish_advances_last() -> None: | ||
| """finish() advances the bar for the last datasource that was being processed.""" | ||
| from databao_cli.shared.log.cli_progress import _ProgressTrackingHandler | ||
|
|
||
| mock_progress = MagicMock() | ||
| handler = _ProgressTrackingHandler(mock_progress, "overall", "ds") | ||
|
|
||
| record = logging.LogRecord( | ||
| name="test", | ||
| level=logging.INFO, | ||
| pathname="", | ||
| lineno=0, | ||
| msg='Found datasource of type "duckdb" with name my_db', | ||
| args=(), | ||
| exc_info=None, | ||
| ) | ||
| handler.emit(record) | ||
| mock_progress.advance.assert_not_called() | ||
|
|
||
| handler.finish() | ||
| mock_progress.advance.assert_called_once_with("overall") | ||
|
|
||
|
|
||
| def test_progress_tracking_handler_finish_noop_when_no_active() -> None: | ||
| """finish() does nothing if no datasource was active.""" | ||
| from databao_cli.shared.log.cli_progress import _ProgressTrackingHandler | ||
|
|
||
| mock_progress = MagicMock() | ||
| handler = _ProgressTrackingHandler(mock_progress, "overall", "ds") | ||
|
|
||
| handler.finish() | ||
| mock_progress.advance.assert_not_called() | ||
|
|
||
|
|
||
| def test_progress_tracking_handler_index_pattern() -> None: | ||
| """The handler recognizes indexing log messages.""" | ||
| from databao_cli.shared.log.cli_progress import _ProgressTrackingHandler | ||
|
|
||
| mock_progress = MagicMock() | ||
| handler = _ProgressTrackingHandler(mock_progress, "overall", "ds") | ||
|
|
||
| record = logging.LogRecord( | ||
| name="test", | ||
| level=logging.INFO, | ||
| pathname="", | ||
| lineno=0, | ||
| msg="Indexing datasource my_ds", | ||
| args=(), | ||
| exc_info=None, | ||
| ) | ||
| handler.emit(record) | ||
| mock_progress.update.assert_called_once_with("ds", description=" [dim]my_ds[/dim]") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This adds an extra call to
manager.get_configured_datasource_list()solely to computetotalfor the progress UI. Sincecli_progressbecomes a no-op on non-TTY stderr, this can introduce unnecessary work in non-interactive runs. Consider computingtotalconditionally (only when stderr is a TTY) or lettingcli_progressaccepttotal=Nonewhen you don't want to pre-scan datasources.