From 1d5d103575da63123a5c2cc0800bd9cd6fca842c Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 14 May 2026 14:25:41 -0500 Subject: [PATCH 01/69] feat: add multi-table support --- .gitignore | 1 + pkg-py/CHANGELOG.md | 2 + pkg-py/examples/lazy_frame_demo.py | 272 +++++++ pkg-py/src/querychat/__init__.py | 16 + pkg-py/src/querychat/_dash.py | 2 + pkg-py/src/querychat/_datasource.py | 30 +- pkg-py/src/querychat/_gradio.py | 19 +- pkg-py/src/querychat/_query_executor.py | 265 +++++++ pkg-py/src/querychat/_querychat_base.py | 358 ++++++++- pkg-py/src/querychat/_querychat_core.py | 83 ++- pkg-py/src/querychat/_shiny.py | 18 +- pkg-py/src/querychat/_shiny_module.py | 152 +++- pkg-py/src/querychat/_streamlit.py | 2 + pkg-py/src/querychat/_system_prompt.py | 64 +- pkg-py/src/querychat/_table_accessor.py | 113 +++ pkg-py/src/querychat/_viz_ggsql.py | 14 +- pkg-py/src/querychat/_viz_tools.py | 16 +- pkg-py/src/querychat/prompts/prompt.md | 21 +- pkg-py/src/querychat/prompts/tool-query.md | 5 + .../querychat/prompts/tool-reset-dashboard.md | 9 +- .../prompts/tool-update-dashboard.md | 9 + .../src/querychat/prompts/tool-visualize.md | 5 + pkg-py/src/querychat/tools.py | 120 ++- pkg-py/src/querychat/types/__init__.py | 3 +- pkg-py/tests/test_base.py | 3 +- pkg-py/tests/test_client_console.py | 10 + pkg-py/tests/test_frameworks.py | 29 + pkg-py/tests/test_multi_table.py | 694 ++++++++++++++++++ pkg-py/tests/test_query_executor.py | 393 ++++++++++ pkg-py/tests/test_querychat.py | 7 +- pkg-py/tests/test_state.py | 228 +++++- pkg-py/tests/test_system_prompt.py | 8 +- pkg-py/tests/test_tools.py | 97 ++- 33 files changed, 2887 insertions(+), 181 deletions(-) create mode 100644 pkg-py/examples/lazy_frame_demo.py create mode 100644 pkg-py/src/querychat/_query_executor.py create mode 100644 pkg-py/src/querychat/_table_accessor.py create mode 100644 pkg-py/tests/test_multi_table.py create mode 100644 pkg-py/tests/test_query_executor.py diff --git a/.gitignore b/.gitignore index df712fb0f..ef52805cf 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ animation.screenflow/ README_files/ README.html .DS_Store +test-results/ python-package/examples/titanic.db .quarto *.db diff --git a/pkg-py/CHANGELOG.md b/pkg-py/CHANGELOG.md index 1660dbb81..31f1719e9 100644 --- a/pkg-py/CHANGELOG.md +++ b/pkg-py/CHANGELOG.md @@ -79,6 +79,8 @@ Each framework's `QueryChat` provides `.app()` for quick standalone apps and `.u ### New features +* Added `PolarsLazySource` to support Polars LazyFrames as data sources. Data stays lazy until the render boundary, enabling efficient handling of large datasets. Pass a `polars.LazyFrame` directly to `QueryChat()` and queries will be executed lazily via Polars' SQLContext. + * `QueryChat.console()` was added to launch interactive console-based chat sessions with your data source, with persistent conversation state across invocations. (#168) * `QueryChat.client()` can now create standalone querychat-enabled chat clients with configurable tools and callbacks, enabling use outside of Shiny applications. (#168) diff --git a/pkg-py/examples/lazy_frame_demo.py b/pkg-py/examples/lazy_frame_demo.py new file mode 100644 index 000000000..004b49f59 --- /dev/null +++ b/pkg-py/examples/lazy_frame_demo.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +""" +Demo script comparing eager vs lazy data source performance. + +This script demonstrates the performance benefits of using PolarsLazySource +with large datasets. It creates a synthetic dataset and compares: +1. Eager loading (all data in memory upfront) +2. Lazy loading (data stays on disk until needed) + +Usage: + # Set your API key first + export OPENAI_API_KEY="your-key-here" + + # Run the demo + cd pkg-py + uv run python examples/lazy_frame_demo.py + + # Or with a custom number of rows (default: 10 million) + uv run python examples/lazy_frame_demo.py --rows 50000000 +""" + +import argparse +import os +import tempfile +import time +from pathlib import Path + +import polars as pl + + +def create_large_dataset(path: Path, n_rows: int) -> None: + """Create a large parquet file for testing.""" + print(f"Creating dataset with {n_rows:,} rows...") + start = time.perf_counter() + + # Generate data in chunks to avoid memory issues + chunk_size = 1_000_000 + chunks_written = 0 + + for i in range(0, n_rows, chunk_size): + chunk_rows = min(chunk_size, n_rows - i) + chunk = pl.DataFrame( + { + "id": range(i, i + chunk_rows), + "category": [f"cat_{j % 100}" for j in range(chunk_rows)], + "region": [["North", "South", "East", "West"][j % 4] for j in range(chunk_rows)], + "value": [float(j % 1000) + 0.5 for j in range(chunk_rows)], + "quantity": [j % 500 for j in range(chunk_rows)], + "date": pl.Series([f"2024-{(j % 12) + 1:02d}-{(j % 28) + 1:02d}" for j in range(chunk_rows)]).str.to_date(), + } + ) + + if chunks_written == 0: + chunk.write_parquet(path) + else: + # Append by reading existing and concatenating + existing = pl.read_parquet(path) + pl.concat([existing, chunk]).write_parquet(path) + + chunks_written += 1 + print(f" Written {min(i + chunk_size, n_rows):,} / {n_rows:,} rows") + + elapsed = time.perf_counter() - start + file_size_mb = path.stat().st_size / (1024 * 1024) + print(f"Dataset created: {file_size_mb:.1f} MB in {elapsed:.1f}s\n") + + +def measure_memory() -> float: + """Get current memory usage in MB (approximate).""" + import psutil + process = psutil.Process(os.getpid()) + return process.memory_info().rss / (1024 * 1024) + + +def demo_eager_vs_lazy(parquet_path: Path) -> None: + """Compare eager vs lazy data loading performance.""" + from querychat import QueryChat + + print("=" * 60) + print("COMPARING EAGER VS LAZY DATA SOURCE") + print("=" * 60) + + # Check if we have psutil for memory tracking + try: + import psutil # noqa: F401 + has_psutil = True + except ImportError: + has_psutil = False + print("(Install psutil for memory usage tracking: pip install psutil)\n") + + # --- EAGER LOADING --- + print("\n1. EAGER LOADING (polars.read_parquet → DataFrame)") + print("-" * 50) + + if has_psutil: + mem_before = measure_memory() + + start = time.perf_counter() + df = pl.read_parquet(parquet_path) + load_time = time.perf_counter() - start + + if has_psutil: + mem_after = measure_memory() + print(f" Memory increase: {mem_after - mem_before:.1f} MB") + + print(f" Load time: {load_time:.2f}s") + print(f" Rows loaded: {len(df):,}") + + # Create QueryChat with eager data + start = time.perf_counter() + qc_eager = QueryChat( + data_source=df, + table_name="sales", + greeting="Hello!", + ) + init_time = time.perf_counter() - start + print(f" QueryChat init: {init_time:.2f}s") + + # Execute a query + start = time.perf_counter() + result = qc_eager.data_source.execute_query( + "SELECT region, SUM(value) as total FROM sales GROUP BY region" + ) + query_time = time.perf_counter() - start + print(f" Query execution: {query_time:.3f}s") + print(f" Result rows: {len(result)}") + + del df, qc_eager, result + import gc + gc.collect() + + # --- LAZY LOADING --- + print("\n2. LAZY LOADING (polars.scan_parquet → LazyFrame)") + print("-" * 50) + + if has_psutil: + mem_before = measure_memory() + + start = time.perf_counter() + lf = pl.scan_parquet(parquet_path) + load_time = time.perf_counter() - start + + if has_psutil: + mem_after = measure_memory() + print(f" Memory increase: {mem_after - mem_before:.1f} MB") + + print(f" 'Load' time: {load_time:.4f}s (just metadata!)") + + # Create QueryChat with lazy data + start = time.perf_counter() + qc_lazy = QueryChat( + data_source=lf, + table_name="sales", + greeting="Hello!", + ) + init_time = time.perf_counter() - start + print(f" QueryChat init: {init_time:.2f}s") + + # Execute the same query (stays lazy) + start = time.perf_counter() + result_lazy = qc_lazy.data_source.execute_query( + "SELECT region, SUM(value) as total FROM sales GROUP BY region" + ) + query_time = time.perf_counter() - start + print(f" Query execution (lazy): {query_time:.3f}s") + + # Now collect to get actual results + start = time.perf_counter() + result_collected = result_lazy.collect() + collect_time = time.perf_counter() - start + print(f" Collect time: {collect_time:.3f}s") + print(f" Result rows: {len(result_collected)}") + + # --- SUMMARY --- + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + print(""" +Key differences: +- EAGER: Loads ALL data into memory immediately +- LAZY: Only reads metadata; data stays on disk until .collect() + +Benefits of lazy: +- Much faster startup (no full data load) +- Lower memory usage (only results in memory) +- Query optimization (Polars can push down filters) + +Use lazy (scan_parquet) for: +- Large files that don't fit in memory +- When you only need filtered/aggregated subsets +- Interactive exploration of big data +""") + + +def interactive_demo(parquet_path: Path) -> None: + """Launch an interactive QueryChat session with the lazy data.""" + from querychat import QueryChat + + print("\n" + "=" * 60) + print("INTERACTIVE DEMO") + print("=" * 60) + + lf = pl.scan_parquet(parquet_path) + qc = QueryChat( + data_source=lf, + table_name="sales", + greeting="I'm connected to a large sales dataset. Ask me anything!", + ) + + print("\nLaunching interactive console...") + print("Try queries like:") + print(' - "Show me total sales by region"') + print(' - "What are the top 10 categories by quantity?"') + print(' - "Filter to just the North region"') + print("\nType 'exit' to quit.\n") + + qc.console() + + +def main(): + parser = argparse.ArgumentParser(description="Demo lazy vs eager data loading") + parser.add_argument( + "--rows", + type=int, + default=10_000_000, + help="Number of rows to generate (default: 10 million)", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Launch interactive console after comparison", + ) + parser.add_argument( + "--data-path", + type=str, + default=None, + help="Path to existing parquet file (skip generation)", + ) + args = parser.parse_args() + + # Check for API key + if not os.environ.get("OPENAI_API_KEY"): + print("Warning: OPENAI_API_KEY not set. Interactive mode won't work.") + print("Set it with: export OPENAI_API_KEY='your-key-here'\n") + + # Create or use existing data file + if args.data_path: + parquet_path = Path(args.data_path) + if not parquet_path.exists(): + print(f"Error: File not found: {parquet_path}") + return + else: + # Create temporary file + temp_dir = tempfile.mkdtemp() + parquet_path = Path(temp_dir) / "large_sales_data.parquet" + create_large_dataset(parquet_path, args.rows) + + try: + demo_eager_vs_lazy(parquet_path) + + if args.interactive: + interactive_demo(parquet_path) + finally: + # Cleanup temp file if we created it + if not args.data_path and parquet_path.exists(): + print(f"\nCleaning up temporary file: {parquet_path}") + parquet_path.unlink() + parquet_path.parent.rmdir() + + +if __name__ == "__main__": + main() diff --git a/pkg-py/src/querychat/__init__.py b/pkg-py/src/querychat/__init__.py index 0e3eaa5f5..f7c64e9b9 100644 --- a/pkg-py/src/querychat/__init__.py +++ b/pkg-py/src/querychat/__init__.py @@ -1,10 +1,26 @@ +from ._datasource import ( + DataFrameSource, + DataSource, + IbisSource, + MissingColumnsError, + PolarsLazySource, + SQLAlchemySource, +) from ._deprecated import greeting, init, sidebar, system_prompt from ._deprecated import mod_server as server from ._deprecated import mod_ui as ui from ._shiny import QueryChat +from ._table_accessor import TableAccessor __all__ = ( + "DataFrameSource", + "DataSource", + "IbisSource", + "MissingColumnsError", + "PolarsLazySource", "QueryChat", + "SQLAlchemySource", + "TableAccessor", # TODO(lifecycle): Remove these deprecated functions when we reach v1.0 "greeting", "init", diff --git a/pkg-py/src/querychat/_dash.py b/pkg-py/src/querychat/_dash.py index da9b57da9..1e46801e4 100644 --- a/pkg-py/src/querychat/_dash.py +++ b/pkg-py/src/querychat/_dash.py @@ -283,6 +283,8 @@ def ui( data_source, self._client_factory, self.greeting, + data_sources=dict(self._data_sources), + query_executor=self._require_query_executor("ui"), ) return html.Div( diff --git a/pkg-py/src/querychat/_datasource.py b/pkg-py/src/querychat/_datasource.py index ed2c8ecd4..1a5b7c120 100644 --- a/pkg-py/src/querychat/_datasource.py +++ b/pkg-py/src/querychat/_datasource.py @@ -307,6 +307,24 @@ def get_semantic_views_description(self) -> str: return "" +def lockdown_duckdb(conn: duckdb.DuckDBPyConnection) -> None: + """Apply security lockdown to a DuckDB connection.""" + conn.execute(""" +-- extensions: lock down supply chain + auto behaviors +SET allow_community_extensions = false; +SET allow_unsigned_extensions = false; +SET autoinstall_known_extensions = false; +SET autoload_known_extensions = false; + +-- external I/O: block file/database/network access from SQL +SET enable_external_access = false; +SET disabled_filesystems = 'LocalFileSystem'; + +-- freeze configuration so user SQL can't relax anything +SET lock_configuration = true; + """) + + class DataFrameSource(DataSource[IntoDataFrameT]): """A DataSource implementation that wraps a DataFrame using DuckDB.""" @@ -335,7 +353,7 @@ def __init__(self, df: nw.DataFrame, table_name: str): self._conn = duckdb.connect(database=":memory:") # NOTE: if native representation is polars, pyarrow is required for registration self._conn.register(table_name, self._df.to_native()) - duckdb_lock_down(self._conn) + lockdown_duckdb(self._conn) # Store original column names for validation self._colnames = list(self._df.columns) @@ -536,6 +554,11 @@ def get_db_type(self) -> str: """ return self._engine.dialect.name.upper().replace(" SQL", "") + @property + def engine(self) -> Engine: + """The SQLAlchemy engine for this data source.""" + return self._engine + def get_schema(self, *, categorical_threshold: int) -> str: """ Generate schema information from database table. @@ -1029,6 +1052,11 @@ def __init__(self, table: ibis.Table, table_name: str): def get_db_type(self) -> str: return self._backend.name + @property + def backend(self) -> SQLBackend: + """The Ibis SQL backend for this data source.""" + return self._backend + def get_schema(self, *, categorical_threshold: int) -> str: columns = [ self._make_column_meta(name, dtype) for name, dtype in self._schema.items() diff --git a/pkg-py/src/querychat/_gradio.py b/pkg-py/src/querychat/_gradio.py index cc0067084..9fb871235 100644 --- a/pkg-py/src/querychat/_gradio.py +++ b/pkg-py/src/querychat/_gradio.py @@ -251,7 +251,11 @@ def ui(self) -> gr.State: import gradio as gr initial_state = create_app_state( - data_source, self._client_factory, self.greeting + data_source, + self._client_factory, + self.greeting, + data_sources=dict(self._data_sources), + query_executor=self._require_query_executor("ui"), ) state_holder = gr.State(value=initial_state.to_dict()) @@ -368,17 +372,14 @@ def app(self) -> GradioBlocksWrapper: def update_displays(state_dict: AppStateDict): """Update SQL and data displays based on state.""" - title = state_dict.get("title") if state_dict else None - error = state_dict.get("error") if state_dict else None + state = self._deserialize_state(state_dict) + df = state.get_current_data() + title = state.title + error = state.error sql_title_text = f"### {title or 'SQL Query'}" - sql_code = ( - state_dict.get("sql") - if state_dict and state_dict.get("sql") - else f"SELECT * FROM {table_name}" - ) + sql_code = state.get_display_sql() - df = self.df(state_dict) nw_df = as_narwhals(df) nrow, ncol = nw_df.shape native_df = nw_df.to_native() diff --git a/pkg-py/src/querychat/_query_executor.py b/pkg-py/src/querychat/_query_executor.py new file mode 100644 index 000000000..e77c08494 --- /dev/null +++ b/pkg-py/src/querychat/_query_executor.py @@ -0,0 +1,265 @@ +"""QueryExecutor abstraction for cross-table query execution.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +import duckdb +import narwhals.stable.v1 as nw + +from ._datasource import MissingColumnsError, lockdown_duckdb +from ._utils import check_query + +if TYPE_CHECKING: + from ._datasource import DataFrameSource, DataSource, PolarsLazySource + + +class QueryExecutor(ABC): + """Thin abstraction that tools use for query execution and validation.""" + + @abstractmethod + def execute_query(self, query: str) -> Any: ... + + @abstractmethod + def test_query( + self, query: str, *, table_name: str, require_all_columns: bool = False + ) -> None: ... + + @abstractmethod + def get_db_type(self) -> str: ... + + @abstractmethod + def cleanup(self) -> None: ... + + +class DuckDBExecutor(QueryExecutor): + """Shared DuckDB connection for multi-table DataFrameSource queries.""" + + def __init__(self, sources: dict[str, DataFrameSource]): + self._df_lib = get_shared_dataframe_backend(sources) + self._conn = duckdb.connect(database=":memory:") + + for name, source in sources.items(): + self._conn.register(name, source.get_data()) + + # Cache column names per table before lockdown + self._table_columns: dict[str, list[str]] = {} + for name in sources: + result = self._conn.execute(f"SELECT * FROM {name} LIMIT 0") + self._table_columns[name] = [desc[0] for desc in result.description] + + lockdown_duckdb(self._conn) + + def execute_query(self, query: str) -> Any: + check_query(query) + result = self._conn.execute(query) + return self._convert_result(result) + + def _convert_result(self, result: duckdb.DuckDBPyConnection) -> Any: + if self._df_lib == "polars": + return result.pl() + elif self._df_lib == "pandas": + return result.df() + elif self._df_lib == "pyarrow": + return result.fetch_arrow_table() + else: + raise ValueError( + f"Unsupported DataFrame backend: '{self._df_lib}'. " + "Supported backends are: polars, pandas, pyarrow" + ) + + def test_query( + self, query: str, *, table_name: str, require_all_columns: bool = False + ) -> None: + check_query(query) + result = self._conn.execute(f"{query} LIMIT 1") + + if require_all_columns: + result_columns = {desc[0] for desc in result.description} + expected = set(self._table_columns[table_name]) + missing = expected - result_columns + if missing: + missing_list = ", ".join(f"'{c}'" for c in sorted(missing)) + original_list = ", ".join( + f"'{c}'" for c in self._table_columns[table_name] + ) + raise MissingColumnsError( + f"Query result missing required columns: {missing_list}. " + f"The query must return all original table columns. " + f"Original columns: {original_list}" + ) + + def get_db_type(self) -> str: + return "DuckDB" + + def cleanup(self) -> None: + if self._conn: + self._conn.close() + + +class PolarsSQLExecutor(QueryExecutor): + """Shared Polars SQLContext for multi-table PolarsLazySource queries.""" + + def __init__(self, sources: dict[str, PolarsLazySource]): + import polars as pl + + frames = {name: source.get_data() for name, source in sources.items()} + self._ctx = pl.SQLContext(frames) + + self._table_columns: dict[str, list[str]] = {} + for name, source in sources.items(): + self._table_columns[name] = list(source.get_data().collect_schema().keys()) + + def execute_query(self, query: str) -> Any: + check_query(query) + return self._ctx.execute(query) + + def test_query( + self, query: str, *, table_name: str, require_all_columns: bool = False + ) -> None: + check_query(query) + test_lf = self._ctx.execute(f"SELECT * FROM ({query}) AS subquery LIMIT 1") + test_lf.collect() + + if require_all_columns: + full_lf = self._ctx.execute(query) + result_columns = set(full_lf.collect_schema().keys()) + expected = set(self._table_columns[table_name]) + missing = expected - result_columns + if missing: + missing_list = ", ".join(f"'{c}'" for c in sorted(missing)) + original_list = ", ".join( + f"'{c}'" for c in self._table_columns[table_name] + ) + raise MissingColumnsError( + f"Query result missing required columns: {missing_list}. " + f"The query must return all original table columns. " + f"Original columns: {original_list}" + ) + + def get_db_type(self) -> str: + return "Polars" + + def cleanup(self) -> None: + pass + + +class DataSourceExecutor(QueryExecutor): + """ + Wraps existing DataSource(s) for backends that already share a connection. + + Used for single-table mode (any source type) and multi-table SQLAlchemy/Ibis + where all sources share the same database backend. + """ + + def __init__(self, data_sources: dict[str, DataSource]): + validate_source_group_compatibility(data_sources) + self._data_sources = data_sources + self._primary = next(iter(data_sources.values())) + + def execute_query(self, query: str) -> Any: + return self._primary.execute_query(query) + + def test_query( + self, query: str, *, table_name: str, require_all_columns: bool = False + ) -> None: + self._data_sources[table_name].test_query( + query, require_all_columns=require_all_columns + ) + + def get_db_type(self) -> str: + return self._primary.get_db_type() + + def cleanup(self) -> None: + pass + + +def get_shared_dataframe_backend(sources: dict[str, DataFrameSource]) -> str: + """Return the shared backend name, rejecting mixed DataFrameSource backends.""" + source_items = iter(sources.items()) + _, first_source = next(source_items) + shared_lib = get_dataframe_backend_name(first_source) + + for name, source in source_items: + source_lib = get_dataframe_backend_name(source) + if source_lib != shared_lib: + raise ValueError( + f"Cannot add table '{name}': all DataFrameSources must use " + f"the same DataFrame backend. " + f"Existing tables use {shared_lib}, new table uses {source_lib}." + ) + + return shared_lib + + +def validate_source_group_compatibility(data_sources: dict[str, DataSource]) -> None: + """Validate that a group of sources satisfies shared executor constraints.""" + existing: dict[str, DataSource] = {} + for name, source in data_sources.items(): + check_source_compatibility(existing, source, name) + existing[name] = source + + +def check_source_compatibility( + existing: dict[str, DataSource], + new_source: DataSource, + new_name: str, +) -> None: + """Validate that a new source is compatible with existing sources.""" + if not existing: + return + + from ._datasource import ( + DataFrameSource, + IbisSource, + SQLAlchemySource, + ) + + first_source = next(iter(existing.values())) + + if type(new_source) is not type(first_source): + raise ValueError( + f"Cannot add {type(new_source).__name__} table '{new_name}': " + f"all tables must be the same type. " + f"Existing tables use {type(first_source).__name__}." + ) + + if isinstance(new_source, DataFrameSource) and isinstance( + first_source, DataFrameSource + ): + new_lib = get_dataframe_backend_name(new_source) + existing_lib = get_dataframe_backend_name(first_source) + if new_lib != existing_lib: + raise ValueError( + f"Cannot add table '{new_name}': all DataFrameSources must use " + f"the same DataFrame backend. " + f"Existing tables use {existing_lib}, new table uses {new_lib}." + ) + + if ( + isinstance(new_source, SQLAlchemySource) + and isinstance(first_source, SQLAlchemySource) + and new_source.engine is not first_source.engine + ): + raise ValueError( + f"Cannot add table '{new_name}': all SQLAlchemy tables must " + f"share the same Engine instance." + ) + + if ( + isinstance(new_source, IbisSource) + and isinstance(first_source, IbisSource) + and new_source.backend is not first_source.backend + ): + raise ValueError( + f"Cannot add table '{new_name}': all Ibis tables must " + f"share the same backend instance." + ) + + +def get_dataframe_backend_name(source: DataFrameSource) -> str: + """Return the native eager dataframe backend name for a DataFrameSource.""" + return nw.get_native_namespace( + nw.from_native(source.get_data(), eager_only=True) + ).__name__ diff --git a/pkg-py/src/querychat/_querychat_base.py b/pkg-py/src/querychat/_querychat_base.py index feaf3a45d..91762e36b 100644 --- a/pkg-py/src/querychat/_querychat_base.py +++ b/pkg-py/src/querychat/_querychat_base.py @@ -21,11 +21,21 @@ SQLAlchemySource, ) from ._pin_source import PinSource, is_pins_board +from ._query_executor import ( + DataSourceExecutor, + DuckDBExecutor, + PolarsSQLExecutor, + QueryExecutor, + check_source_compatibility, + validate_source_group_compatibility, +) from ._querychat_core import GREETING_PROMPT from ._system_prompt import QueryChatSystemPrompt +from ._table_accessor import TableAccessor from ._utils import MISSING, MISSING_TYPE, is_ibis_table from ._viz_utils import has_viz_deps, has_viz_tool from .tools import ( + ResetDashboardCallback, UpdateDashboardData, tool_query, tool_reset_dashboard, @@ -90,6 +100,16 @@ def __init__( "Table name must begin with a letter and contain only letters, numbers, and underscores", ) + # Multi-table storage: dict of data sources keyed by table name + self._data_sources: dict[str, DataSource] = {} + + # Track server initialization state for add/remove table validation + self._server_initialized = False + + # Store metadata for multi-table support + self._table_relationships: dict[str, dict[str, str]] = {} + self._table_descriptions: dict[str, str] = {} + self.tools = normalize_tools(tools, default=DEFAULT_TOOLS) self.greeting = greeting.read_text() if isinstance(greeting, Path) else greeting @@ -113,11 +133,13 @@ def __init__( data_source, table_name ) self._table_name = self._data_source.table_name + self._data_sources[table_name] = self._data_source self._auto_fill_data_description() self._build_system_prompt() else: self._data_source = None self._system_prompt = None + self._query_executor = None def _auto_fill_data_description(self) -> None: """Auto-populate data_description from data source metadata if not user-supplied.""" @@ -130,22 +152,73 @@ def _auto_fill_data_description(self) -> None: self._data_description = desc self._data_description_mode = "inferred" - def _build_system_prompt(self) -> None: - """Build/rebuild the system prompt from current data source.""" - if self._data_source is None: + def _build_system_prompt( + self, + *, + data_sources: dict[str, DataSource] | None = None, + relationships: dict[str, dict[str, str]] | None = None, + table_descriptions: dict[str, str] | None = None, + ) -> None: + """Build/rebuild the system prompt from current or staged data sources.""" + next_data_sources = self._data_sources if data_sources is None else data_sources + + if not next_data_sources: raise RuntimeError("Cannot build system prompt without data_source") prompt_template = self._prompt_template if prompt_template is None: prompt_template = Path(__file__).parent / "prompts" / "prompt.md" - self._system_prompt = QueryChatSystemPrompt( + next_relationships = ( + self._table_relationships if relationships is None else relationships + ) + next_table_descriptions = ( + self._table_descriptions + if table_descriptions is None + else table_descriptions + ) + + replacement_prompt = QueryChatSystemPrompt( prompt_template=prompt_template, - data_source=self._data_source, + data_sources=next_data_sources, data_description=self._data_description, extra_instructions=self._extra_instructions, categorical_threshold=self._categorical_threshold, + relationships=next_relationships, + table_descriptions=next_table_descriptions, ) + replacement_executor = self._build_query_executor(data_sources=next_data_sources) + previous_executor = getattr(self, "_query_executor", None) + + self._system_prompt = replacement_prompt + self._query_executor = replacement_executor + + if previous_executor is not None: + previous_executor.cleanup() + + def _build_query_executor( + self, *, data_sources: dict[str, DataSource] | None = None + ) -> QueryExecutor: + """Build a query executor from current or staged data sources.""" + sources = self._data_sources if data_sources is None else data_sources + + validate_source_group_compatibility(sources) + + if len(sources) == 1: + return DataSourceExecutor(dict(sources)) + + first_source = next(iter(sources.values())) + + if isinstance(first_source, DataFrameSource): + return DuckDBExecutor( + {n: s for n, s in sources.items() if isinstance(s, DataFrameSource)} + ) + if isinstance(first_source, PolarsLazySource): + return PolarsSQLExecutor( + {n: s for n, s in sources.items() if isinstance(s, PolarsLazySource)} + ) + + return DataSourceExecutor(dict(sources)) def _require_data_source(self, method_name: str) -> DataSource[IntoFrameT]: """Raise if data_source is not set, otherwise return it for type narrowing.""" @@ -157,13 +230,22 @@ def _require_data_source(self, method_name: str) -> DataSource[IntoFrameT]: ) return self._data_source + def _require_query_executor(self, method_name: str) -> QueryExecutor: + """Raise if query executor is not initialized, otherwise return it.""" + if self._query_executor is None: + raise RuntimeError( + f"query executor must be set before calling {method_name}(). " + "Set the data_source first so querychat can build an executor." + ) + return self._query_executor + def _create_session_client( self, *, client_spec: str | chatlas.Chat | None | MISSING_TYPE = MISSING, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE = MISSING, update_dashboard: Callable[[UpdateDashboardData], None] | None = None, - reset_dashboard: Callable[[], None] | None = None, + reset_dashboard: ResetDashboardCallback | None = None, visualize: Callable[[VisualizeData], None] | None = None, ) -> chatlas.Chat: """Create a fresh, fully-configured Chat.""" @@ -178,20 +260,30 @@ def _create_session_client( if resolved_tools is None: return chat - data_source = self._require_data_source("_create_session_client") + self._require_data_source("_create_session_client") + assert self._query_executor is not None # noqa: S101 if "update" in resolved_tools: update_fn = update_dashboard or (lambda _: None) - reset_fn = reset_dashboard or (lambda: None) - chat.register_tool(tool_update_dashboard(data_source, update_fn)) - chat.register_tool(tool_reset_dashboard(reset_fn)) + user_reset = reset_dashboard or (lambda _table: None) + + chat.register_tool( + tool_update_dashboard( + self._query_executor, + list(self._data_sources.keys()), + update_fn, + ) + ) + chat.register_tool( + tool_reset_dashboard(user_reset, list(self._data_sources.keys())) + ) if "query" in resolved_tools: - chat.register_tool(tool_query(data_source)) + chat.register_tool(tool_query(self._query_executor)) if "visualize" in resolved_tools: viz_fn = visualize or (lambda _: None) - chat.register_tool(tool_visualize(data_source, viz_fn)) + chat.register_tool(tool_visualize(self._query_executor, viz_fn)) return chat @@ -200,7 +292,7 @@ def client( *, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE = MISSING, update_dashboard: Callable[[UpdateDashboardData], None] | None = None, - reset_dashboard: Callable[[], None] | None = None, + reset_dashboard: ResetDashboardCallback | None = None, visualize: Callable[[VisualizeData], None] | None = None, ) -> chatlas.Chat: """ @@ -265,25 +357,221 @@ def system_prompt(self) -> str: @property def data_source(self) -> DataSource | None: - """Get the current data source.""" - return self._data_source + """ + Get the data source (for single-table backwards compatibility). + + Returns None if no data source is set. Raises ValueError if multiple + tables are present - use .table("name").data_source instead. + """ + if not self._data_sources: + return None + if len(self._data_sources) == 1: + return next(iter(self._data_sources.values())) + raise ValueError( + f"Multiple tables present ({', '.join(self._data_sources.keys())}). " + "Use qc.table('name').data_source instead." + ) @data_source.setter def data_source(self, value: IntoFrame | sqlalchemy.Engine | BaseBoard) -> None: """Set the data source, normalizing and rebuilding system prompt.""" - old_source = self._data_source - if self._table_name is None: - raise ValueError("table_name must be set before assigning a data source") - self._data_source = normalize_data_source(value, self._table_name) - if old_source is not None and old_source is not self._data_source: - old_source.cleanup() - self._auto_fill_data_description() - self._build_system_prompt() + normalized = normalize_data_source(value, self._table_name) + try: + other_sources = { + name: source + for name, source in self._data_sources.items() + if name != self._table_name + } + check_source_compatibility(other_sources, normalized, self._table_name) + + next_data_sources = dict(self._data_sources) + next_data_sources[self._table_name] = normalized + self._data_source = normalized + self._auto_fill_data_description() + self._build_system_prompt(data_sources=next_data_sources) + except Exception: + cleanup_failed_staged_source(value, normalized) + raise + + self._data_sources = next_data_sources + + def table_names(self) -> list[str]: + """ + Return the names of all registered tables. + + Returns + ------- + list[str] + List of table names in the order they were added. + + """ + return list(self._data_sources.keys()) + + def table(self, name: str) -> TableAccessor: + """ + Get an accessor for a specific table. + + Parameters + ---------- + name + The name of the table to access. + + Returns + ------- + TableAccessor + An accessor object with df(), sql(), title() methods. + + Raises + ------ + ValueError + If the table doesn't exist. + + """ + if name not in self._data_sources: + available = ", ".join(self._data_sources.keys()) + raise ValueError(f"Table '{name}' not found. Available: {available}") + + return TableAccessor(self, name) + + def add_table( + self, + data_source: IntoFrame | sqlalchemy.Engine, + table_name: str, + *, + relationships: dict[str, str] | None = None, + description: str | None = None, + ) -> None: + """ + Add an additional table to the QueryChat instance. + + Parameters + ---------- + data_source + The data source (DataFrame, LazyFrame, or database connection). + table_name + Name for the table (must be unique within this QueryChat). + relationships + Optional dict mapping local columns to "other_table.column" for JOINs. + Example: {"customer_id": "customers.id"} + description + Optional free-text description of the table for the LLM. + + Raises + ------ + ValueError + If table_name already exists or is invalid. + RuntimeError + If called after server() has been invoked. + + """ + # Check if server already initialized + if self._server_initialized: + raise RuntimeError( + "Cannot add tables after server initialization. " + "Add all tables before calling .server() or .app()." + ) + + # Validate table name format + if not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", table_name): + raise ValueError( + "Table name must begin with a letter and contain only " + "letters, numbers, and underscores" + ) + + # Check for duplicates + if table_name in self._data_sources: + raise ValueError(f"Table '{table_name}' already exists") + + # Normalize and validate compatibility with existing sources + normalized = normalize_data_source(data_source, table_name) + try: + check_source_compatibility(self._data_sources, normalized, table_name) + next_data_sources = dict(self._data_sources) + next_data_sources[table_name] = normalized + next_relationships = dict(self._table_relationships) + next_table_descriptions = dict(self._table_descriptions) + + # Store relationship and description metadata + if relationships: + next_relationships[table_name] = relationships + if description: + next_table_descriptions[table_name] = description + + # Rebuild system prompt with new table + self._build_system_prompt( + data_sources=next_data_sources, + relationships=next_relationships, + table_descriptions=next_table_descriptions, + ) + except Exception: + cleanup_failed_staged_source(data_source, normalized) + raise + + self._data_sources = next_data_sources + self._table_relationships = next_relationships + self._table_descriptions = next_table_descriptions + + def remove_table(self, table_name: str) -> None: + """ + Remove a table from the QueryChat instance. + + Parameters + ---------- + table_name + Name of the table to remove. + + Raises + ------ + ValueError + If table doesn't exist or is the last remaining table. + RuntimeError + If called after server() has been invoked. + + """ + if self._server_initialized: + raise RuntimeError( + "Cannot remove tables after server initialization. " + "Configure all tables before calling .server() or .app()." + ) + + if table_name not in self._data_sources: + available = ", ".join(self._data_sources.keys()) + raise ValueError(f"Table '{table_name}' not found. Available: {available}") + + if len(self._data_sources) == 1: + raise ValueError( + "Cannot remove last table. At least one table is required." + ) + + removed_source = self._data_sources[table_name] + next_data_sources = dict(self._data_sources) + del next_data_sources[table_name] + next_relationships = dict(self._table_relationships) + next_relationships.pop(table_name, None) + next_table_descriptions = dict(self._table_descriptions) + next_table_descriptions.pop(table_name, None) + + # Rebuild system prompt without removed table + self._build_system_prompt( + data_sources=next_data_sources, + relationships=next_relationships, + table_descriptions=next_table_descriptions, + ) + self._data_sources = next_data_sources + self._table_relationships = next_relationships + self._table_descriptions = next_table_descriptions + removed_source.cleanup() + + def _mark_server_initialized(self) -> None: + """Mark that the server has been initialized. Prevents add/remove_table.""" + self._server_initialized = True def cleanup(self) -> None: - """Clean up resources associated with the data source.""" - if self._data_source is not None: - self._data_source.cleanup() + """Clean up resources associated with all data sources.""" + if hasattr(self, "_query_executor") and self._query_executor is not None: + self._query_executor.cleanup() + for source in self._data_sources.values(): + source.cleanup() def normalize_data_source( @@ -330,6 +618,24 @@ def normalize_data_source( ) +def cleanup_failed_staged_source( + original_source: IntoFrame | sqlalchemy.Engine | DataSource, + normalized_source: DataSource, +) -> None: + """ + Clean up transient resources created during a failed staged rebuild. + + Only DataFrameSource owns a fresh disposable connection created during + normalization. SQLAlchemySource wraps a caller-owned engine, while + PolarsLazySource and IbisSource do not allocate disposable resources here. + """ + if isinstance(original_source, (DataSource, sqlalchemy.Engine)): + return + + if isinstance(normalized_source, DataFrameSource): + normalized_source.cleanup() + + def create_client(client: str | chatlas.Chat | None) -> chatlas.Chat: """Resolve a client spec into a fresh Chat with no conversation history.""" if client is None: diff --git a/pkg-py/src/querychat/_querychat_core.py b/pkg-py/src/querychat/_querychat_core.py index fb3134beb..2888aeed8 100644 --- a/pkg-py/src/querychat/_querychat_core.py +++ b/pkg-py/src/querychat/_querychat_core.py @@ -14,8 +14,8 @@ ] from collections.abc import Callable -from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, Optional, TypedDict, Union +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Generic, NotRequired, Optional, TypedDict, Union from chatlas import Chat, ContentToolRequest, ContentToolResult from chatlas.types import Content @@ -36,10 +36,11 @@ from narwhals.stable.v1.typing import IntoFrame from ._datasource import DataSource + from ._query_executor import QueryExecutor ClientFactory = Callable[ - [Callable[[UpdateDashboardData], None], Callable[[], None]], + [Callable[[UpdateDashboardData], None], Callable[[str], None]], Chat, ] """Factory that creates a Chat client with update_dashboard and reset_dashboard callbacks.""" @@ -48,6 +49,7 @@ class AppStateDict(TypedDict): """Serialized AppState for framework state stores.""" + table: NotRequired[str | None] sql: str | None title: str | None error: str | None @@ -65,11 +67,13 @@ class StateDictAccessorMixin(Generic[IntoFrameT]): """Mixin providing df/sql/title accessors for frameworks using serialized state dicts.""" _data_source: DataSource[IntoFrameT] | None + _data_sources: dict[str, DataSource[IntoFrameT]] + _query_executor: QueryExecutor | None def _client_factory( self, update_cb: Callable[[UpdateDashboardData], None], - reset_cb: Callable[[], None], + reset_cb: Callable[[str], None], ) -> Chat: """Create a chat client with dashboard callbacks.""" return self.client(update_dashboard=update_cb, reset_dashboard=reset_cb) # type: ignore[attr-defined] @@ -90,15 +94,33 @@ def df(self, state: AppStateDict | None) -> IntoFrameT: Returns a LazyFrame if the data source is lazy. """ - data_source = self._require_data_source("df") # type: ignore[attr-defined] + data_source = self._get_state_data_source(state) # type: ignore[attr-defined] sql = state.get("sql") if state else None if sql: try: - return data_source.execute_query(sql) + query_executor = self._require_query_executor("df") # type: ignore[attr-defined] + return query_executor.execute_query(sql) except Exception: return data_source.get_data() return data_source.get_data() + def _get_state_data_source( + self, state: AppStateDict | None + ) -> DataSource[IntoFrameT]: + """Resolve the full-data source for a serialized state payload.""" + default_source = self._require_data_source("_get_state_data_source") # type: ignore[attr-defined] + if not state: + return default_source + + table_name = state.get("table") + if table_name is None: + return default_source + + if table_name in self._data_sources: + return self._data_sources[table_name] + + return default_source + def sql(self, state: AppStateDict | None) -> str | None: """ Get the current SQL query from state. @@ -140,6 +162,8 @@ def _deserialize_state(self, state_data: AppStateDict | None) -> AppState: data_source, self._client_factory, self.greeting, # type: ignore[attr-defined] + data_sources=dict(self._data_sources), # type: ignore[attr-defined] + query_executor=self._require_query_executor("_deserialize_state"), # type: ignore[attr-defined] ) if state_data: state.update_from_dict(state_data) @@ -201,39 +225,59 @@ class AppState: data_source: DataSource client: Chat + query_executor: QueryExecutor | None = None + data_sources: dict[str, DataSource] = field(default_factory=dict) greeting: Optional[str] = None + active_table: str | None = None sql: Optional[str] = None title: Optional[str] = None error: Optional[str] = None + def __post_init__(self) -> None: + if not self.data_sources: + self.data_sources = {self.data_source.table_name: self.data_source} + if self.active_table is None: + self.active_table = self.data_source.table_name + def update_dashboard(self, data: UpdateDashboardData) -> None: + self.active_table = data["table"] self.sql = data["query"] self.title = data["title"] self.error = None # Clear any previous error on successful update - def reset_dashboard(self) -> None: + def reset_dashboard(self, table: str | None = None) -> None: + if table is not None: + self.active_table = table self.sql = None self.title = None self.error = None + def get_active_data_source(self) -> DataSource: + """Return the current full-data source for the active table.""" + if self.active_table is None: + return self.data_source + return self.data_sources.get(self.active_table, self.data_source) + def get_current_data(self) -> IntoFrame: """Get current data, falling back to default if query fails.""" + data_source = self.get_active_data_source() if self.sql: try: - result = self.data_source.execute_query(self.sql) + query_runner = self.query_executor or data_source + result = query_runner.execute_query(self.sql) self.error = None # Clear error on success return result except Exception as e: self.error = format_query_error(e) self.sql = None self.title = None - return self.data_source.get_data() - self.error = None - return self.data_source.get_data() + return data_source.get_data() + return data_source.get_data() def get_display_sql(self) -> str: - return self.sql or f"SELECT * FROM {self.data_source.table_name}" + table_name = self.active_table or self.data_source.table_name + return self.sql or f"SELECT * FROM {table_name}" def get_display_messages(self) -> list[DisplayMessage]: """ @@ -280,6 +324,7 @@ def initialize_greeting_if_preset(self) -> bool: def to_dict(self) -> AppStateDict: """Serialize state to dict for framework state stores.""" return { + "table": self.active_table, "sql": self.sql, "title": self.title, "error": self.error, @@ -290,6 +335,7 @@ def update_from_dict(self, data: AppStateDict) -> None: """Restore state from serialized dict.""" from chatlas import Turn + self.active_table = data.get("table", self.data_source.table_name) self.sql = data["sql"] self.title = data["title"] self.error = data["error"] @@ -303,6 +349,9 @@ def create_app_state( data_source: DataSource, client_factory: ClientFactory, greeting: Optional[str] = None, + *, + data_sources: dict[str, DataSource] | None = None, + query_executor: QueryExecutor | None = None, ) -> AppState: """Create AppState with callbacks connected via holder pattern.""" state_holder: dict[str, AppState | None] = {"state": None} @@ -313,16 +362,22 @@ def update_callback(data: UpdateDashboardData) -> None: raise RuntimeError("Callback invoked before state initialization") state.update_dashboard(data) - def reset_callback() -> None: + def reset_callback(_table: str) -> None: state = state_holder["state"] if state is None: raise RuntimeError("Callback invoked before state initialization") - state.reset_dashboard() + state.reset_dashboard(_table) client = client_factory(update_callback, reset_callback) state = AppState( data_source=data_source, client=client, + query_executor=query_executor, + data_sources=( + dict(data_sources) + if data_sources is not None + else {data_source.table_name: data_source} + ), greeting=greeting, ) state_holder["state"] = state diff --git a/pkg-py/src/querychat/_shiny.py b/pkg-py/src/querychat/_shiny.py index 9f05132db..b139d4309 100644 --- a/pkg-py/src/querychat/_shiny.py +++ b/pkg-py/src/querychat/_shiny.py @@ -300,9 +300,11 @@ def app_ui(request): ) def app_server(input: Inputs, output: Outputs, session: Session): + self._mark_server_initialized() vals = mod_server( self.id, - data_source=data_source, + data_sources=dict(self._data_sources), + executor=self._query_executor, greeting=self.greeting, client=self._create_session_client, enable_bookmarking=enable_bookmarking, @@ -498,7 +500,7 @@ def title(): if data_source is not None: self.data_source = data_source - resolved_data_source = self._require_data_source("server") + self._require_data_source("server") resolved_client_spec = self._client_spec if isinstance(client, MISSING_TYPE) else client def create_session_client(**kwargs) -> chatlas.Chat: @@ -506,9 +508,11 @@ def create_session_client(**kwargs) -> chatlas.Chat: client_spec=resolved_client_spec, **kwargs ) + self._mark_server_initialized() return mod_server( id or self.id, - data_source=resolved_data_source, + data_sources=dict(self._data_sources), + executor=self._query_executor, greeting=self.greeting, client=create_session_client, enable_bookmarking=enable_bookmarking, @@ -751,9 +755,15 @@ def __init__( else: enable = enable_bookmarking + if self._data_source is None and isinstance(session, ExpressStubSession): + return + + self._require_data_source("__init__") + self._mark_server_initialized() self._vals = mod_server( self.id, - data_source=self._data_source, + data_sources=dict(self._data_sources), + executor=self._query_executor, greeting=self.greeting, client=self._create_session_client, enable_bookmarking=enable, diff --git a/pkg-py/src/querychat/_shiny_module.py b/pkg-py/src/querychat/_shiny_module.py index 73683ad28..bfed76807 100644 --- a/pkg-py/src/querychat/_shiny_module.py +++ b/pkg-py/src/querychat/_shiny_module.py @@ -24,6 +24,8 @@ from shiny import Inputs, Outputs, Session from ._datasource import DataSource + from ._query_executor import QueryExecutor + from ._querychat_base import TOOL_GROUPS from ._viz_tools import VisualizeData from .types import UpdateDashboardData @@ -55,6 +57,35 @@ def __getattr__(self, _name: str): ServerClient = chatlas.Chat | _DeferredStubChatClient +@dataclass +class TableState(Generic[IntoFrameT]): + """Per-table reactive state.""" + + sql: ReactiveStringOrNone + title: ReactiveStringOrNone + df: Callable[[], IntoFrameT] + + +class _MultiTableGuard: + """Raises ValueError when accessed, guiding users to per-table API.""" + + def __init__(self, field: str, table_names: list[str]): + names = ", ".join(f"'{n}'" for n in table_names) + self._msg = ( + f"Multiple tables present ({names}). " + f"Use qc.table('name').{field}() instead." + ) + + def __call__(self, *args: object, **kwargs: object) -> None: # noqa: ARG002 + raise ValueError(self._msg) + + def set(self, value: object) -> None: # noqa: ARG002 + raise ValueError(self._msg) + + def get(self) -> None: + raise ValueError(self._msg) + + @module.ui def mod_ui(*, preload_viz: bool = False, **kwargs): css_path = Path(__file__).parent / "static" / "css" / "styles.css" @@ -99,6 +130,10 @@ class ServerValues(Generic[IntoFrameT]): provides this title when generating a new SQL query. Access it with `.title()`, or set it with `.title.set("...")`. Returns `None` if no title has been set. + tables + Per-table reactive state. Keys are table names. Each value is a + `TableState` with `sql`, `title`, and `df` attributes. Always populated, + even for single-table usage. client Session chat client value. For real sessions this is a `chatlas.Chat` created by the client @@ -110,6 +145,7 @@ class ServerValues(Generic[IntoFrameT]): df: Callable[[], IntoFrameT] sql: ReactiveStringOrNone title: ReactiveStringOrNone + tables: dict[str, TableState[IntoFrameT]] client: ServerClient @@ -119,27 +155,45 @@ def mod_server( output: Outputs, session: Session, *, - data_source: DataSource[IntoFrameT] | None, + data_sources: dict[str, DataSource[IntoFrameT]] | None, + executor: QueryExecutor | None, greeting: str | None, client: Callable[..., chatlas.Chat], enable_bookmarking: bool, tools: set[str] | None = None, ) -> ServerValues[IntoFrameT]: - # Reactive values to store state - sql = ReactiveStringOrNone(None) - title = ReactiveStringOrNone(None) has_greeted = reactive.value[bool](False) # noqa: FBT003 if not callable(client): raise TypeError("mod_server() requires a callable client factory.") + table_states: dict[str, TableState[IntoFrameT]] = {} + + def _make_table_state( + source: DataSource[IntoFrameT], exec: QueryExecutor + ) -> TableState[IntoFrameT]: + table_sql = ReactiveStringOrNone(None) + table_title = ReactiveStringOrNone(None) + + @reactive.calc + def filtered_df() -> IntoFrameT: + query = table_sql.get() + if query: + return exec.execute_query(query) + return source.get_data() + + return TableState(sql=table_sql, title=table_title, df=filtered_df) + def update_dashboard(data: UpdateDashboardData): - sql.set(data["query"]) - title.set(data["title"]) + table_name = data["table"] + if table_name in table_states: + table_states[table_name].sql.set(data["query"]) + table_states[table_name].title.set(data["title"]) - def reset_dashboard(): - sql.set(None) - title.set(None) + def reset_dashboard(table_name: str): + if table_name in table_states: + table_states[table_name].sql.set(None) + table_states[table_name].title.set(None) viz_widgets: list[VizWidgetEntry] = [] @@ -155,43 +209,40 @@ def build_chat_client() -> chatlas.Chat: ) # Short-circuit for stub sessions (e.g. 1st run of an Express app) - # data_source may be None during stub session for deferred pattern + # data_sources may be None during stub session for deferred pattern if session.is_stub_session(): # Mock the error that would otherwise occur in a real session def _stub_df(): raise RuntimeError("RuntimeError: No current reactive context") stub_client = ( - _DeferredStubChatClient() if data_source is None else build_chat_client() + _DeferredStubChatClient() if data_sources is None else build_chat_client() ) return ServerValues( df=_stub_df, - sql=sql, - title=title, + sql=ReactiveStringOrNone(None), + title=ReactiveStringOrNone(None), + tables={}, client=stub_client, ) - # Real session requires data_source - if data_source is None: + # Real session requires data_sources and executor + if data_sources is None or executor is None: raise RuntimeError( "data_source must be set before the real session. " "Set it via the data_source property before users connect." ) + for name, source in data_sources.items(): + table_states[name] = _make_table_state(source, executor) + # Build the session-specific chat client through QueryChat.client(...). chat = build_chat_client() if has_viz_tool(tools): preload_viz_deps_server() - # Execute query when SQL changes - @reactive.calc - def filtered_df(): - query = sql.get() - df = data_source.get_data() if not query else data_source.execute_query(query) - return df - # Chat UI logic chat_ui = shinychat.Chat(CHAT_ID) ctrl = chatlas.StreamController() @@ -232,17 +283,16 @@ async def greet_on_startup(): @reactive.event(input.chat_update) def _(): update = input.chat_update() - if update is None: - return - if not isinstance(update, dict): + if update is None or not isinstance(update, dict): return - + table_name = update.get("table", "") new_query = update.get("query") new_title = update.get("title") - if new_query is not None: - sql.set(new_query) - if new_title is not None: - title.set(new_title) + if table_name and table_name in table_states: + if new_query is not None: + table_states[table_name].sql.set(new_query) + if new_title is not None: + table_states[table_name].title.set(new_title) if enable_bookmarking: chat_ui.enable_bookmarking(chat) @@ -250,28 +300,50 @@ def _(): @session.bookmark.on_bookmark def _on_bookmark(x: BookmarkState) -> None: vals = x.values - vals["querychat_sql"] = sql.get() - vals["querychat_title"] = title.get() vals["querychat_has_greeted"] = has_greeted.get() + for name, state in table_states.items(): + vals[f"querychat_sql_{name}"] = state.sql.get() + vals[f"querychat_title_{name}"] = state.title.get() if viz_widgets: vals["querychat_viz_widgets"] = viz_widgets @session.bookmark.on_restore def _on_restore(x: RestoreState) -> None: vals = x.values - if "querychat_sql" in vals: - sql.set(vals["querychat_sql"]) - if "querychat_title" in vals: - title.set(vals["querychat_title"]) if "querychat_has_greeted" in vals: has_greeted.set(vals["querychat_has_greeted"]) + for name, state in table_states.items(): + if f"querychat_sql_{name}" in vals: + state.sql.set(vals[f"querychat_sql_{name}"]) + if f"querychat_title_{name}" in vals: + state.title.set(vals[f"querychat_title_{name}"]) if "querychat_viz_widgets" in vals: restored = restore_viz_widgets( - data_source, vals["querychat_viz_widgets"] + executor, vals["querychat_viz_widgets"] ) viz_widgets[:] = restored - return ServerValues(df=filtered_df, sql=sql, title=title, client=chat) + # Build return value with backward-compatible flat fields + table_names = list(data_sources.keys()) + is_multi = len(table_names) > 1 + + if is_multi: + return ServerValues( + df=_MultiTableGuard("df", table_names), # type: ignore[arg-type] + sql=_MultiTableGuard("sql", table_names), # type: ignore[arg-type] + title=_MultiTableGuard("title", table_names), # type: ignore[arg-type] + tables=table_states, + client=chat, + ) + else: + first_state = next(iter(table_states.values())) + return ServerValues( + df=first_state.df, + sql=first_state.sql, + title=first_state.title, + tables=table_states, + client=chat, + ) class GreetWarning(Warning): @@ -279,7 +351,7 @@ class GreetWarning(Warning): def restore_viz_widgets( - data_source: DataSource[IntoFrameT], + executor: QueryExecutor, saved_widgets: list[VizWidgetEntry], ) -> list[VizWidgetEntry]: """Re-execute ggsql queries, register widgets, and return restored entries.""" @@ -293,7 +365,7 @@ def restore_viz_widgets( ggsql_str = entry["ggsql"] try: validated = validate(ggsql_str) - spec = execute_ggsql(data_source, validated) + spec = execute_ggsql(executor, validated) altair_widget = AltairWidget.from_ggsql(spec, widget_id=widget_id) register_widget(widget_id, altair_widget.widget) restored.append(entry) diff --git a/pkg-py/src/querychat/_streamlit.py b/pkg-py/src/querychat/_streamlit.py index b68a6effc..b6834de66 100644 --- a/pkg-py/src/querychat/_streamlit.py +++ b/pkg-py/src/querychat/_streamlit.py @@ -172,6 +172,8 @@ def _get_state(self) -> AppState: reset_dashboard=reset_cb, ), self.greeting, + data_sources=dict(self._data_sources), + query_executor=self._require_query_executor("_get_state"), ) return st.session_state[self._state_key] diff --git a/pkg-py/src/querychat/_system_prompt.py b/pkg-py/src/querychat/_system_prompt.py index f690a0696..d3db65e13 100644 --- a/pkg-py/src/querychat/_system_prompt.py +++ b/pkg-py/src/querychat/_system_prompt.py @@ -20,22 +20,37 @@ class QueryChatSystemPrompt: def __init__( self, prompt_template: str | Path, - data_source: DataSource, + data_source: DataSource | None = None, + data_sources: dict[str, DataSource] | None = None, data_description: str | Path | None = None, extra_instructions: str | Path | None = None, categorical_threshold: int = 10, + relationships: dict[str, dict[str, str]] | None = None, + table_descriptions: dict[str, str] | None = None, ): """ Initialize with prompt components. Args: prompt_template: Mustache template string or path to template file - data_source: DataSource instance for schema generation + data_source: Single DataSource instance (backwards compatibility) + data_sources: Dictionary of DataSource instances keyed by table name data_description: Optional data context (string or path) extra_instructions: Optional custom LLM instructions (string or path) categorical_threshold: Threshold for categorical column detection + relationships: Optional dict mapping table.column to foreign table.column + table_descriptions: Optional dict mapping table names to descriptions """ + # Handle both single source (backwards compat) and dict of sources + if data_sources is not None: + self._data_sources = data_sources + elif data_source is not None: + self._data_sources = {data_source.table_name: data_source} + else: + raise ValueError("Either data_source or data_sources must be provided") + + # Load template if isinstance(prompt_template, Path): self.template = prompt_template.read_text() else: @@ -51,15 +66,36 @@ def __init__( else: self.extra_instructions = extra_instructions + self.categorical_threshold = categorical_threshold + self._relationships = relationships or {} + self._table_descriptions = table_descriptions or {} + + # Generate combined schema (skip if template doesn't reference it) if _SCHEMA_TAG_RE.search(self.template): - self.schema = data_source.get_schema( - categorical_threshold=categorical_threshold - ) + self.schema = self._generate_combined_schema() else: self.schema = "" - self.categorical_threshold = categorical_threshold - self.data_source = data_source + def _generate_combined_schema(self) -> str: + """Generate schema string for all tables.""" + schemas = [] + for name, source in self._data_sources.items(): + schema = source.get_schema(categorical_threshold=self.categorical_threshold) + schemas.append(f'\n{schema}\n
') + + return "\n\n".join(schemas) + + def _generate_relationships_text(self) -> str: + """Generate relationship information text.""" + if not self._relationships: + return "" + + lines = [] + for table, rels in self._relationships.items(): + for local_col, foreign_ref in rels.items(): + lines.append(f"- {table}.{local_col} references {foreign_ref}") + + return "\n".join(lines) def render(self, tools: set[str] | None) -> str: """ @@ -72,13 +108,14 @@ def render(self, tools: set[str] | None) -> str: Fully rendered system prompt string """ - db_type = self.data_source.get_db_type() + first_source = next(iter(self._data_sources.values())) + db_type = first_source.get_db_type() is_duck_db = db_type.lower() == "duckdb" context = { "db_type": db_type, "is_duck_db": is_duck_db, - "semantic_views": self.data_source.get_semantic_views_description(), + "semantic_views": first_source.get_semantic_views_description(), "schema": self.schema, "data_description": self.data_description, "extra_instructions": self.extra_instructions, @@ -86,6 +123,8 @@ def render(self, tools: set[str] | None) -> str: "has_tool_query": "query" in tools if tools else False, "has_tool_visualize": has_viz_tool(tools), "include_query_guidelines": len(tools or ()) > 0, + "relationships": self._generate_relationships_text(), + "multi_table": len(self._data_sources) > 1, } prompts_dir = str(Path(__file__).parent / "prompts") @@ -95,3 +134,10 @@ def render(self, tools: set[str] | None) -> str: partials_path=prompts_dir, partials_ext="md", ) + + @property + def data_source(self) -> DataSource: + """Return single data source for backwards compatibility.""" + if len(self._data_sources) == 1: + return next(iter(self._data_sources.values())) + raise ValueError("Multiple data sources present; use _data_sources instead") diff --git a/pkg-py/src/querychat/_table_accessor.py b/pkg-py/src/querychat/_table_accessor.py new file mode 100644 index 000000000..f6f1ad056 --- /dev/null +++ b/pkg-py/src/querychat/_table_accessor.py @@ -0,0 +1,113 @@ +"""TableAccessor class for accessing per-table state and data.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast + +if TYPE_CHECKING: + from shiny import ui + + from ._datasource import DataSource + from ._querychat_base import QueryChatBase + from ._shiny_module import ServerValues + + +class TableAccessor: + """ + Accessor for a specific table's state and data. + + This class provides access to per-table data source and (when server is initialized) + reactive state. It is returned by QueryChat.table("name"). + + Parameters + ---------- + querychat + The parent QueryChat instance. + table_name + The name of the table this accessor represents. + + """ + + def __init__(self, querychat: QueryChatBase, table_name: str): + self._querychat = querychat + self._table_name = table_name + + @property + def table_name(self) -> str: + """The name of this table.""" + return self._table_name + + @property + def data_source(self) -> DataSource: + """The data source for this table.""" + return self._querychat._data_sources[self._table_name] + + def _require_server_values(self) -> ServerValues[Any]: + """Return typed per-session state after verifying server initialization.""" + vals = getattr(self._querychat, "_vals", None) + if vals is None: + raise RuntimeError("Server not initialized. Call .server() first.") + return cast("ServerValues[Any]", vals) + + def df(self) -> Any: + """ + Return the current filtered data for this table (reactive). + + Returns the native DataFrame type (polars, pandas, ibis.Table, etc.) + for this table's data source. + + Raises + ------ + RuntimeError + If called before server initialization. + + """ + return self._require_server_values().tables[self._table_name].df() + + def sql(self) -> str | None: + """ + Return the current SQL filter for this table (reactive). + + Raises + ------ + RuntimeError + If called before server initialization. + + """ + return self._require_server_values().tables[self._table_name].sql.get() + + def title(self) -> str | None: + """ + Return the current filter title for this table (reactive). + + Raises + ------ + RuntimeError + If called before server initialization. + + """ + return self._require_server_values().tables[self._table_name].title.get() + + def ui(self) -> ui.Tag: + """ + Render the UI for this table (data table + SQL display). + + Returns + ------- + Tag + A Shiny UI element containing the data table and SQL display. + + """ + from shiny import ui as shiny_ui + + querychat_id = getattr(self._querychat, "id", None) + if not isinstance(querychat_id, str): + raise RuntimeError("QueryChat instance is missing an id.") + + table_id = f"{querychat_id}_{self._table_name}" + + return shiny_ui.card( + shiny_ui.card_header(self._table_name), + shiny_ui.output_data_frame(f"{table_id}_dt"), + shiny_ui.output_text(f"{table_id}_sql"), + ) diff --git a/pkg-py/src/querychat/_viz_ggsql.py b/pkg-py/src/querychat/_viz_ggsql.py index 076b4f3b3..9e166c151 100644 --- a/pkg-py/src/querychat/_viz_ggsql.py +++ b/pkg-py/src/querychat/_viz_ggsql.py @@ -10,20 +10,20 @@ if TYPE_CHECKING: import ggsql - from ._datasource import DataSource + from ._query_executor import QueryExecutor -def execute_ggsql(data_source: DataSource, validated: ggsql.Validated) -> ggsql.Spec: +def execute_ggsql(executor: QueryExecutor, validated: ggsql.Validated) -> ggsql.Spec: """ - Execute a pre-validated ggsql query against a DataSource, returning a Spec. + Execute a pre-validated ggsql query against a QueryExecutor, returning a Spec. - Executes the SQL portion through DataSource (preserving database pushdown), + Executes the SQL portion through the executor (preserving database pushdown), then feeds the result into a ggsql DuckDBReader to produce a Spec. Parameters ---------- - data_source - The querychat DataSource to execute the SQL portion against. + executor + The querychat QueryExecutor to execute the SQL portion against. validated A pre-validated ggsql query (from ``ggsql.validate()``). @@ -47,7 +47,7 @@ def execute_ggsql(data_source: DataSource, validated: ggsql.Validated) -> ggsql. "result." ) - pl_df = to_polars(data_source.execute_query(validated.sql())) + pl_df = to_polars(executor.execute_query(validated.sql())) reader = DuckDBReader("duckdb://memory") table = extract_visualise_table(visual) diff --git a/pkg-py/src/querychat/_viz_tools.py b/pkg-py/src/querychat/_viz_tools.py index 8000aa461..a4664e3d8 100644 --- a/pkg-py/src/querychat/_viz_tools.py +++ b/pkg-py/src/querychat/_viz_tools.py @@ -27,7 +27,7 @@ import altair as alt from ipywidgets.widgets.widget import Widget - from ._datasource import DataSource + from ._query_executor import QueryExecutor class VisualizeData(TypedDict): @@ -55,7 +55,7 @@ class VisualizeData(TypedDict): def tool_visualize( - data_source: DataSource, + executor: QueryExecutor, update_fn: Callable[[VisualizeData], None], ) -> Tool: """ @@ -63,8 +63,8 @@ def tool_visualize( Parameters ---------- - data_source - The data source to query against + executor + The query executor to query against update_fn Callback function to call with VisualizeData when visualization succeeds @@ -74,10 +74,10 @@ def tool_visualize( A tool that can be registered with chatlas """ - impl = visualize_impl(data_source, update_fn) + impl = visualize_impl(executor, update_fn) impl.__doc__ = read_prompt_template( "tool-visualize.md", - db_type=data_source.get_db_type(), + db_type=executor.get_db_type(), ) return Tool.from_func( @@ -146,7 +146,7 @@ def __init__( def visualize_impl( - data_source: DataSource, + executor: QueryExecutor, update_fn: Callable[[VisualizeData], None], ) -> Callable[[str, str], ContentToolResult]: """Create the visualize implementation function.""" @@ -172,7 +172,7 @@ def visualize( "\n".join(error["message"] for error in validated.errors()) ) - spec = execute_ggsql(data_source, validated) + spec = execute_ggsql(executor, validated) raw_chart = VegaLiteWriter().render_chart(spec) altair_widget = AltairWidget(copy.deepcopy(raw_chart)) diff --git a/pkg-py/src/querychat/prompts/prompt.md b/pkg-py/src/querychat/prompts/prompt.md index 00272cada..e44cc21a8 100644 --- a/pkg-py/src/querychat/prompts/prompt.md +++ b/pkg-py/src/querychat/prompts/prompt.md @@ -3,7 +3,7 @@ You are a data dashboard chatbot that operates in a sidebar interface. Your role You have access to a {{db_type}} SQL database with the following schema: -{{schema}} +{{{schema}}} {{#data_description}} @@ -14,7 +14,15 @@ Here is additional information about the data: {{/data_description}} -For security reasons, you may only query this specific table. +{{#relationships}} + +{{{relationships}}} + + +When answering questions that span multiple tables, use JOINs based on these relationships. +{{/relationships}} + +For security reasons, you may only query {{#relationships}}these specific tables{{/relationships}}{{^relationships}}this specific table{{/relationships}}. {{#include_query_guidelines}} ## SQL Query Guidelines @@ -82,18 +90,19 @@ You can handle these types of requests: When the user asks you to filter or sort the dashboard, e.g. "Show me..." or "Which ____ have the highest ____?" or "Filter to only include ____": - Write a {{db_type}} SQL SELECT query -- Call `querychat_update_dashboard` with the query and a descriptive title -- The query MUST return all columns from the schema (you can use `SELECT *`) +- Call `querychat_update_dashboard` with the query, table name, and a descriptive title +- You MUST specify the `table` parameter to indicate which table to filter +- The query MUST return all columns from the specified table's schema (you can use `SELECT *`) - Use a single SQL query even if complex (subqueries and CTEs are fine) - Optimize for **readability over efficiency** - Include SQL comments to explain complex logic - No confirmation messages are needed: the user will see your query in the dashboard. -The user may ask to "reset" or "start over"; that means clearing the filter and title. Do this by calling `querychat_reset_dashboard()`. +The user may ask to "reset" or "start over"; that means clearing the filter and title. Do this by calling `querychat_reset_dashboard` with the relevant `table`. **Filtering Example:** User: "Show only rows where sales are above average" -Tool Call: `querychat_update_dashboard({query: "SELECT * FROM table WHERE sales > (SELECT AVG(sales) FROM table)", title: "Above average sales"})` +Tool Call: `querychat_update_dashboard({query: "SELECT * FROM sales_data WHERE sales > (SELECT AVG(sales) FROM sales_data)", table: "sales_data", title: "Above average sales"})` Response: "" No further response needed, the user will see the updated dashboard. diff --git a/pkg-py/src/querychat/prompts/tool-query.md b/pkg-py/src/querychat/prompts/tool-query.md index 65bc7d899..bd8954738 100644 --- a/pkg-py/src/querychat/prompts/tool-query.md +++ b/pkg-py/src/querychat/prompts/tool-query.md @@ -25,6 +25,11 @@ Always use SQL for counting, averaging, summing, and other calculations—NEVER - When using `collapsed=false`, avoid duplicating the same rows/values in both the tool result and your response text - Do not reproduce large result sets in your response — summarize the key takeaways instead +{{#multi_table}} + +**Multi-table queries:** Your schema includes multiple tables. You can reference any table in your queries and use JOINs when the data spans tables. Use the relationships described in the schema to determine join conditions. + +{{/multi_table}} Parameters ---------- query : diff --git a/pkg-py/src/querychat/prompts/tool-reset-dashboard.md b/pkg-py/src/querychat/prompts/tool-reset-dashboard.md index 7d78b4b43..a44d0a7f1 100644 --- a/pkg-py/src/querychat/prompts/tool-reset-dashboard.md +++ b/pkg-py/src/querychat/prompts/tool-reset-dashboard.md @@ -2,9 +2,14 @@ Reset the dashboard to its original state Resets the dashboard to use the original unfiltered dataset and clears any custom title. -If the user asks to reset the dashboard, simply call this tool with no other response. The reset action will be obvious to the user. +If the user asks to reset the dashboard, call this tool with the relevant table name and no other response. The reset action will be obvious to the user. -If the user asks to start over, call this tool and then provide a new set of suggestions for next steps. Include suggestions that encourage exploration of the data in new directions. +If the user asks to start over, call this tool with the relevant table name and then provide a new set of suggestions for next steps. Include suggestions that encourage exploration of the data in new directions. + +Parameters +---------- +table + The name of the table to reset. Returns ------- diff --git a/pkg-py/src/querychat/prompts/tool-update-dashboard.md b/pkg-py/src/querychat/prompts/tool-update-dashboard.md index dae9861c0..809c3b447 100644 --- a/pkg-py/src/querychat/prompts/tool-update-dashboard.md +++ b/pkg-py/src/querychat/prompts/tool-update-dashboard.md @@ -2,6 +2,8 @@ Filter and sort the dashboard data This tool executes a {{db_type}} SQL SELECT query to filter or sort the data used in the dashboard. +The `table` parameter specifies which table to filter. Use the table name exactly as shown in the schema. + **When to use:** Call this tool whenever the user requests filtering, sorting, or data manipulation on the dashboard with questions like "Show me..." or "Which records have...". This tool is appropriate for any request that involves showing a subset of the data or reordering it. **When not to use:** Do NOT use this tool for general questions about the data that can be answered with a single value or summary statistic. For those questions, use the `querychat_query` tool instead. @@ -14,8 +16,15 @@ This tool executes a {{db_type}} SQL SELECT query to filter or sort the data use - Assume the user will only see the original columns in the dataset +{{#multi_table}} + +**Multi-table filters:** When filtering a table, you may reference other tables in WHERE clauses, subqueries, or CTEs (e.g., filtering orders by a condition on customers). The result must still return all columns of the target table specified by the `table` parameter. + +{{/multi_table}} Parameters ---------- +table : + The name of the table to filter. Must match exactly one of the table names from the schema. query : A {{db_type}} SQL SELECT query that MUST return all existing schema columns (use SELECT * or explicitly list all columns). May include additional computed columns, subqueries, CTEs, WHERE clauses, ORDER BY, and any {{db_type}}-supported SQL functions. title : diff --git a/pkg-py/src/querychat/prompts/tool-visualize.md b/pkg-py/src/querychat/prompts/tool-visualize.md index c43f4da4d..4475cd3ef 100644 --- a/pkg-py/src/querychat/prompts/tool-visualize.md +++ b/pkg-py/src/querychat/prompts/tool-visualize.md @@ -10,6 +10,11 @@ Render a ggsql query (SQL with a VISUALISE clause) as an Altair chart displayed - Do NOT include `LABEL title => ...` in the query — use the `title` parameter instead. - If a visualization fails, read the error message carefully and retry with a corrected query. Common fixes: correcting column names, adding `SCALE DISCRETE` for integer categories, moving SQL expressions out of `VISUALISE` into the `SELECT` clause, and using `DRAW range` for interval-style marks instead of deprecated `errorbar`.{{#has_tool_query}} If the error persists, fall back to `querychat_query` for a tabular answer.{{/has_tool_query}} +{{#multi_table}} + +**Multi-table queries:** The SELECT portion of your ggsql query can reference any table from the schema and use JOINs. + +{{/multi_table}} Parameters ---------- ggsql : diff --git a/pkg-py/src/querychat/tools.py b/pkg-py/src/querychat/tools.py index 48a17b5cc..000824216 100644 --- a/pkg-py/src/querychat/tools.py +++ b/pkg-py/src/querychat/tools.py @@ -1,6 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol, TypedDict, runtime_checkable +import inspect +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast, runtime_checkable from chatlas import ContentToolResult, Tool from shinychat.types import ToolResultDisplay @@ -23,9 +25,10 @@ ] if TYPE_CHECKING: - from collections.abc import Callable + from ._query_executor import QueryExecutor - from ._datasource import DataSource + +ResetDashboardCallback = Callable[[], None] | Callable[[str], None] @runtime_checkable @@ -52,6 +55,8 @@ class UpdateDashboardData(TypedDict): Attributes ---------- + table + The name of the table being filtered. query The SQL query string to execute for filtering/sorting the dashboard. title @@ -66,6 +71,7 @@ class UpdateDashboardData(TypedDict): def log_update(data: UpdateDashboardData): + print(f"Table: {data['table']}") print(f"Executing: {data['query']}") print(f"Title: {data['title']}") @@ -77,35 +83,45 @@ def log_update(data: UpdateDashboardData): """ + table: str query: str title: str def _update_dashboard_impl( - data_source: DataSource, + executor: QueryExecutor, + table_names: list[str], update_fn: Callable[[UpdateDashboardData], None], -) -> Callable[[str, str], ContentToolResult]: +) -> Callable[[str, str, str], ContentToolResult]: """Create the implementation function for updating the dashboard.""" - def update_dashboard(query: str, title: str) -> ContentToolResult: + def update_dashboard(table: str, query: str, title: str) -> ContentToolResult: error = None markdown = f"```sql\n{query}\n```" value = "Dashboard updated. Use `query` tool to review results, if needed." + # Validate table exists + if table not in table_names: + available = ", ".join(table_names) + error = f"Table '{table}' not found. Available: {available}" + markdown += f"\n\n> Error: {error}" + return ContentToolResult(value=markdown, error=Exception(error)) + try: # Test the query but don't execute it yet - data_source.test_query(query, require_all_columns=True) + executor.test_query(query, table_name=table, require_all_columns=True) # Add Apply Filter button button_html = f"""""" # Call the callback with TypedDict data on success - update_fn({"query": query, "title": title}) + update_fn({"table": table, "query": query, "title": title}) except Exception as e: error = truncate_error(str(e)) @@ -130,30 +146,33 @@ def update_dashboard(query: str, title: str) -> ContentToolResult: def tool_update_dashboard( - data_source: DataSource, + executor: QueryExecutor, + table_names: list[str], update_fn: Callable[[UpdateDashboardData], None], ) -> Tool: """ - Create a tool that modifies the data presented in the dashboard based on the SQL query. + Create a tool that modifies the data presented in the dashboard. Parameters ---------- - data_source - The data source to query against + executor + The query executor to validate queries against. + table_names + List of valid table names for validation. update_fn - Callback function to call with UpdateDashboardData when update succeeds + Callback function to call with UpdateDashboardData when update succeeds. Returns ------- Tool - A tool that can be registered with chatlas + A tool that can be registered with chatlas. """ - impl = _update_dashboard_impl(data_source, update_fn) + impl = _update_dashboard_impl(executor, table_names, update_fn) description = read_prompt_template( "tool-update-dashboard.md", - db_type=data_source.get_db_type(), + db_type=executor.get_db_type(), ) impl.__doc__ = description @@ -165,17 +184,27 @@ def tool_update_dashboard( def _reset_dashboard_impl( - reset_fn: Callable[[], None], -) -> Callable[[], ContentToolResult]: + reset_fn: ResetDashboardCallback, + table_names: list[str] | None, +) -> Callable[[str], ContentToolResult]: """Create the implementation function for resetting the dashboard.""" - def reset_dashboard() -> ContentToolResult: + def reset_dashboard(table: str) -> ContentToolResult: + if table_names is not None and table not in table_names: + available = ", ".join(table_names) + error = f"Table '{table}' not found. Available: {available}" + return ContentToolResult( + value=error, + error=Exception(error), + ) + # Call the callback to reset - reset_fn() + _call_reset_dashboard(reset_fn, table) # Add Reset Filter button - button_html = """""" @@ -275,6 +276,8 @@ def tool_update_dashboard( executor: QueryExecutor, table_names: list[str], update_fn: Callable[[UpdateDashboardData], None], + *, + multi_table: bool = False, ) -> Tool: """ Create a tool that modifies the data presented in the dashboard. @@ -287,6 +290,8 @@ def tool_update_dashboard( List of valid table names for validation. update_fn Callback function to call with UpdateDashboardData when update succeeds. + multi_table + Whether multiple tables are registered. Returns ------- @@ -299,6 +304,7 @@ def tool_update_dashboard( description = read_prompt_template( "tool-update-dashboard.md", db_type=executor.get_db_type(), + multi_table=multi_table, ) impl.__doc__ = description @@ -330,7 +336,7 @@ def reset_dashboard(table: str) -> ContentToolResult: # Add Reset Filter button button_html = f"""