Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"adbc-driver-snowflake>=1.10.0",
"adbc-driver-manager>=1.10.0",
"nh3>=0.2.15",
"rich>=14.3.3",
]

[project.optional-dependencies]
Expand Down
9 changes: 6 additions & 3 deletions src/databao_cli/features/build.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from databao_context_engine import BuildDatasourceResult, DatabaoContextDomainManager

from databao_cli.shared.log.cli_progress import cli_progress
from databao_cli.shared.project.layout import ProjectLayout


def build_impl(project_layout: ProjectLayout, domain: str, should_index: bool) -> list[BuildDatasourceResult]:
dce_project_dir = project_layout.domains_dir / domain
results: list[BuildDatasourceResult] = DatabaoContextDomainManager(domain_dir=dce_project_dir).build_context(
datasource_ids=None, should_index=should_index
)
manager = DatabaoContextDomainManager(domain_dir=dce_project_dir)

datasources = manager.get_configured_datasource_list()
with cli_progress(total=len(datasources), label="Building datasources"):
results: list[BuildDatasourceResult] = manager.build_context(datasource_ids=None, should_index=should_index)
Comment on lines +9 to +13
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds an extra call to manager.get_configured_datasource_list() solely to compute total for the progress UI. Since cli_progress becomes a no-op on non-TTY stderr, this can introduce unnecessary work in non-interactive runs. Consider computing total conditionally (only when stderr is a TTY) or letting cli_progress accept total=None when you don't want to pre-scan datasources.

Copilot uses AI. Check for mistakes.
return results
12 changes: 9 additions & 3 deletions src/databao_cli/features/index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from databao_context_engine import ChunkEmbeddingMode, DatabaoContextDomainManager, DatasourceId, IndexDatasourceResult

from databao_cli.shared.log.cli_progress import cli_progress
from databao_cli.shared.project.layout import ProjectLayout


Expand All @@ -10,8 +11,13 @@ def index_impl(

datasource_ids = [DatasourceId.from_string_repr(p) for p in datasources_config_files] if datasources_config_files else None

results: list[IndexDatasourceResult] = DatabaoContextDomainManager(domain_dir=dce_project_dir).index_built_contexts(
datasource_ids=datasource_ids, chunk_embedding_mode=ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY
)
manager = DatabaoContextDomainManager(domain_dir=dce_project_dir)

total = len(datasource_ids) if datasource_ids is not None else len(manager.get_configured_datasource_list())

with cli_progress(total=total, label="Indexing datasources"):
results: list[IndexDatasourceResult] = manager.index_built_contexts(
datasource_ids=datasource_ids, chunk_embedding_mode=ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY
)

return results
158 changes: 158 additions & 0 deletions src/databao_cli/shared/log/cli_progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from __future__ import annotations

import logging
import re
import sys
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any

# Log patterns emitted by databao_context_engine.build_sources.build_runner
_BUILD_START_RE = re.compile(r'^Found datasource of type ".*" with name (.+)$')
_INDEX_START_RE = re.compile(r"^Indexing datasource (.+)$")
_ENRICH_START_RE = re.compile(r"^Enriching context for datasource (.+)$")
_SKIP_RE = re.compile(r"^Skipping disabled datasource (.+)$")
_FAIL_RE = re.compile(r"^Failed to build source at \((.+?)\)")
_FAIL_INDEX_RE = re.compile(r"^Failed to build source at \((.+?)\)")
_FAIL_ENRICH_RE = re.compile(r"^Failed to enrich context for datasource \((.+?)\)")
Comment thread
gasparian marked this conversation as resolved.

Comment on lines +10 to +17
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description mentions build/index emitting structured progress events and showing final per-datasource completion status, but the implementation here derives progress by regex-parsing log messages and only updates the current datasource + overall count. If structured events/final status output are still intended, the code and/or PR description should be aligned so the documented mechanism matches what ships.

Copilot uses AI. Check for mistakes.

class _ProgressTrackingHandler(logging.Handler):
"""Intercepts databao_context_engine log messages to drive a Rich progress bar.

The library processes datasources sequentially. It logs "Found datasource..."
at the START of each one. We advance the progress bar when we detect that
a new datasource has started (meaning the previous one finished), and once
more when the context manager exits (for the last datasource).
"""

def __init__(
self,
progress: Any,
overall_task: Any,
datasource_task: Any,
) -> None:
super().__init__()
self._progress = progress
self._overall_task = overall_task
self._datasource_task = datasource_task
self._has_active = False # whether a datasource is currently being processed

def emit(self, record: logging.LogRecord) -> None:
msg = record.getMessage()

# Datasource processing started
m = _BUILD_START_RE.match(msg) or _INDEX_START_RE.match(msg) or _ENRICH_START_RE.match(msg)
if m:
if self._has_active:
# Previous datasource finished — advance
self._progress.advance(self._overall_task)
self._has_active = True
name = m.group(1)
self._progress.update(self._datasource_task, description=f" [dim]{name}[/dim]")
return

# Datasource skipped (immediately done)
if _SKIP_RE.match(msg):
self._progress.advance(self._overall_task)
return

# Datasource failed (after "Found datasource", so active is already True)
if _FAIL_RE.match(msg) or _FAIL_INDEX_RE.match(msg) or _FAIL_ENRICH_RE.match(msg):
if self._has_active:
self._progress.advance(self._overall_task)
self._has_active = False
return

def finish(self) -> None:
"""Advance for the last datasource that was being processed."""
if self._has_active:
self._progress.advance(self._overall_task)
self._has_active = False


@contextmanager
def cli_progress(total: int | None = None, label: str = "Datasources") -> Iterator[None]:
"""Show a Rich progress bar during build/index operations.

Intercepts ``databao_context_engine`` log messages to track per-datasource progress.
Redirects library logging through Rich for clean TTY output.

Args:
total: Number of datasources to process.
label: Label for the overall progress bar.
"""
try:
from rich.console import Console
from rich.logging import RichHandler
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TextColumn,
)
from rich.table import Column
except ImportError:
yield
return

if not sys.stderr.isatty():
yield
return

console = Console(stderr=True)

progress = Progress(
SpinnerColumn(),
TextColumn(
"[progress.description]{task.description}",
table_column=Column(width=50, overflow="ellipsis", no_wrap=True),
),
BarColumn(),
MofNCompleteColumn(),
transient=True,
console=console,
)

overall_task = progress.add_task(label, total=total)
datasource_task = progress.add_task(" [dim]starting…[/dim]", total=None)

# --- logging setup ---
engine_logger = logging.getLogger("databao_context_engine")
cli_logger = logging.getLogger("databao_cli")

prev_engine = (list(engine_logger.handlers), engine_logger.propagate)
prev_cli = (list(cli_logger.handlers), cli_logger.propagate)

def _is_console_handler(h: logging.Handler) -> bool:
return isinstance(h, logging.StreamHandler) and getattr(h, "stream", None) in (sys.stderr, sys.stdout)

rich_handler = RichHandler(
console=console,
show_time=False,
show_level=True,
show_path=False,
rich_tracebacks=False,
)

tracker = _ProgressTrackingHandler(progress, overall_task, datasource_task)
tracker.setLevel(logging.DEBUG)

try:
for lgr in (engine_logger, cli_logger):
kept = [h for h in lgr.handlers if not _is_console_handler(h)]
lgr.handlers = [*kept, rich_handler]
lgr.propagate = False

engine_logger.addHandler(tracker)

with progress:
yield

tracker.finish()
finally:
engine_logger.handlers = prev_engine[0]
engine_logger.propagate = prev_engine[1]
cli_logger.handlers = prev_cli[0]
cli_logger.propagate = prev_cli[1]
179 changes: 179 additions & 0 deletions tests/test_cli_progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""Tests for the cli_progress module."""

from __future__ import annotations

import logging
from typing import Any
from unittest.mock import MagicMock, patch

from databao_cli.shared.log.cli_progress import cli_progress


def test_cli_progress_noop_when_not_tty() -> None:
"""Progress context manager yields without error when stderr is not a TTY."""
with patch("sys.stderr") as mock_stderr:
mock_stderr.isatty.return_value = False
with cli_progress(total=5, label="Test"):
pass # Should not raise


def test_cli_progress_noop_when_rich_not_available() -> None:
"""Progress context manager yields without error when Rich is not installed."""
import builtins

original_import = builtins.__import__

def mock_import(name: str, *args: Any, **kwargs: Any) -> object:
if name == "rich.console":
raise ImportError("No module named 'rich'")
return original_import(name, *args, **kwargs)

with patch("builtins.__import__", side_effect=mock_import), cli_progress(total=3, label="Test"):
pass # Should not raise


def test_progress_tracking_handler_advances_on_datasource_start() -> None:
"""The tracking handler advances the progress bar when a new datasource starts."""
from databao_cli.shared.log.cli_progress import _ProgressTrackingHandler

mock_progress = MagicMock()
overall_task = MagicMock()
ds_task = MagicMock()

handler = _ProgressTrackingHandler(mock_progress, overall_task, ds_task)

# First datasource — no advance (nothing to finish)
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg='Found datasource of type "duckdb" with name my_db',
args=(),
exc_info=None,
)
handler.emit(record)
mock_progress.advance.assert_not_called()
mock_progress.update.assert_called_once_with(ds_task, description=" [dim]my_db[/dim]")

mock_progress.reset_mock()

# Second datasource — advance for first
record2 = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg='Found datasource of type "csv" with name my_csv',
args=(),
exc_info=None,
)
handler.emit(record2)
mock_progress.advance.assert_called_once_with(overall_task)


def test_progress_tracking_handler_advances_on_skip() -> None:
"""The tracking handler advances when a datasource is skipped."""
from databao_cli.shared.log.cli_progress import _ProgressTrackingHandler

mock_progress = MagicMock()
handler = _ProgressTrackingHandler(mock_progress, "overall", "ds")

record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="Skipping disabled datasource my_ds",
args=(),
exc_info=None,
)
handler.emit(record)
mock_progress.advance.assert_called_once_with("overall")


def test_progress_tracking_handler_advances_on_failure() -> None:
"""The tracking handler advances when a datasource fails."""
from databao_cli.shared.log.cli_progress import _ProgressTrackingHandler

mock_progress = MagicMock()
handler = _ProgressTrackingHandler(mock_progress, "overall", "ds")

# Start a datasource first
record1 = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg='Found datasource of type "duckdb" with name my_db',
args=(),
exc_info=None,
)
handler.emit(record1)

# Fail it
record2 = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="Failed to build source at (my_db): connection error",
args=(),
exc_info=None,
)
handler.emit(record2)
mock_progress.advance.assert_called_once_with("overall")


def test_progress_tracking_handler_finish_advances_last() -> None:
"""finish() advances the bar for the last datasource that was being processed."""
from databao_cli.shared.log.cli_progress import _ProgressTrackingHandler

mock_progress = MagicMock()
handler = _ProgressTrackingHandler(mock_progress, "overall", "ds")

record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg='Found datasource of type "duckdb" with name my_db',
args=(),
exc_info=None,
)
handler.emit(record)
mock_progress.advance.assert_not_called()

handler.finish()
mock_progress.advance.assert_called_once_with("overall")


def test_progress_tracking_handler_finish_noop_when_no_active() -> None:
"""finish() does nothing if no datasource was active."""
from databao_cli.shared.log.cli_progress import _ProgressTrackingHandler

mock_progress = MagicMock()
handler = _ProgressTrackingHandler(mock_progress, "overall", "ds")

handler.finish()
mock_progress.advance.assert_not_called()


def test_progress_tracking_handler_index_pattern() -> None:
"""The handler recognizes indexing log messages."""
from databao_cli.shared.log.cli_progress import _ProgressTrackingHandler

mock_progress = MagicMock()
handler = _ProgressTrackingHandler(mock_progress, "overall", "ds")

record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="Indexing datasource my_ds",
args=(),
exc_info=None,
)
handler.emit(record)
mock_progress.update.assert_called_once_with("ds", description=" [dim]my_ds[/dim]")
Loading
Loading