diff --git a/.gitignore b/.gitignore index df712fb0f..ef52805cf 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ animation.screenflow/ README_files/ README.html .DS_Store +test-results/ python-package/examples/titanic.db .quarto *.db diff --git a/js/build.mjs b/js/build.mjs index 6b95400d0..dda68cb85 100644 --- a/js/build.mjs +++ b/js/build.mjs @@ -24,6 +24,14 @@ const jsTargets = [ source: "src/viz.ts", output: "../pkg-r/inst/htmldep/viz.js", }, + { + source: "src/schema-display.js", + output: "../pkg-py/src/querychat/static/js/schema-display.js", + }, + { + source: "src/schema-display.js", + output: "../pkg-r/inst/htmldep/schema-display.js", + }, ]; const cssTargets = [ diff --git a/js/src/schema-display.js b/js/src/schema-display.js new file mode 100644 index 000000000..95966542e --- /dev/null +++ b/js/src/schema-display.js @@ -0,0 +1,229 @@ +let lastDisplay = null; +let lastDisplayTime = 0; +const BATCH_MS = 1000; +let activePanel = null; + +// -- Schema text parser -------------------------------------------------- + +function parseColumnsJson(json) { + return JSON.parse(json).map((col) => ({ + name: col.name, + type: col.sql_type, + units: col.units || null, + description: col.description || null, + constraints: col.constraints && col.constraints.length > 0 ? col.constraints.join(', ') : null, + range: + col.min_val != null && col.max_val != null ? `${col.min_val} to ${col.max_val}` : null, + categories: + col.categories && col.categories.length > 0 + ? col.categories.map((v) => `'${v}'`).join(', ') + : null, + })); +} + +// -- Table rendering ----------------------------------------------------- + +function esc(s) { + return String(s) + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"'); +} + +const TH = + 'padding:0.35em 0.75em;text-align:left;white-space:nowrap;font-weight:600;' + + 'border-bottom:2px solid var(--bs-border-color,#dee2e6);' + + 'background:var(--bs-tertiary-bg,#f8f9fa);' + + 'position:sticky;top:0;z-index:1;'; +const TD_MONO = + 'padding:0.3em 0.75em;white-space:nowrap;' + + 'font-family:var(--bs-font-monospace,monospace);font-size:0.875em;' + + 'border-bottom:1px solid var(--bs-border-color-translucent,rgba(0,0,0,.08));'; +const TD_WRAP = + 'padding:0.3em 0.75em;max-width:22em;overflow-wrap:break-word;' + + 'border-bottom:1px solid var(--bs-border-color-translucent,rgba(0,0,0,.08));'; +const TD_NOWRAP = + 'padding:0.3em 0.75em;white-space:nowrap;' + + 'border-bottom:1px solid var(--bs-border-color-translucent,rgba(0,0,0,.08));'; + +function renderTable(columns) { + const rows = columns + .map((col) => { + let typeCell = esc(col.type); + if (col.units) { + typeCell += ` [${esc(col.units)}]`; + } + const details = col.range + ? esc(col.range) + : col.categories + ? esc(col.categories) + : ''; + + return ( + `` + + `${esc(col.name)}` + + `${typeCell}` + + `${col.description ? esc(col.description) : ''}` + + `${col.constraints ? esc(col.constraints) : ''}` + + `${details}` + + `` + ); + }) + .join(''); + + return ( + `` + + `` + + `` + + `` + + `` + + `` + + `` + + `` + + `${rows}` + + `
ColumnTypeDescriptionConstraintsRange / Values
` + ); +} + +// -- Panel positioning & lifecycle --------------------------------------- + +const PANEL_STYLE = + 'position:fixed;z-index:9999;' + + 'background:var(--bs-body-bg,#fff);color:var(--bs-body-color,#212529);' + + 'border:1px solid var(--bs-border-color,#dee2e6);' + + 'border-radius:var(--bs-border-radius,0.375rem);' + + 'box-shadow:0 4px 16px rgba(0,0,0,.15);' + + 'overflow:auto;' + + 'max-height:min(420px,60vh);'; + +function positionPanel(btn, panel) { + const rect = btn.getBoundingClientRect(); + const vw = window.innerWidth; + const vh = window.innerHeight; + + const pw = Math.min(Math.max(360, vw * 0.55), vw - 16); + panel.style.width = `${pw}px`; + panel.style.left = `${Math.max(8, Math.min(rect.left, vw - pw - 8))}px`; + + // Prefer below; fall back to above if there's more room there + const spaceBelow = vh - rect.bottom - 8; + const spaceAbove = rect.top - 8; + if (spaceBelow >= 120 || spaceBelow >= spaceAbove) { + panel.style.top = `${rect.bottom + 4}px`; + } else { + const panelH = Math.min(420, spaceAbove); + panel.style.top = `${Math.max(8, rect.top - panelH - 4)}px`; + } +} + +function closePanel() { + if (activePanel) { + activePanel.panel.hidden = true; + activePanel.btn.setAttribute('aria-expanded', 'false'); + activePanel = null; + } +} + +document.addEventListener('click', closePanel); +document.addEventListener('keydown', (e) => { + if (e.key === 'Escape') closePanel(); +}); +window.addEventListener( + 'scroll', + (e) => { + if (activePanel && !activePanel.panel.contains(/** @type {Node} */ (e.target))) { + closePanel(); + } + }, + true, +); +window.addEventListener('resize', closePanel); + +// -- Button + panel construction ----------------------------------------- + +function createBtn(tableName, columnsJson) { + const columns = parseColumnsJson(columnsJson); + + const btn = document.createElement('button'); + btn.type = 'button'; + btn.style.cssText = + 'background:none;border:none;padding:0;color:inherit;' + + 'text-decoration:underline dotted;cursor:pointer;font-size:inherit;border-radius:2px;'; + btn.textContent = tableName; + btn.setAttribute('aria-label', `Show schema for ${tableName}`); + btn.setAttribute('aria-expanded', 'false'); + btn.setAttribute('aria-haspopup', 'dialog'); + + const panel = document.createElement('div'); + panel.setAttribute('role', 'dialog'); + panel.setAttribute('aria-label', `${tableName} schema`); + panel.style.cssText = PANEL_STYLE; + panel.hidden = true; + panel.innerHTML = renderTable(columns); + document.body.appendChild(panel); + + btn.addEventListener('click', (e) => { + e.stopPropagation(); + if (activePanel && activePanel.panel === panel) { + closePanel(); + return; + } + closePanel(); + positionPanel(btn, panel); + panel.hidden = false; + btn.setAttribute('aria-expanded', 'true'); + activePanel = { btn, panel }; + }); + + panel.addEventListener('click', (e) => e.stopPropagation()); + + return btn; +} + +// -- Focus ring for keyboard users (Bootstrap resets button outline) ----- + +const style = document.createElement('style'); +style.textContent = + '.qc-schema-display button:focus-visible{' + + 'outline:2px solid currentColor;outline-offset:2px;border-radius:2px}'; +document.head.appendChild(style); + +// -- MutationObserver --------------------------------------------------- + +function processCollector(sentinel) { + const now = Date.now(); + const tableName = sentinel.dataset.table; + const btn = createBtn(tableName, sentinel.dataset.schemaJson); + + if (lastDisplay && document.contains(lastDisplay) && now - lastDisplayTime < BATCH_MS) { + lastDisplay.appendChild(document.createTextNode(', ')); + lastDisplay.appendChild(btn); + sentinel.remove(); + } else { + const p = document.createElement('p'); + p.className = 'qc-schema-display'; + p.style.cssText = + 'color:var(--bs-secondary-color,#6c757d);font-size:0.875em;margin:0.1rem 0;'; + p.appendChild(document.createTextNode('πŸ” Fetched schemas: ')); + p.appendChild(btn); + sentinel.replaceWith(p); + lastDisplay = p; + } + lastDisplayTime = now; +} + +new MutationObserver((mutations) => { + for (const { addedNodes } of mutations) { + for (const node of addedNodes) { + if (node.nodeType !== 1) continue; + if (/** @type {Element} */ (node).classList.contains('qc-schema-collector')) { + processCollector(/** @type {HTMLElement} */ (node)); + } else { + /** @type {Element} */ (node) + .querySelectorAll('.qc-schema-collector') + .forEach((el) => processCollector(/** @type {HTMLElement} */ (el))); + } + } + } +}).observe(document.body, { subtree: true, childList: true }); diff --git a/pkg-py/CHANGELOG.md b/pkg-py/CHANGELOG.md index a00e8722e..384356136 100644 --- a/pkg-py/CHANGELOG.md +++ b/pkg-py/CHANGELOG.md @@ -9,13 +9,36 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features +* `QueryChat()` now supports **multiple related tables**. Register additional tables with `add_table()` and the LLM can reason across all of them β€” joins, cross-table filters, aggregations. Per-table reactive state (`df()`, `sql()`, `title()`) is accessible via `qc_vals.table("name")` on the value returned by `server()`. (#195) + + ```python + qc = QueryChat(orders_df, "orders") + qc.add_table(customers_df, "customers") + + qc_vals = qc.server() + qc_vals.table("orders").df() + qc_vals.table("customers").sql() + ``` + +* A new **`DataDict`** type β€” integrating with the [data-dict](https://data-dict.tidyverse.org/) spec β€” lets you annotate tables and columns with plain-English descriptions loaded from a YAML file. This is the preferred way to provide additional context for the data, especially when multiple tables are relevant. The LLM receives these descriptions when it fetches the schema, helping it interpret ambiguous or domain-specific column names without any extra prompting. (#195) + + ```python + QueryChat(data_dict="data_dict.yaml") + ``` + * Added `PinSource`, a data source for chatting with datasets pinned to a [pins](https://pins.rstudio.com/) board. Works with parquet, CSV, JSON, and Arrow pins, and uses the pin's title, description, and tags as the default data description. Install the optional dependency with `pip install querychat[pins]`. (#246) * File attachments are now enabled by default in the Shiny chat UI. Users can attach images, PDFs, and text files to their messages and the LLM will receive them. Disable with `allow_attachments=False` in `mod_ui()` or `QueryChat.ui()`. (#253) +### Breaking Changes + +* The `data_source` property has been removed. Use `qc.table("name").data_source` to read a table's data source, and `qc.add_table(df, "name", replace=True)` to replace it. The `data_source` parameter to `server()` (Shiny) has also been removed; call `add_table()` before `server()` instead. (#195) + ### Improvements * Chat greetings now use shinychat's greeting API (requires shinychat >= 0.4.0). A provided `greeting` renders instantly when the app loads, and when no `greeting` is given one is generated on demand without being added to the conversation history. Generated greetings are now preserved across bookmark/restore. (#249) + +* The system prompt is now lighter: full schema is no longer embedded upfront. Instead the LLM fetches per-table schema on demand via the new `querychat_get_schema` tool β€” and only when it needs to. When a `DataDict` is provided, the tool skips columns that already have descriptions, so the LLM only pays for what isn't already documented. (#195) * The query tool result card now starts collapsed by default. Users can still expand it to see the SQL query and results. Set `QUERYCHAT_TOOL_DETAILS=expanded` to restore the previous behavior. (#239) ## [0.6.1] - 2026-05-26 @@ -86,6 +109,8 @@ Each framework's `QueryChat` provides `.app()` for quick standalone apps and `.u ### New features +* Added `PolarsLazySource` to support Polars LazyFrames as data sources. Data stays lazy until the render boundary, enabling efficient handling of large datasets. Pass a `polars.LazyFrame` directly to `QueryChat()` and queries will be executed lazily via Polars' SQLContext. + * `QueryChat.console()` was added to launch interactive console-based chat sessions with your data source, with persistent conversation state across invocations. (#168) * `QueryChat.client()` can now create standalone querychat-enabled chat clients with configurable tools and callbacks, enabling use outside of Shiny applications. (#168) diff --git a/pkg-py/docs/build-intro.qmd b/pkg-py/docs/build-intro.qmd index ecac93711..ad2570559 100644 --- a/pkg-py/docs/build-intro.qmd +++ b/pkg-py/docs/build-intro.qmd @@ -4,21 +4,14 @@ title: Intro While the `.app()` method is a great [quick start](index.qmd#quick-start) for exploring data, building custom apps with querychat unlocks the full power of integrating natural language data exploration with custom visualizations, layouts, and interactivity. -querychat is a particularly good fit for apps that have: - -1. **A single data source** (or a set of related tables that can be joined) -2. **Multiple filters** that let users slice and explore the data in different ways -3. **Several visualizations and outputs** that all depend on the same filtered data - -In these apps, querychat can replace or augment your filtering UI by allowing users to describe what they want to see in natural language. Instead of building complex filter controls, users can simply ask questions like "show me customers from California who spent over $1000 last quarter" and querychat will generate the appropriate SQL query. +querychat lets users ask questions of their data in plain language β€” filtering, sorting, summarizing, joining across tables, and creating visualizations β€” all without needing to write SQL or navigate complex filter UIs. You can use it as the primary exploration interface in a standalone app, or embed it alongside curated views in an existing dashboard to let users go deeper than the views you designed. This is especially valuable when: - Your data has many columns and building a UI for all possible filters would be overwhelming - Users want to explore ad-hoc combinations of filters that you didn't anticipate -- You want to make data exploration more accessible to users who aren't comfortable with traditional filtering UIs - -If you have an existing app with a data frame that flows through multiple outputs, querychat can be a natural addition to provide an alternative way to filter that data. +- You have multiple related tables that users may want to query and join +- You want to make data exploration more accessible to non-technical users ## General pattern diff --git a/pkg-py/docs/build.qmd b/pkg-py/docs/build.qmd index 460d36cc7..68ea3ceb1 100644 --- a/pkg-py/docs/build.qmd +++ b/pkg-py/docs/build.qmd @@ -303,22 +303,98 @@ def _(): Users can also ask the LLM to "reset" or "show all data" to clear filters through the chat interface. ::: -### Advanced patterns - -#### Programmatic updates +### Programmatic updates You can update the query state programmatically using `.sql()` and `.title()` as setters. This is useful for adding preset filter buttons or linking filters to other UI controls. -#### Multiple datasets +## Multiple tables + +querychat can work with multiple related tables in a single chat interface, letting users query across tables, join data, and filter any table independently. Register additional tables with `.add_table()` after creating the `QueryChat` instance, then access per-table state through the `.table()` method. + +### Registering tables + +Pass the first table when creating `QueryChat`, then call `.add_table()` for each additional table: + +```python +from querychat import QueryChat + +qc = QueryChat(orders, "orders") +qc.add_table(customers, "customers") +qc.add_table(products, "products") +``` + +The LLM can query any registered table and write joins across them. You can inspect which tables are registered with `qc.table_names()`. + +### Per-table reactive access + +When working with multiple tables, access filtered data and SQL for each table individually using `.table()`: + +::: {.panel-tabset group="shiny-mode"} + +#### Express + +```python +from shiny.express import render + +qc.sidebar() + +@render.data_frame +def orders_table(): + return qc.table("orders").df() + +@render.data_frame +def customers_table(): + return qc.table("customers").df() +``` + +#### Core -To explore multiple datasets, use separate `QueryChat` instances (i.e., separate chat interfaces). +```python +def server(input, output, session): + qc_vals = qc.server() -::: {.callout-note} -### Multiple tables in one chat? + @render.data_frame + def orders_table(): + return qc_vals.table("orders").df() + + @render.data_frame + def customers_table(): + return qc_vals.table("customers").df() +``` -In some cases, you might be able to "pre-join" datasets into a single table and use one `QueryChat` instance to explore them together. In the future, we may support multiple filtered tables in one chat interface, but this is not currently available. Please upvote [the relevant issue](https://github.com/posit-dev/querychat/issues/6) if this is a feature you'd like to see! ::: +Each table has its own `.df()`, `.sql()`, and `.title()` reactives that update independently when the user filters that specific table. + +### Tracking the active table + +Use `.current_table()` to find out which table the LLM most recently queried. This is useful for auto-switching a tabbed UI to the relevant table: + +```python +@reactive.effect +def _(): + name = qc_vals.current_table() + if name: + ui.update_navs("table_tabs", selected=name) +``` + +### Data dictionary + +When working with multiple related tables, providing a [data dictionary](context.qmd#data-dictionary) is strongly recommended. It tells the LLM how tables relate to each other, which columns are keys, and what domain terms mean β€” all of which help it write accurate joins and queries. + +```python +from pathlib import Path + +qc = QueryChat(orders, "orders", data_dict=Path("data-dict.yaml")) +qc.add_table(customers, "customers") +``` + +See [Provide context](context.qmd#data-dictionary) for the full data dictionary format. + +### Separate chat interfaces + +If your tables are truly independent (not related), you may prefer separate `QueryChat` instances, each with its own chat interface: + ```{python} # | eval: false # | code-fold: true @@ -328,10 +404,6 @@ In some cases, you might be able to "pre-join" datasets into a single table and ![](/images/multiple-datasets.png){fig-alt="Screenshot of a querychat app with two datasets: titanic and penguins." class="lightbox shadow rounded mb-3"} -::: {.callout-note} -Each dataset gets its own chat interface and maintains separate state. -::: - ## See also - [Greet users](greet.qmd) - Create welcoming onboarding experiences diff --git a/pkg-py/docs/context.qmd b/pkg-py/docs/context.qmd index df6cb25ae..123335dc0 100644 --- a/pkg-py/docs/context.qmd +++ b/pkg-py/docs/context.qmd @@ -2,73 +2,135 @@ title: Provide context --- -querychat automatically gathers information about your table to help the LLM write accurate SQL queries. This includes column names and types, numerical ranges, and categorical value examples.^[All of this information is provided to the LLM as part of the **system prompt** -- a string of text containing instructions and context for the LLM to consider when responding to user queries.] +querychat automatically gathers schema information about your tables β€” column names, types, numerical ranges, and categorical values β€” and makes it available to the LLM on demand via the `querychat_get_schema` [tool](tools.qmd#schema-retrieval). The LLM calls this tool before writing SQL to understand the structure of the tables it's querying. Importantly, we are **not** sending your raw data to the LLM and asking it to do complicated math. The LLM only needs to understand the structure and schema of your data in order to write SQL queries. -You can get even better results by customizing the system prompt in three ways: +You can get even better results by providing additional context: -1. Add a [data description](#data-description) to provide more context about what the data represents -2. Add [custom instructions](#extra-instructions) to guide the LLM's behavior -3. Use a fully [custom prompt template](#custom-template) if you want complete control (useful if you want to be certain the model cannot see any literal values from your data) +1. Add a [data dictionary](#data-dictionary) to describe tables, columns, relationships, and domain terminology (recommended) +2. Add a [data description](#data-description) for a simpler alternative when working with a single, straightforward table +3. Add [custom instructions](#extra-instructions) to guide the LLM's behavior +4. Use a fully [custom prompt template](#custom-template) if you want complete control -## Default prompt +## Data dictionary {#data-dictionary} -For full visibility into the system prompt that querychat generates for the LLM, you can inspect the `system_prompt` property. This is useful for debugging and understanding exactly what context the LLM is using: +A **data dictionary** is a YAML file that describes your tables, columns, relationships, and domain-specific terminology. It's the recommended way to provide context, especially when working with [multiple tables](build.qmd#multiple-tables) or when your data has domain-specific meaning that isn't obvious from column names alone. -```python +```{.python filename="app.py"} +from pathlib import Path from querychat import QueryChat -from querychat.data import titanic -qc = QueryChat(titanic(), "titanic") -print(qc.system_prompt) +qc = QueryChat( + orders, "orders", + data_dict=Path("data-dict.yaml") +) +qc.add_table(customers, "customers") ``` -By default, the system prompt contains the following components: +### Format + +A data dictionary has three top-level sections: `tables`, `relationships`, and `glossary`. + +```{.yaml filename="data-dict.yaml"} +version: "0.2.0" + +tables: + orders: + description: One row per customer order. + columns: + - name: order_id + type: number(id) + constraints: [primary_key] + description: Unique order identifier. + - name: customer_id + type: number(id) + constraints: [foreign_key] + description: References customers.id. + - name: total + type: number(quantity) + description: Order total in USD. + - name: status + type: enum + values: [pending, shipped, delivered, cancelled] + description: Current order status. + + customers: + description: One row per customer. + columns: + - name: id + type: number(id) + constraints: [primary_key] + description: Unique customer identifier. + - name: name + type: string + description: Full name. + - name: region + type: string + description: Geographic sales region. + +relationships: + - description: Each order belongs to one customer. + cardinality: many-to-one + join: orders.customer_id = customers.id + +glossary: + AOV: Average order value β€” total revenue divided by number of orders. + churn: A customer who has not placed an order in the last 90 days. +``` + +#### Tables + +Each entry under `tables` describes one table. The key must match the table name you pass to `QueryChat` or `.add_table()`. + +- **`description`**: What this table represents (one sentence is usually enough). +- **`columns`**: A list of column annotations. Each column can have: + - `name`: Column name (must match the actual column) + - `type`: Semantic type hint β€” `string`, `number`, `number(id)`, `number(quantity)`, `date`, `enum` + - `constraints`: Optional list β€” `primary_key`, `foreign_key` + - `description`: What this column means in plain English + - `values`: For `enum` columns, the list of possible values + +Columns listed in the data dictionary are excluded from the auto-generated schema (since your description supersedes the auto-detected metadata). Columns not listed are still auto-detected as usual. + +#### Relationships + +The `relationships` section tells the LLM how to join tables. Each entry has: -1. The basic set of behaviors and guidelines the LLM must follow in order for querychat to work properly, including how to use [tools](tools.qmd) to execute queries and update the app. -2. The SQL schema of the data frame you provided. This includes: - - Column names - - Data types (integer, float, boolean, datetime, text) - - For text columns with less than 10 unique values, we assume they are categorical variables and include the list of values - - For integer and float columns, we include the range -3. A [data description](#data-description) (if provided via `data_description`) -4. [Additional instructions](#additional-instructions) you want to use to guide querychat's behavior (if provided via `extra_instructions`). +- `description`: A plain-English description of the relationship +- `cardinality`: `one-to-one`, `one-to-many`, or `many-to-one` +- `join`: The join condition (e.g., `orders.customer_id = customers.id`) + +#### Glossary + +The `glossary` section defines domain-specific terms that users might use in their questions. This helps the LLM translate business language into correct SQL. ## Data description {#data-description} -If your column names are descriptive, querychat may already work well without additional context. However, if your columns are named `x`, `V1`, `value`, etc., you should provide a data description. Use the `data_description` parameter for this: +For simple single-table use cases where a full data dictionary would be overkill, you can provide a **data description** β€” a free-form markdown file or string that describes what the data represents. Use the `data_description` parameter: -```{.python filename="titanic-app.py"} +```python from pathlib import Path -from querychat import QueryChat qc = QueryChat( titanic, "titanic", data_description=Path("data_description.md") ) -app = qc.app() ``` -querychat doesn't need this information in any particular format -- just provide what a human would find helpful: +querychat doesn't need this in any particular format β€” just provide what a human would find helpful: ```{.markdown filename="data_description.md"} -This dataset contains information about Titanic passengers, collected for predicting survival. +This dataset contains information about Titanic passengers. - survived: Survival (0 = No, 1 = Yes) - pclass: Ticket class (1 = 1st, 2 = 2nd, 3 = 3rd) -- sex: Sex of passenger -- age: Age in years -- sibsp: Number of siblings/spouses aboard -- parch: Number of parents/children aboard -- fare: Passenger fare - embarked: Port of embarkation (C = Cherbourg, Q = Queenstown, S = Southampton) ``` - ## Additional instructions {#extra-instructions} You can add custom instructions to guide the LLM's behavior using the `extra_instructions` parameter: @@ -86,7 +148,7 @@ Or as a string: ```python instructions = """ - Use British spelling conventions -- Stay on topic and only discuss the data dashboard +- Stay on topic and only discuss the data - Refuse to answer unrelated questions """ @@ -98,6 +160,17 @@ qc = QueryChat(titanic, "titanic", extra_instructions=instructions) LLMs may not always follow your instructions perfectly. Test extensively when changing instructions or models. ::: +## Default prompt + +For full visibility into the system prompt that querychat generates for the LLM, you can inspect the `system_prompt` property. This is useful for debugging and understanding exactly what context the LLM is working with: + +```python +from querychat import QueryChat +from querychat.data import titanic + +qc = QueryChat(titanic(), "titanic") +print(qc.system_prompt) +``` ## Custom template {#custom-template} diff --git a/pkg-py/docs/index.qmd b/pkg-py/docs/index.qmd index 8d119ae3b..d571e333b 100644 --- a/pkg-py/docs/index.qmd +++ b/pkg-py/docs/index.qmd @@ -126,8 +126,7 @@ pip install "querychat[streamlit]" # or [gradio] or [dash] ## Build custom apps -querychat is designed to be highly extensible -- it provides programmatic access to the chat interface, the filtered/sorted data frame, SQL queries, and more. -This makes it easy to build custom web apps that leverage natural language interaction with your data. +querychat is designed to be highly extensible β€” it provides programmatic access to the chat interface, the filtered/sorted data frame, SQL queries, and more. You can register [multiple related tables](build.qmd#multiple-tables) for cross-table queries and joins, and connect all of it to your own visualizations and layouts. For example, [here](https://github.com/posit-conf-2025/llm/blob/main/_solutions/25_querychat/25_querychat_02-end-app.R)'s a bespoke app for exploring Airbnb listings in Ashville, NC: ![](/images/airbnb.png){fig-alt="A custom app for exploring Airbnb listings, powered by querychat." class="lightbox shadow rounded mb-3"} @@ -136,9 +135,9 @@ To learn more, see the build guides for your framework: [Shiny](build.qmd), [Str ## How it works -querychat uses LLMs to translate natural language into SQL queries. Models of all sizes, from small ones you can run locally to large frontier models from major AI providers, are remarkably effective at this task. But even the best models need to understand your data's overall structure to perform well. +querychat uses LLMs to translate natural language into SQL queries. Models of all sizes, from small ones you can run locally to large frontier models from major AI providers, are remarkably effective at this task. But even the best models need to understand your data's structure to perform well. -To address this, querychat includes schema metadata -- column names, types, ranges, categorical values -- in the LLM's [system prompt](context.qmd). Importantly, querychat **does not** send raw data to the LLM; it shares only enough structural information for the model to generate accurate queries. When the LLM produces a query, querychat executes it in a SQL database (DuckDB[^duckdb], by default) to obtain precise results. +To address this, querychat provides a schema retrieval [tool](tools.qmd#schema-retrieval) that the LLM calls on demand to learn about table structure β€” column names, types, ranges, and categorical values. You can further improve results by providing a [data dictionary](context.qmd#data-dictionary) with column descriptions, table relationships, and domain terminology. Importantly, querychat **does not** send raw data to the LLM; it shares only enough structural information for the model to generate accurate queries. When the LLM produces a query, querychat executes it in a SQL database (DuckDB[^duckdb], by default) to obtain precise results. This design makes querychat reliable, safe, and reproducible: diff --git a/pkg-py/docs/tools.qmd b/pkg-py/docs/tools.qmd index 44301f1d4..702bc865d 100644 --- a/pkg-py/docs/tools.qmd +++ b/pkg-py/docs/tools.qmd @@ -2,14 +2,24 @@ title: Tools --- -querychat combines [tool calling](https://posit-dev.github.io/chatlas/get-started/tools.html) with [reactivity](https://shiny.posit.co/py/docs/reactive-foundations.html) to not only execute SQL, but also reactively update dependent data views. Understanding how these tools work will help you better understand what querychat is capable of and how to customize/extend to its behavior. +querychat combines [tool calling](https://posit-dev.github.io/chatlas/get-started/tools.html) with [reactivity](https://shiny.posit.co/py/docs/reactive-foundations.html) to not only execute SQL, but also reactively update dependent data views. Understanding how these tools work will help you better understand what querychat is capable of and how to customize/extend its behavior. One important thing to understand generally about querychat's tools is they are Python functions, and that execution happens on _your machine_, not on the LLM provider's side. In other words, the SQL queries generated by the LLM are executed locally in the Python process running the app. -querychat provides the LLM access to three tool groups: +querychat provides the LLM access to four tool groups: -1. **Data updating** - Filter and sort data (without sending results to the LLM). -2. **Data analysis** - Calculate summaries and return results for interpretation by the LLM. +1. **Schema retrieval** - Fetch table structure before writing SQL. +2. **Data updating** - Filter and sort data (without sending results to the LLM). +3. **Data analysis** - Calculate summaries and return results for interpretation by the LLM. +4. **Data visualization** - Create charts inline in the chat. + +## Schema retrieval {#schema-retrieval} + +Before writing any SQL query, the LLM calls the `querychat_get_schema` tool to retrieve column names, types, value ranges, and descriptions for a specific table. This on-demand approach means the LLM only fetches schema for the tables it actually needs, keeping the system prompt lean and startup fast β€” especially when working with [multiple tables](build.qmd#multiple-tables) or large databases. + +If you've provided a [data dictionary](context.qmd#data-dictionary), the schema response includes your column descriptions and relationship information. Columns annotated in the data dictionary are excluded from the auto-generated schema metadata (since your description supersedes it). + +This tool is always registered and cannot be disabled. ## Data updating @@ -61,7 +71,7 @@ This tool: 2. Renders the `VISUALISE` clause as an Altair chart 3. Displays the chart inline in the chat -Unlike the data updating tools, visualization queries don't affect the dashboard filter. +Unlike the data updating tools, visualization queries don't affect the active data filter. They query the full dataset independently, and each call produces a new inline chart message in the chat. The inline chart includes controls for fullscreen viewing, saving as PNG/SVG, and a "Show Query" toggle that reveals the underlying ggsql code. @@ -96,6 +106,7 @@ If you'd like to better understand how the tools work and how the LLM is prompte **Prompts:** +- [`prompts/tool-get-schema.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-get-schema.md) - [`prompts/tool-update-dashboard.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-update-dashboard.md) - [`prompts/tool-reset-dashboard.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-reset-dashboard.md) - [`prompts/tool-query.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-query.md) diff --git a/pkg-py/examples/11-multi-table-nutrition/app.py b/pkg-py/examples/11-multi-table-nutrition/app.py new file mode 100644 index 000000000..0a6e5ac88 --- /dev/null +++ b/pkg-py/examples/11-multi-table-nutrition/app.py @@ -0,0 +1,290 @@ +"""USDA Foundation Foods nutrition dashboard with querychat. + +Real nutrition data from the USDA Foundation Foods dataset (via the {foodbank} +R package), organized across six tables: + foods, food_categories, nutrients, food_nutrients, food_portions, measure_units + +The main content area shows reactive value boxes and Plotly Express charts that +update whenever querychat filters the data. + +Usage: + cd pkg-py + uv run shiny run examples/multi-table-nutrition.py +""" +from pathlib import Path + +import plotly.express as px +import polars as pl +import shinychat +from shiny import App, reactive, render, ui +from shinywidgets import output_widget, render_plotly + +from querychat import QueryChat + +# ── Data ───────────────────────────────────────────────────────────────────── + +_DATA_DIR = Path(__file__).parent / "data" / "foodbank" + +foods = pl.read_parquet(_DATA_DIR / "foods.parquet") +food_categories = pl.read_parquet(_DATA_DIR / "food_categories.parquet") +nutrients = pl.read_parquet(_DATA_DIR / "nutrients.parquet") +food_nutrients = pl.read_parquet(_DATA_DIR / "food_nutrients.parquet") +food_portions = pl.read_parquet(_DATA_DIR / "food_portions.parquet") +measure_units = pl.read_parquet(_DATA_DIR / "measure_units.parquet") + +# Mapping from USDA nutrient ID to friendly column name +_NUTRIENT_ID_TO_COL = { + 1008: "energy_kcal", + 1003: "protein_g", + 1004: "fat_g", + 1005: "carbs_g", + 1079: "fiber_g", + 1063: "sugars_g", + 1258: "sat_fat_g", + 1087: "calcium_mg", + 1089: "iron_mg", + 1093: "sodium_mg", + 1162: "vitamin_c_mg", + 1092: "potassium_mg", +} + +_col_map = pl.DataFrame( + { + "nutrient_id": list(_NUTRIENT_ID_TO_COL.keys()), + "col": list(_NUTRIENT_ID_TO_COL.values()), + } +).with_columns(pl.col("nutrient_id").cast(pl.Int32)) + +_wide_nutrients = ( + food_nutrients.join(_col_map, on="nutrient_id", how="left").pivot( + index="fdc_id", on="col", values="amount" + ) +) + +foods_wide = ( + foods.join( + food_categories.select(["id", "description"]).rename( + {"description": "category"} + ), + left_on="food_category_id", + right_on="id", + how="left", + ).join(_wide_nutrients, on="fdc_id", how="left") +) + +# ── QueryChat ───────────────────────────────────────────────────────────────── + +qc = QueryChat( + foods, + "foods", + data_dict=Path(__file__).parent / "nutrition-data-dict.yaml", + greeting="", +) +qc.add_table(food_categories, "food_categories") +qc.add_table(nutrients, "nutrients") +qc.add_table(food_nutrients, "food_nutrients") +qc.add_table(food_portions, "food_portions") +qc.add_table(measure_units, "measure_units") + +_GREETING = shinychat.chat_greeting( + "## USDA Foundation Foods Explorer\n\n" + "Real nutrition data for **436 foods** across 19 categories β€” " + "macronutrients, minerals, vitamins, and serving sizes.\n\n" + "**Filter this view**\n\n" + 'Show only foods where fiber exceeds sugar\n\n' + 'High-protein, low-fat foods: protein > 20g and fat < 5g per 100g\n\n' + 'Foods higher in potassium than sodium\n\n' + "**Dig deeper**\n\n" + 'Which fruits or vegetables beat whole milk for calcium?\n\n' + 'Rank all foods by protein per calorie\n\n' + 'For 1 cup of oats, how much protein and fiber am I getting?\n\n' +) + +# ── App ─────────────────────────────────────────────────────────────────────── + + +def app_ui(request): + return ui.page_sidebar( + ui.sidebar( + qc.ui(greeting=_GREETING), + width=400, + height="100%", + fillable=True, + class_="querychat-sidebar", + ), + ui.layout_columns( + ui.value_box( + "Foods", + ui.output_text("n_foods"), + showcase=ui.tags.span("🍽️", style="font-size:2rem"), + theme="primary", + ), + ui.value_box( + "Avg Protein", + ui.output_text("avg_protein"), + showcase=ui.tags.span("πŸ₯©", style="font-size:2rem"), + theme="success", + ), + ui.value_box( + "Avg Fiber", + ui.output_text("avg_fiber"), + showcase=ui.tags.span("πŸ₯¦", style="font-size:2rem"), + theme="info", + ), + ui.value_box( + "Avg Calories", + ui.output_text("avg_calories"), + showcase=ui.tags.span("πŸ”₯", style="font-size:2rem"), + theme="warning", + ), + col_widths=[3, 3, 3, 3], + gap="1rem", + fill=False, + ), + ui.layout_columns( + ui.card( + ui.card_header(ui.output_text("protein_chart_title")), + output_widget("protein_chart"), + full_screen=True, + ), + ui.card( + ui.card_header("Avg protein by category (top 10)"), + output_widget("macro_chart"), + full_screen=True, + ), + ), + ui.navset_card_underline( + *[ + ui.nav_panel(name, ui.output_data_frame(f"dt_{name}")) + for name in qc.table_names() + ], + id="table_tabs", + full_screen=True, + ), + title="USDA Foundation Foods", + fillable=True, + class_="bslib-page-dashboard", + ) + + +def server(input, output, session): + qc_vals = qc.server() + + @reactive.calc + def current_subset() -> pl.DataFrame: + queried = qc_vals.table("foods").df() + # queried may be polars or pandas depending on the data source + if hasattr(queried, "to_pandas"): # polars DataFrame + ids = queried["fdc_id"].to_list() + else: # pandas DataFrame + ids = queried["fdc_id"].tolist() + return foods_wide.filter(pl.col("fdc_id").is_in(ids)) + + @render.text + def n_foods(): + return str(current_subset().height) + + @render.text + def avg_protein(): + v = current_subset()["protein_g"].drop_nulls().mean() + return f"{v:.1f} g" if v is not None else "β€”" + + @render.text + def avg_fiber(): + v = current_subset()["fiber_g"].drop_nulls().mean() + return f"{v:.1f} g" if v is not None else "β€”" + + @render.text + def avg_calories(): + v = current_subset()["energy_kcal"].drop_nulls().mean() + return f"{v:.0f} kcal" if v is not None else "β€”" + + @render.text + def protein_chart_title(): + n = current_subset().filter(pl.col("protein_g").is_not_null()).height + shown = min(n, 15) + return f"Top {shown} foods by protein (g/100g)" + + @render_plotly + def protein_chart(): + df = ( + current_subset() + .filter(pl.col("protein_g").is_not_null()) + .sort("protein_g", descending=True) + .head(15) + .with_columns( + pl.col("description") + .str.slice(0, 35) + .str.replace(r"(.{35}).+", "${1}…") + .alias("label") + ) + ) + fig = px.bar( + df, + x="protein_g", + y="label", + orientation="h", + hover_data={"category": True, "label": False, "description": True}, + labels={"protein_g": "Protein (g/100g)", "label": ""}, + color_discrete_sequence=["#2196F3"], + ) + fig.update_layout( + showlegend=False, + yaxis={"categoryorder": "total ascending"}, + margin={"l": 10, "r": 40, "t": 10, "b": 40}, + ) + return fig + + @render_plotly + def macro_chart(): + subset = current_subset() + # Limit to top 10 categories by food count to keep the chart readable + top_cats = ( + subset.group_by("category") + .len() + .sort("len", descending=True) + .head(10)["category"] + ) + agg = ( + subset.filter(pl.col("category").is_in(top_cats)) + .group_by("category") + .agg(pl.col("protein_g").mean().alias("avg_protein")) + .sort("avg_protein", descending=True) + ) + fig = px.bar( + agg, + x="avg_protein", + y="category", + orientation="h", + labels={"avg_protein": "Avg protein (g/100g)", "category": ""}, + color_discrete_sequence=["#4CAF50"], + ) + fig.update_layout( + showlegend=False, + yaxis={"categoryorder": "total ascending"}, + margin={"l": 10, "r": 10, "t": 10, "b": 10}, + ) + return fig + + # Auto-switch tab when LLM queries a table + @reactive.effect + def _switch_tab(): + name = qc_vals.current_table() + if name is not None: + ui.update_navs("table_tabs", selected=name) + + # Register one data frame render per table. + # Value boxes and charts above remain tied to the `foods` table β€” they + # use foods-specific wide-format joins and are not generic per-table views. + def _make_dt_renderer(table_name: str): + @render.data_frame + def _renderer(): + return qc_vals.table(table_name).df() + + return _renderer + + for _tname in qc.table_names(): + output(id=f"dt_{_tname}")(_make_dt_renderer(_tname)) + + +app = App(app_ui, server) diff --git a/pkg-py/examples/11-multi-table-nutrition/data/extract_foodbank.R b/pkg-py/examples/11-multi-table-nutrition/data/extract_foodbank.R new file mode 100644 index 000000000..f3b575db8 --- /dev/null +++ b/pkg-py/examples/11-multi-table-nutrition/data/extract_foodbank.R @@ -0,0 +1,59 @@ +#!/usr/bin/env Rscript +# Extract USDA Foundation Foods data from the {foodbank} R package and write +# as parquet files for use in the multi-table-nutrition Shiny for Python app. +# +# Run from the pkg-py/examples directory: +# Rscript data/extract_foodbank.R +# +# Requires: foodbank (github::hadley/foodbank), nanoparquet +# pak::pkg_install(c("hadley/foodbank", "nanoparquet")) + +library(foodbank) +library(nanoparquet) + +script_dir <- tryCatch( + dirname(normalizePath(sys.frame(1)$ofile)), + error = function(e) getwd() +) +out_dir <- file.path(script_dir, "foodbank") +dir.create(out_dir, showWarnings = FALSE, recursive = TRUE) + +# Key nutrient IDs to include (a curated subset of the 477 available) +key_ids <- c( + 1008L, # Energy (kcal) + 1003L, # Protein (g) + 1004L, # Total lipid / fat (g) + 1005L, # Carbohydrate, by difference (g) + 1079L, # Fiber, total dietary (g) + 1063L, # Sugars, Total (g) + 1258L, # Fatty acids, total saturated (g) + 1087L, # Calcium, Ca (mg) + 1089L, # Iron, Fe (mg) + 1093L, # Sodium, Na (mg) + 1162L, # Vitamin C, total ascorbic acid (mg) + 1092L # Potassium, K (mg) +) + +write_parquet(food, file.path(out_dir, "foods.parquet")) +write_parquet(food_category, file.path(out_dir, "food_categories.parquet")) +write_parquet( + nutrient[nutrient$id %in% key_ids, c("id", "name", "unit_name")], + file.path(out_dir, "nutrients.parquet") +) +write_parquet( + food_nutrient[food_nutrient$nutrient_id %in% key_ids, + c("fdc_id", "nutrient_id", "amount")], + file.path(out_dir, "food_nutrients.parquet") +) +write_parquet( + food_portion[, c("fdc_id", "seq_num", "amount", "measure_unit_id", + "gram_weight", "modifier")], + file.path(out_dir, "food_portions.parquet") +) +used_unit_ids <- unique(food_portion$measure_unit_id) +write_parquet( + measure_unit[measure_unit$id %in% used_unit_ids, ], + file.path(out_dir, "measure_units.parquet") +) + +cat("Wrote parquet files to", out_dir, "\n") diff --git a/pkg-py/examples/11-multi-table-nutrition/data/foodbank/food_categories.parquet b/pkg-py/examples/11-multi-table-nutrition/data/foodbank/food_categories.parquet new file mode 100644 index 000000000..10f6d585a Binary files /dev/null and b/pkg-py/examples/11-multi-table-nutrition/data/foodbank/food_categories.parquet differ diff --git a/pkg-py/examples/11-multi-table-nutrition/data/foodbank/food_nutrients.parquet b/pkg-py/examples/11-multi-table-nutrition/data/foodbank/food_nutrients.parquet new file mode 100644 index 000000000..87edd01ac Binary files /dev/null and b/pkg-py/examples/11-multi-table-nutrition/data/foodbank/food_nutrients.parquet differ diff --git a/pkg-py/examples/11-multi-table-nutrition/data/foodbank/food_portions.parquet b/pkg-py/examples/11-multi-table-nutrition/data/foodbank/food_portions.parquet new file mode 100644 index 000000000..b9fc1b411 Binary files /dev/null and b/pkg-py/examples/11-multi-table-nutrition/data/foodbank/food_portions.parquet differ diff --git a/pkg-py/examples/11-multi-table-nutrition/data/foodbank/foods.parquet b/pkg-py/examples/11-multi-table-nutrition/data/foodbank/foods.parquet new file mode 100644 index 000000000..c52fd37fe Binary files /dev/null and b/pkg-py/examples/11-multi-table-nutrition/data/foodbank/foods.parquet differ diff --git a/pkg-py/examples/11-multi-table-nutrition/data/foodbank/measure_units.parquet b/pkg-py/examples/11-multi-table-nutrition/data/foodbank/measure_units.parquet new file mode 100644 index 000000000..7a4019495 Binary files /dev/null and b/pkg-py/examples/11-multi-table-nutrition/data/foodbank/measure_units.parquet differ diff --git a/pkg-py/examples/11-multi-table-nutrition/data/foodbank/nutrients.parquet b/pkg-py/examples/11-multi-table-nutrition/data/foodbank/nutrients.parquet new file mode 100644 index 000000000..23c533adb Binary files /dev/null and b/pkg-py/examples/11-multi-table-nutrition/data/foodbank/nutrients.parquet differ diff --git a/pkg-py/examples/11-multi-table-nutrition/nutrition-data-dict.yaml b/pkg-py/examples/11-multi-table-nutrition/nutrition-data-dict.yaml new file mode 100644 index 000000000..a8455e6fc --- /dev/null +++ b/pkg-py/examples/11-multi-table-nutrition/nutrition-data-dict.yaml @@ -0,0 +1,161 @@ +version: "0.2.0" + +tables: + foods: + description: > + One row per USDA Foundation Food item. Nutrient values are not stored here β€” + see food_nutrients. All nutrient amounts in food_nutrients are per 100g of the food. + columns: + - name: fdc_id + type: number(id) + constraints: [primary_key] + description: Unique food identifier from USDA FoodData Central. + - name: description + type: string + description: Full name of the food item (e.g., "Broccoli, raw"). + - name: food_category_id + type: number(id) + constraints: [foreign_key] + description: Category of the food. Joins to food_categories.id. + - name: publication_date + type: date + description: Date the food record was published by USDA. + + food_categories: + description: Lookup table of USDA food categories. One row per category. + columns: + - name: id + type: number(id) + constraints: [primary_key] + description: Unique category identifier. + - name: code + type: number + description: USDA numeric category code (e.g., 1100 for Vegetables). + - name: description + type: string + description: Human-readable category name (e.g., "Vegetables and Vegetable Products"). + + nutrients: + description: > + Registry of the 12 key nutritional compounds tracked in this dataset. + One row per nutrient type. See the glossary for the full list of nutrient IDs. + columns: + - name: id + type: number(id) + constraints: [primary_key] + description: USDA nutrient identifier (e.g., 1003 for Protein). + - name: name + type: string + description: > + Official USDA nutrient name (e.g., "Protein", "Total lipid (fat)", + "Carbohydrate, by difference"). + - name: unit_name + type: enum + values: [G, KCAL, MG] + description: Unit of measurement. G = grams, KCAL = kilocalories, MG = milligrams. + + food_nutrients: + description: > + Nutrient content per 100g of food. One row per food–nutrient combination. + Only the 12 key nutrients are included (see nutrients table and glossary). + To get nutrient values for a food, join on fdc_id. + To get nutrient names and units, join nutrients on nutrient_id. + columns: + - name: fdc_id + type: number(id) + constraints: [foreign_key] + description: Identifies the food. Joins to foods.fdc_id. + - name: nutrient_id + type: number(id) + constraints: [foreign_key] + description: Identifies the nutrient. Joins to nutrients.id. + - name: amount + type: number(quantity) + description: > + Amount of the nutrient per 100g of food, in the unit given by + nutrients.unit_name. May be null if the nutrient was not measured + for this food. + + food_portions: + description: > + Common serving sizes for foods. One row per portion definition. + A single food may have multiple portions (e.g., "1 cup" and "1 oz"). + Not all foods have portion data β€” only 116 of the 436 foods are covered. + columns: + - name: fdc_id + type: number(id) + constraints: [foreign_key] + description: Identifies the food. Joins to foods.fdc_id. + - name: seq_num + type: number + description: Sequence number ordering multiple portions for the same food. + - name: amount + type: number(quantity) + description: > + The number of measure units in this portion (e.g., 1.0 for "1 cup", + 2.0 for "2 tablespoons"). + - name: measure_unit_id + type: number(id) + constraints: [foreign_key] + description: Unit of the portion. Joins to measure_units.id. + - name: gram_weight + type: number(quantity) + description: Weight in grams of this portion (e.g., 240 for 1 cup of milk). + - name: modifier + type: string + description: > + Optional preparation note for the portion (e.g., "chopped", "drained", + "cooked"). Null when no modifier applies. + + measure_units: + description: > + Lookup table of measurement unit names used in food_portions. + Only units that appear in food_portions are included (~32 of the 123 total). + columns: + - name: id + type: number(id) + constraints: [primary_key] + description: Unique unit identifier. + - name: name + type: string + description: > + Human-readable unit name (e.g., "cup", "tablespoon", "oz", "slice", + "piece", "package"). + +relationships: + - description: Each food belongs to one category. + cardinality: many-to-one + join: foods.food_category_id = food_categories.id + - description: Each food has one row per tracked nutrient. + cardinality: one-to-many + join: foods.fdc_id = food_nutrients.fdc_id + - description: Each nutrient type appears in many food measurements. + cardinality: one-to-many + join: nutrients.id = food_nutrients.nutrient_id + - description: Each food may have one or more common portion sizes. + cardinality: one-to-many + join: foods.fdc_id = food_portions.fdc_id + - description: Each portion references a measurement unit. + cardinality: many-to-one + join: food_portions.measure_unit_id = measure_units.id + +glossary: + per 100g: All nutrient amounts in food_nutrients are normalized to a 100-gram serving for fair comparison across foods. + fdc_id: USDA FoodData Central identifier β€” the primary key for food items in this dataset. + energy: Total caloric value, measured in kilocalories (kcal). Nutrient ID 1008. + macronutrient: One of the three main energy-providing nutrients β€” protein (ID 1003), fat (ID 1004), or carbohydrates (ID 1005). + nutrient IDs: > + Key nutrient IDs in this dataset: + 1003 = Protein (G), + 1004 = Total lipid / fat (G), + 1005 = Carbohydrate by difference (G), + 1008 = Energy (KCAL), + 1063 = Sugars Total (G), + 1079 = Fiber total dietary (G), + 1258 = Fatty acids total saturated (G), + 1087 = Calcium Ca (MG), + 1089 = Iron Fe (MG), + 1093 = Sodium Na (MG), + 1162 = Vitamin C total ascorbic acid (MG), + 1092 = Potassium K (MG). + gram_weight: The actual weight in grams of a described portion. Use this to convert per-100g nutrient values to per-serving values by multiplying by (gram_weight / 100). diff --git a/pkg-py/examples/12-multi-table-express.py b/pkg-py/examples/12-multi-table-express.py new file mode 100644 index 000000000..1f87e0828 --- /dev/null +++ b/pkg-py/examples/12-multi-table-express.py @@ -0,0 +1,54 @@ +"""Minimal multi-table querychat example using Shiny Express. + +Two related tables (orders + customers) are registered with a single QueryChat +instance. The LLM can query either table or write joins across them. +Per-table filtered data is accessed with `qc.table("name").df()`. + +Usage: + cd pkg-py + uv run shiny run examples/12-multi-table-express.py +""" + +import pandas as pd +from shiny.express import render, ui + +from querychat.express import QueryChat + +orders = pd.DataFrame( + { + "order_id": [1, 2, 3, 4, 5], + "customer_id": [101, 102, 101, 103, 102], + "amount": [250.0, 180.0, 320.0, 90.0, 450.0], + "status": ["shipped", "pending", "shipped", "delivered", "pending"], + } +) + +customers = pd.DataFrame( + { + "customer_id": [101, 102, 103], + "name": ["Alice", "Bob", "Carol"], + "city": ["Boston", "Chicago", "Denver"], + } +) + +qc = QueryChat(orders, "orders") +qc.add_table(customers, "customers") +qc.sidebar() + +with ui.navset_card_underline(): + with ui.nav_panel("Orders"): + + @render.data_frame + def orders_table(): + return qc.table("orders").df() + + with ui.nav_panel("Customers"): + + @render.data_frame + def customers_table(): + return qc.table("customers").df() + +ui.page_opts( + title="Orders & Customers", + fillable=True, +) diff --git a/pkg-py/src/querychat/__init__.py b/pkg-py/src/querychat/__init__.py index 0e3eaa5f5..953f73840 100644 --- a/pkg-py/src/querychat/__init__.py +++ b/pkg-py/src/querychat/__init__.py @@ -1,9 +1,11 @@ +from ._data_dict import DataDict from ._deprecated import greeting, init, sidebar, system_prompt from ._deprecated import mod_server as server from ._deprecated import mod_ui as ui from ._shiny import QueryChat __all__ = ( + "DataDict", "QueryChat", # TODO(lifecycle): Remove these deprecated functions when we reach v1.0 "greeting", diff --git a/pkg-py/src/querychat/_dash.py b/pkg-py/src/querychat/_dash.py index da9b57da9..5d4c4df10 100644 --- a/pkg-py/src/querychat/_dash.py +++ b/pkg-py/src/querychat/_dash.py @@ -8,12 +8,11 @@ from narwhals.stable.v1.typing import IntoDataFrameT, IntoFrameT, IntoLazyFrameT from ._dash_ui import IDs, card_ui, chat_container_ui, chat_messages_ui -from ._querychat_base import TOOL_GROUPS, QueryChatBase +from ._querychat_base import TOOL_GROUPS, StateDictQueryChat from ._querychat_core import ( GREETING_PROMPT, AppState, AppStateDict, - StateDictAccessorMixin, create_app_state, stream_response_async, ) @@ -33,8 +32,10 @@ import dash from dash import html + from ._data_dict import DataDict -class QueryChat(QueryChatBase[IntoFrameT], StateDictAccessorMixin[IntoFrameT]): + +class QueryChat(StateDictQueryChat[IntoFrameT]): """ QueryChat for Dash applications. @@ -92,16 +93,17 @@ def update_sql(state): @overload def __init__( self: QueryChat[Any], - data_source: None, - table_name: str, + data_source: None = None, + table_name: str | None = None, *, greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | PathType] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | PathType | None = None, extra_instructions: Optional[str | PathType] = None, prompt_template: Optional[str | PathType] = None, + categorical_threshold: int = 20, + data_description: Optional[str | PathType] = None, storage_type: Literal["memory", "session", "local"] = "memory", ) -> None: ... @@ -114,10 +116,11 @@ def __init__( greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | PathType] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | PathType | None = None, extra_instructions: Optional[str | PathType] = None, prompt_template: Optional[str | PathType] = None, + categorical_threshold: int = 20, + data_description: Optional[str | PathType] = None, storage_type: Literal["memory", "session", "local"] = "memory", ) -> None: ... @@ -130,10 +133,11 @@ def __init__( greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | PathType] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | PathType | None = None, extra_instructions: Optional[str | PathType] = None, prompt_template: Optional[str | PathType] = None, + categorical_threshold: int = 20, + data_description: Optional[str | PathType] = None, storage_type: Literal["memory", "session", "local"] = "memory", ) -> None: ... @@ -146,10 +150,11 @@ def __init__( greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | PathType] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | PathType | None = None, extra_instructions: Optional[str | PathType] = None, prompt_template: Optional[str | PathType] = None, + categorical_threshold: int = 20, + data_description: Optional[str | PathType] = None, storage_type: Literal["memory", "session", "local"] = "memory", ) -> None: ... @@ -162,25 +167,27 @@ def __init__( greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | PathType] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | PathType | None = None, extra_instructions: Optional[str | PathType] = None, prompt_template: Optional[str | PathType] = None, + categorical_threshold: int = 20, + data_description: Optional[str | PathType] = None, storage_type: Literal["memory", "session", "local"] = "memory", ) -> None: ... def __init__( self, - data_source: IntoFrame | sqlalchemy.Engine | ibis.Table | None, - table_name: str, + data_source: IntoFrame | sqlalchemy.Engine | ibis.Table | None = None, + table_name: str | None = None, *, greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | PathType] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | PathType | None = None, extra_instructions: Optional[str | PathType] = None, prompt_template: Optional[str | PathType] = None, + categorical_threshold: int = 20, + data_description: Optional[str | PathType] = None, storage_type: Literal["memory", "session", "local"] = "memory", ): super().__init__( @@ -190,12 +197,13 @@ def __init__( client=client, tools=tools, data_description=data_description, + data_dict=data_dict, categorical_threshold=categorical_threshold, extra_instructions=extra_instructions, prompt_template=prompt_template, ) self._storage_type: Literal["memory", "session", "local"] = storage_type - self._ids = IDs.from_table_name(table_name) + self._ids = IDs.from_table_name(table_name or "querychat") self._initialized_apps: set[int] = set() @property @@ -217,12 +225,18 @@ def app(self) -> dash.Dash: A Dash app ready to run. """ - data_source = self._require_data_source("app") + self._require_initialized("app") + if len(self._data_sources) > 1: + table_list = ", ".join(f"'{n}'" for n in self._data_sources) + raise RuntimeError( + f"app() does not support multiple tables ({table_list}). " + "Build a custom layout using ui() and table('name') instead." + ) import dash_bootstrap_components as dbc import dash - table_name = data_source.table_name + table_name = next(iter(self._data_sources)) app = dash.Dash( __name__, @@ -235,7 +249,7 @@ def app(self) -> dash.Dash: register_app_callbacks( app, self._ids, - data_source.table_name, + table_name, self._deserialize_state, ) @@ -276,13 +290,14 @@ def ui( ... return f"Current SQL: {sql}" """ - data_source = self._require_data_source("ui") + self._require_initialized("ui") from dash import dcc, html initial_state = create_app_state( - data_source, - self._client_factory, - self.greeting, + data_sources=dict(self._data_sources), + client_factory=self._client_factory, + greeting=self.greeting, + query_executor=self._require_query_executor("ui"), ) return html.Div( diff --git a/pkg-py/src/querychat/_data_dict.py b/pkg-py/src/querychat/_data_dict.py new file mode 100644 index 000000000..6fa909c8c --- /dev/null +++ b/pkg-py/src/querychat/_data_dict.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel + +if TYPE_CHECKING: + from ._datasource import ColumnMeta + from ._query_executor import QueryExecutor + + +class ColumnRange(BaseModel): + """Inclusive numeric range for a column, used instead of live min/max queries.""" + + min: Any = None + max: Any = None + + +class ColumnSpec(BaseModel): + """ + Per-column metadata entry in a :class:`DataDict`. + + All fields are optional. Only ``name`` is required, and is used to match + this spec against columns returned by the data source. + + Parameters + ---------- + name + Column name as it appears in the data source. + type + Human-readable type override (e.g. ``"date"``, ``"currency"``). When + supplied, this replaces the inferred SQL type in the LLM schema view. + constraints + Free-text constraints conveyed to the LLM (e.g. ``"non-negative"``). + description + Short description forwarded verbatim to the LLM's schema view. + details + Longer narrative about the column, used only in the on-demand + ``get_schema`` tool response. + units + Unit label (e.g. ``"kg"``, ``"USD"``), included in the schema view. + values + Exhaustive list of valid values. Replaces categorical inference for + this column β€” querychat will not query the data source for distinct + values when this is set. + range + Inclusive min/max bounds. Replaces live min/max statistics queries + when set. + examples + Representative sample values shown to the LLM as context. + + """ + + name: str + type: str | None = None + constraints: list[str] = [] + description: str | None = None + details: str | None = None + units: str | None = None + values: list[Any] | None = None + range: ColumnRange | None = None + examples: list[Any] | None = None + + +class TableSpec(BaseModel): + """ + Metadata for a single table in a :class:`DataDict`. + + Parameters + ---------- + description + Short description of the table, forwarded to the LLM's schema view. + details + Longer narrative shown only in the on-demand ``get_schema`` tool + response. + columns + Per-column specifications. Columns not listed here are documented + using live statistics inferred from the data. + + """ + + description: str | None = None + details: str | None = None + columns: list[ColumnSpec] = [] + + +class RelationshipSpec(BaseModel): + """ + A declared relationship between two tables. + + Parameters + ---------- + description + Human-readable description of the relationship. + cardinality + Cardinality string (e.g. ``"one-to-many"``). + join + SQL JOIN clause or expression that links the tables. + + """ + + description: str | None = None + cardinality: str | None = None + join: str + + +class DataDict(BaseModel): + """ + A data dictionary providing rich per-table and per-column metadata. + + Pass a ``DataDict`` to ``QueryChat`` (or load one from YAML via + :meth:`from_yaml`) to give the LLM better context about your data without + querying the data source for statistics at startup. + + For columns listed in a ``DataDict``: + + * ``values`` replaces categorical inference (no ``SELECT DISTINCT`` query). + * ``range`` replaces live min/max statistics queries. + * ``description`` is forwarded verbatim to the LLM's schema view. + + Columns not listed fall back to the normal live-statistics path. + + Parameters + ---------- + name + Short identifier for this dictionary's domain (e.g. ``"sales"``). + Used as the ``name`` attribute on the ```` tag in the system + prompt. When loading from YAML via :meth:`from_yaml`, defaults to the + file stem if not set explicitly. + description + One-line summary of the domain, shown alongside ``name`` in the + system prompt. + tables + Per-table metadata, keyed by table name. Each value is a + :class:`TableSpec` with optional description and column specs. + Table names must match those registered with ``QueryChat``. + relationships + Cross-table relationship declarations. Useful context for multi-table + apps where the LLM needs to know how tables join. + glossary + Domain-specific term definitions passed to the LLM as context + (e.g. ``{"ARR": "Annual Recurring Revenue"}``). + + Examples + -------- + Load from a YAML file: + + ```python + from querychat import QueryChat, DataDict + + qc = QueryChat(df, "sales", data_dict=DataDict.from_yaml("data_dict.yaml")) + ``` + + Or pass a path directly and let QueryChat load it: + + ```python + qc = QueryChat(df, "sales", data_dict="data_dict.yaml") + ``` + + """ + + name: str | None = None + description: str | None = None + tables: dict[str, TableSpec] = {} + relationships: list[RelationshipSpec] = [] + glossary: dict[str, str] = {} + + def to_prompt_dict(self) -> dict[str, Any]: + """Return a filtered dict for the system prompt (excludes per-column details).""" + result: dict[str, Any] = {} + if self.name is not None: + result["name"] = self.name + if self.description is not None: + result["description"] = self.description + if self.tables: + result["tables"] = { + name: ({"description": spec.description} if spec.description else None) + for name, spec in self.tables.items() + } + if self.relationships: + result["relationships"] = [ + {k: v for k, v in rel.model_dump().items() if v is not None} + for rel in self.relationships + ] + if self.glossary: + result["glossary"] = self.glossary + return result + + def get_table_schema( + self, + table_name: str, + executor: QueryExecutor, + categorical_threshold: int, + ) -> list[ColumnMeta]: + # Get authoritative column names + types via cheap LIMIT 0 + metas: list[ColumnMeta] = executor.get_column_metas(table_name) + + # Build lookup from data_dict for this table + table_spec = self.tables.get(table_name) + documented: dict[str, ColumnSpec] = ( + {col.name: col for col in table_spec.columns} if table_spec else {} + ) + + undocumented: list[ColumnMeta] = [] + for meta in metas: + spec = documented.get(meta.name) + if spec is not None: + if spec.type is not None: + meta.sql_type = spec.type + if spec.range is not None: + meta.min_val = spec.range.min + meta.max_val = spec.range.max + if spec.values is not None: + meta.categories = [str(v) for v in spec.values] + if spec.description is not None: + meta.description = spec.description + if spec.units is not None: + meta.units = spec.units + if spec.constraints: + meta.constraints = list(spec.constraints) + else: + undocumented.append(meta) + + if undocumented: + executor.populate_column_stats(table_name, undocumented, categorical_threshold) + + return metas + + @classmethod + def from_yaml(cls, path: Path | str) -> DataDict: + """ + Load a :class:`DataDict` from a YAML file. + + Parameters + ---------- + path + Path to the YAML file. + + """ + import yaml + + path = Path(path) + with path.open() as f: + data = yaml.safe_load(f) or {} + dd = cls.model_validate(data) + if dd.name is None: + dd = dd.model_copy(update={"name": path.stem}) + return dd diff --git a/pkg-py/src/querychat/_datasource.py b/pkg-py/src/querychat/_datasource.py index ed2c8ecd4..55af587db 100644 --- a/pkg-py/src/querychat/_datasource.py +++ b/pkg-py/src/querychat/_datasource.py @@ -51,14 +51,30 @@ class ColumnMeta: categories: list[str] = field(default_factory=list) """Unique values for text columns below the categorical threshold.""" + description: str | None = None + """Optional human-readable description of the column.""" + + units: str | None = None + """Unit label (e.g. 'kg', 'USD').""" + + constraints: list[str] = field(default_factory=list) + """Free-text constraints (e.g. 'non-negative').""" + def format_schema(table_name: str, columns: list[ColumnMeta]) -> str: """Format column metadata into schema string.""" lines = [f"Table: {table_name}", "Columns:"] for col in columns: - lines.append(f"- {col.name} ({col.sql_type})") - + header = f"- {col.name} ({col.sql_type})" + if col.units: + header += f" [{col.units}]" + lines.append(header) + + if col.description: + lines.append(f" Description: {col.description}") + if col.constraints: + lines.append(f" Constraints: {', '.join(col.constraints)}") if col.kind in ("numeric", "date") and col.min_val is not None and col.max_val is not None: lines.append(f" Range: {col.min_val} to {col.max_val}") elif col.categories: @@ -152,19 +168,6 @@ def duckdb_column_stats( pass -def duckdb_get_schema( - conn: duckdb.DuckDBPyConnection, - table_name: str, - categorical_threshold: int, -) -> str: - """Generate schema string from a DuckDB connection and table name.""" - result = conn.execute(f'SELECT * FROM "{table_name}" LIMIT 0') - columns = [ - duckdb_column_meta(desc[0], desc[1]) for desc in result.description - ] - duckdb_column_stats(conn, table_name, columns, categorical_threshold) - return format_schema(table_name, columns) - def duckdb_lock_down(conn: duckdb.DuckDBPyConnection) -> None: """Lock down a DuckDB connection to prevent LLM-generated SQL from accessing external resources.""" @@ -220,6 +223,18 @@ def get_schema(self, *, categorical_threshold: int) -> str: """ ... + @abstractmethod + def get_column_metas(self) -> list[ColumnMeta]: + """Return column names and types without running stats queries.""" + ... + + @abstractmethod + def populate_column_stats( + self, columns: list[ColumnMeta], categorical_threshold: int + ) -> None: + """Populate min/max/categories on the given ColumnMeta list in place.""" + ... + @abstractmethod def execute_query(self, query: str) -> IntoFrameT: """ @@ -368,7 +383,18 @@ def get_schema(self, *, categorical_threshold: int) -> str: String describing the schema """ - return duckdb_get_schema(self._conn, self.table_name, categorical_threshold) + metas = self.get_column_metas() + self.populate_column_stats(metas, categorical_threshold) + return format_schema(self.table_name, metas) + + def get_column_metas(self) -> list[ColumnMeta]: + result = self._conn.execute(f'SELECT * FROM "{self.table_name}" LIMIT 0') + return [duckdb_column_meta(desc[0], desc[1]) for desc in result.description] + + def populate_column_stats( + self, columns: list[ColumnMeta], categorical_threshold: int + ) -> None: + duckdb_column_stats(self._conn, self.table_name, columns, categorical_threshold) def execute_query(self, query: str) -> IntoDataFrameT: """ @@ -536,6 +562,11 @@ def get_db_type(self) -> str: """ return self._engine.dialect.name.upper().replace(" SQL", "") + @property + def engine(self) -> Engine: + """The SQLAlchemy engine for this data source.""" + return self._engine + def get_schema(self, *, categorical_threshold: int) -> str: """ Generate schema information from database table. @@ -552,12 +583,20 @@ def get_schema(self, *, categorical_threshold: int) -> str: String describing the schema """ - columns = [ + metas = self.get_column_metas() + self.populate_column_stats(metas, categorical_threshold) + return format_schema(self.table_name, metas) + + def get_column_metas(self) -> list[ColumnMeta]: + return [ self._make_column_meta(col["name"], col["type"]) for col in self._columns_info ] + + def populate_column_stats( + self, columns: list[ColumnMeta], categorical_threshold: int + ) -> None: self._add_column_stats(columns, categorical_threshold) - return format_schema(self.table_name, columns) def get_semantic_views_description(self) -> str: """Get information about semantic views (if any) for the system prompt.""" @@ -827,14 +866,17 @@ def get_db_type(self) -> str: def get_schema(self, *, categorical_threshold: int) -> str: """Generate schema information from LazyFrame using lazy aggregates.""" - # Build column metadata (classification happens here) - columns = [ - self._make_column_meta(name, dtype) for name, dtype in self._schema.items() - ] + metas = self.get_column_metas() + self.populate_column_stats(metas, categorical_threshold) + return format_schema(self.table_name, metas) + + def get_column_metas(self) -> list[ColumnMeta]: + return [self._make_column_meta(name, dtype) for name, dtype in self._schema.items()] - # Add stats to the metadata and format schema string + def populate_column_stats( + self, columns: list[ColumnMeta], categorical_threshold: int + ) -> None: self._add_column_stats(columns, self._lf, categorical_threshold) - return format_schema(self.table_name, columns) def execute_query(self, query: str) -> pl.LazyFrame: """ @@ -1029,12 +1071,25 @@ def __init__(self, table: ibis.Table, table_name: str): def get_db_type(self) -> str: return self._backend.name + @property + def backend(self) -> SQLBackend: + """The Ibis SQL backend for this data source.""" + return self._backend + def get_schema(self, *, categorical_threshold: int) -> str: - columns = [ + metas = self.get_column_metas() + self.populate_column_stats(metas, categorical_threshold) + return format_schema(self.table_name, metas) + + def get_column_metas(self) -> list[ColumnMeta]: + return [ self._make_column_meta(name, dtype) for name, dtype in self._schema.items() ] + + def populate_column_stats( + self, columns: list[ColumnMeta], categorical_threshold: int + ) -> None: self._add_column_stats(columns, self._table, categorical_threshold) - return format_schema(self.table_name, columns) def get_semantic_views_description(self) -> str: """Get information about semantic views (if any) for the system prompt.""" diff --git a/pkg-py/src/querychat/_gradio.py b/pkg-py/src/querychat/_gradio.py index cc0067084..18a87cb88 100644 --- a/pkg-py/src/querychat/_gradio.py +++ b/pkg-py/src/querychat/_gradio.py @@ -10,11 +10,10 @@ if TYPE_CHECKING: import narwhals.stable.v1 as nw -from ._querychat_base import TOOL_GROUPS, QueryChatBase +from ._querychat_base import TOOL_GROUPS, StateDictQueryChat from ._querychat_core import ( GREETING_PROMPT, AppStateDict, - StateDictAccessorMixin, create_app_state, stream_response, ) @@ -31,8 +30,10 @@ import gradio as gr + from ._data_dict import DataDict -class QueryChat(QueryChatBase[IntoFrameT], StateDictAccessorMixin[IntoFrameT]): + +class QueryChat(StateDictQueryChat[IntoFrameT]): """ QueryChat for Gradio applications. @@ -86,16 +87,17 @@ def update_outputs(state_dict): @overload def __init__( self: QueryChat[Any], - data_source: None, - table_name: str, + data_source: None = None, + table_name: str | None = None, *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... @overload @@ -107,10 +109,11 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... @overload @@ -122,10 +125,11 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... @overload @@ -137,10 +141,11 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... @overload @@ -152,24 +157,26 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... def __init__( self, - data_source: IntoFrame | sqlalchemy.Engine | ibis.Table | None, - table_name: str, + data_source: IntoFrame | sqlalchemy.Engine | ibis.Table | None = None, + table_name: str | None = None, *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ): super().__init__( data_source, @@ -178,6 +185,7 @@ def __init__( client=client, tools=tools, data_description=data_description, + data_dict=data_dict, categorical_threshold=categorical_threshold, extra_instructions=extra_instructions, prompt_template=prompt_template, @@ -247,11 +255,14 @@ def ui(self) -> gr.State: >>> app.launch(css=qc.css, head=qc.head) """ - data_source = self._require_data_source("ui") + self._require_initialized("ui") import gradio as gr initial_state = create_app_state( - data_source, self._client_factory, self.greeting + data_sources=dict(self._data_sources), + client_factory=self._client_factory, + greeting=self.greeting, + query_executor=self._require_query_executor("ui"), ) state_holder = gr.State(value=initial_state.to_dict()) @@ -328,12 +339,12 @@ def app(self) -> GradioBlocksWrapper: querychat CSS/JS at launch time for Gradio 6.0+ compatibility. """ - data_source = self._require_data_source("app") + self._require_initialized("app") from gradio.themes import Soft import gradio as gr - table_name = data_source.table_name + table_name = next(iter(self._data_sources)) with gr.Blocks( title=f"querychat with {table_name}", @@ -368,17 +379,14 @@ def app(self) -> GradioBlocksWrapper: def update_displays(state_dict: AppStateDict): """Update SQL and data displays based on state.""" - title = state_dict.get("title") if state_dict else None - error = state_dict.get("error") if state_dict else None + state = self._deserialize_state(state_dict) + df = state.get_current_data() + title = state.title + error = state.error sql_title_text = f"### {title or 'SQL Query'}" - sql_code = ( - state_dict.get("sql") - if state_dict and state_dict.get("sql") - else f"SELECT * FROM {table_name}" - ) + sql_code = state.get_display_sql() - df = self.df(state_dict) nw_df = as_narwhals(df) nrow, ncol = nw_df.shape native_df = nw_df.to_native() diff --git a/pkg-py/src/querychat/_icons.py b/pkg-py/src/querychat/_icons.py index fc484c9c0..25cabdf00 100644 --- a/pkg-py/src/querychat/_icons.py +++ b/pkg-py/src/querychat/_icons.py @@ -9,6 +9,7 @@ "download", "funnel-fill", "graph-up", + "search", "terminal-fill", "table", ] @@ -30,6 +31,7 @@ def bs_icon(name: ICON_NAMES, cls: str = "") -> ui.HTML: "chevron-down": '', "download": '', "funnel-fill": '', + "search": '', "graph-up": '', "terminal-fill": '', "table": '', diff --git a/pkg-py/src/querychat/_pin_source.py b/pkg-py/src/querychat/_pin_source.py index 58240f8cb..fd62ae9ef 100644 --- a/pkg-py/src/querychat/_pin_source.py +++ b/pkg-py/src/querychat/_pin_source.py @@ -7,10 +7,13 @@ import narwhals.stable.v1 as nw from ._datasource import ( + ColumnMeta, DataSource, MissingColumnsError, - duckdb_get_schema, + duckdb_column_meta, + duckdb_column_stats, duckdb_lock_down, + format_schema, ) from ._utils import check_query @@ -186,7 +189,18 @@ def get_db_type(self) -> str: return "DuckDB" def get_schema(self, *, categorical_threshold: int) -> str: - return duckdb_get_schema(self._conn, self.table_name, categorical_threshold) + metas = self.get_column_metas() + self.populate_column_stats(metas, categorical_threshold) + return format_schema(self.table_name, metas) + + def get_column_metas(self) -> list[ColumnMeta]: + result = self._conn.execute(f'SELECT * FROM "{self.table_name}" LIMIT 0') + return [duckdb_column_meta(desc[0], desc[1]) for desc in result.description] + + def populate_column_stats( + self, columns: list[ColumnMeta], categorical_threshold: int + ) -> None: + duckdb_column_stats(self._conn, self.table_name, columns, categorical_threshold) def execute_query(self, query: str) -> nw.DataFrame: check_query(query) diff --git a/pkg-py/src/querychat/_query_executor.py b/pkg-py/src/querychat/_query_executor.py new file mode 100644 index 000000000..ab5a46b02 --- /dev/null +++ b/pkg-py/src/querychat/_query_executor.py @@ -0,0 +1,306 @@ +"""QueryExecutor abstraction for cross-table query execution.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +import duckdb +import narwhals.stable.v1 as nw + +from ._datasource import ( + ColumnMeta, + MissingColumnsError, + duckdb_column_meta, + duckdb_column_stats, + duckdb_lock_down, + format_schema, +) +from ._utils import check_query + +if TYPE_CHECKING: + from ._datasource import DataFrameSource, DataSource, PolarsLazySource + + +class QueryExecutor(ABC): + """Thin abstraction that tools use for query execution and validation.""" + + @abstractmethod + def execute_query(self, query: str) -> Any: ... + + @abstractmethod + def test_query( + self, query: str, *, table_name: str, require_all_columns: bool = False + ) -> None: ... + + @abstractmethod + def get_db_type(self) -> str: ... + + @abstractmethod + def cleanup(self) -> None: ... + + @abstractmethod + def get_column_metas(self, table_name: str) -> list[ColumnMeta]: ... + + @abstractmethod + def populate_column_stats( + self, table_name: str, columns: list[ColumnMeta], categorical_threshold: int + ) -> None: ... + + def get_column_details(self, table_name: str, categorical_threshold: int) -> list[ColumnMeta]: + metas = self.get_column_metas(table_name) + self.populate_column_stats(table_name, metas, categorical_threshold) + return metas + + def get_schema(self, table_name: str, categorical_threshold: int) -> str: + return format_schema(table_name, self.get_column_details(table_name, categorical_threshold)) + + @staticmethod + def _validate_missing_columns( + result_columns: set[str], expected_columns: list[str] + ) -> None: + missing = set(expected_columns) - result_columns + if missing: + missing_list = ", ".join(f"'{c}'" for c in sorted(missing)) + original_list = ", ".join(f"'{c}'" for c in expected_columns) + raise MissingColumnsError( + f"Query result missing required columns: {missing_list}. " + f"The query must return all original table columns. " + f"Original columns: {original_list}" + ) + + +class DuckDBExecutor(QueryExecutor): + """Shared DuckDB connection for multi-table DataFrameSource queries.""" + + def __init__(self, sources: dict[str, DataFrameSource]): + self._df_lib = get_shared_dataframe_backend(sources) + self._conn = duckdb.connect(database=":memory:") + + for name, source in sources.items(): + self._conn.register(name, source.get_data()) + + # Cache column names per table before lockdown + self._table_columns: dict[str, list[str]] = {} + for name in sources: + result = self._conn.execute(f'SELECT * FROM "{name}" LIMIT 0') + self._table_columns[name] = [desc[0] for desc in result.description] + + duckdb_lock_down(self._conn) + + def execute_query(self, query: str) -> Any: + check_query(query) + result = self._conn.execute(query) + return self._convert_result(result) + + def _convert_result(self, result: duckdb.DuckDBPyConnection) -> Any: + if self._df_lib == "polars": + return result.pl() + elif self._df_lib == "pandas": + return result.df() + elif self._df_lib == "pyarrow": + return result.fetch_arrow_table() + else: + raise ValueError( + f"Unsupported DataFrame backend: '{self._df_lib}'. " + "Supported backends are: polars, pandas, pyarrow" + ) + + def test_query( + self, query: str, *, table_name: str, require_all_columns: bool = False + ) -> None: + check_query(query) + result = self._conn.execute(f"{query} LIMIT 1") + + if require_all_columns: + result_columns = {desc[0] for desc in result.description} + self._validate_missing_columns(result_columns, self._table_columns[table_name]) + + def get_db_type(self) -> str: + return "DuckDB" + + def cleanup(self) -> None: + if self._conn: + self._conn.close() + + def get_column_metas(self, table_name: str) -> list[ColumnMeta]: + result = self._conn.execute(f'SELECT * FROM "{table_name}" LIMIT 0') + return [duckdb_column_meta(desc[0], desc[1]) for desc in result.description] + + def populate_column_stats( + self, table_name: str, columns: list[ColumnMeta], categorical_threshold: int + ) -> None: + duckdb_column_stats(self._conn, table_name, columns, categorical_threshold) + + +class PolarsSQLExecutor(QueryExecutor): + """Shared Polars SQLContext for multi-table PolarsLazySource queries.""" + + def __init__(self, sources: dict[str, PolarsLazySource]): + import polars as pl + + frames = {name: source.get_data() for name, source in sources.items()} + self._ctx = pl.SQLContext(frames) + self._sources = sources # stored for schema delegation + + self._table_columns: dict[str, list[str]] = {} + for name, source in sources.items(): + self._table_columns[name] = list(source.get_data().collect_schema().keys()) + + def execute_query(self, query: str) -> Any: + check_query(query) + return self._ctx.execute(query) + + def test_query( + self, query: str, *, table_name: str, require_all_columns: bool = False + ) -> None: + check_query(query) + test_lf = self._ctx.execute(f"SELECT * FROM ({query}) AS subquery LIMIT 1") + test_lf.collect() + + if require_all_columns: + full_lf = self._ctx.execute(query) + result_columns = set(full_lf.collect_schema().keys()) + self._validate_missing_columns(result_columns, self._table_columns[table_name]) + + def get_db_type(self) -> str: + return "Polars" + + def cleanup(self) -> None: + pass + + def get_column_metas(self, table_name: str) -> list[ColumnMeta]: + return self._sources[table_name].get_column_metas() + + def populate_column_stats( + self, table_name: str, columns: list[ColumnMeta], categorical_threshold: int + ) -> None: + self._sources[table_name].populate_column_stats(columns, categorical_threshold) + + +class DataSourceExecutor(QueryExecutor): + """ + Wraps existing DataSource(s) for backends that already share a connection. + + Used for single-table mode (any source type) and multi-table SQLAlchemy/Ibis + where all sources share the same database backend. + """ + + def __init__(self, data_sources: dict[str, DataSource]): + validate_source_group_compatibility(data_sources) + self._data_sources = data_sources + self._primary = next(iter(data_sources.values())) + + def execute_query(self, query: str) -> Any: + return self._primary.execute_query(query) + + def test_query( + self, query: str, *, table_name: str, require_all_columns: bool = False + ) -> None: + self._data_sources[table_name].test_query( + query, require_all_columns=require_all_columns + ) + + def get_db_type(self) -> str: + return self._primary.get_db_type() + + def cleanup(self) -> None: + pass + + def get_column_metas(self, table_name: str) -> list[ColumnMeta]: + return self._data_sources[table_name].get_column_metas() + + def populate_column_stats( + self, table_name: str, columns: list[ColumnMeta], categorical_threshold: int + ) -> None: + self._data_sources[table_name].populate_column_stats(columns, categorical_threshold) + + +def get_shared_dataframe_backend(sources: dict[str, DataFrameSource]) -> str: + """Return the shared backend name, rejecting mixed DataFrameSource backends.""" + source_items = iter(sources.items()) + _, first_source = next(source_items) + shared_lib = get_dataframe_backend_name(first_source) + + for name, source in source_items: + source_lib = get_dataframe_backend_name(source) + if source_lib != shared_lib: + raise ValueError( + f"Cannot add table '{name}': all DataFrameSources must use " + f"the same DataFrame backend. " + f"Existing tables use {shared_lib}, new table uses {source_lib}." + ) + + return shared_lib + + +def validate_source_group_compatibility(data_sources: dict[str, DataSource]) -> None: + """Validate that a group of sources satisfies shared executor constraints.""" + existing: dict[str, DataSource] = {} + for name, source in data_sources.items(): + check_source_compatibility(existing, source, name) + existing[name] = source + + +def check_source_compatibility( + existing: dict[str, DataSource], + new_source: DataSource, + new_name: str, +) -> None: + """Validate that a new source is compatible with existing sources.""" + if not existing: + return + + from ._datasource import ( + DataFrameSource, + IbisSource, + SQLAlchemySource, + ) + + first_source = next(iter(existing.values())) + + if type(new_source) is not type(first_source): + raise ValueError( + f"Cannot add {type(new_source).__name__} table '{new_name}': " + f"all tables must be the same type. " + f"Existing tables use {type(first_source).__name__}." + ) + + if isinstance(new_source, DataFrameSource) and isinstance( + first_source, DataFrameSource + ): + new_lib = get_dataframe_backend_name(new_source) + existing_lib = get_dataframe_backend_name(first_source) + if new_lib != existing_lib: + raise ValueError( + f"Cannot add table '{new_name}': all DataFrameSources must use " + f"the same DataFrame backend. " + f"Existing tables use {existing_lib}, new table uses {new_lib}." + ) + + if ( + isinstance(new_source, SQLAlchemySource) + and isinstance(first_source, SQLAlchemySource) + and new_source.engine is not first_source.engine + ): + raise ValueError( + f"Cannot add table '{new_name}': all SQLAlchemy tables must " + f"share the same Engine instance." + ) + + if ( + isinstance(new_source, IbisSource) + and isinstance(first_source, IbisSource) + and new_source.backend is not first_source.backend + ): + raise ValueError( + f"Cannot add table '{new_name}': all Ibis tables must " + f"share the same backend instance." + ) + + +def get_dataframe_backend_name(source: DataFrameSource) -> str: + """Return the native eager dataframe backend name for a DataFrameSource.""" + return nw.get_native_namespace( + nw.from_native(source.get_data(), eager_only=True) + ).__name__ diff --git a/pkg-py/src/querychat/_querychat_base.py b/pkg-py/src/querychat/_querychat_base.py index feaf3a45d..782660ff8 100644 --- a/pkg-py/src/querychat/_querychat_base.py +++ b/pkg-py/src/querychat/_querychat_base.py @@ -2,9 +2,11 @@ from __future__ import annotations +import contextlib import copy import os import re +import warnings from pathlib import Path from typing import TYPE_CHECKING, Generic, Literal, Optional @@ -21,12 +23,28 @@ SQLAlchemySource, ) from ._pin_source import PinSource, is_pins_board -from ._querychat_core import GREETING_PROMPT +from ._query_executor import ( + DataSourceExecutor, + DuckDBExecutor, + PolarsSQLExecutor, + QueryExecutor, + check_source_compatibility, + validate_source_group_compatibility, +) +from ._querychat_core import ( + GREETING_PROMPT, + AppState, + AppStateDict, + create_app_state, + warn_multi_table_flat_accessor, +) from ._system_prompt import QueryChatSystemPrompt from ._utils import MISSING, MISSING_TYPE, is_ibis_table from ._viz_utils import has_viz_deps, has_viz_tool from .tools import ( + ResetDashboardCallback, UpdateDashboardData, + tool_get_schema, tool_query, tool_reset_dashboard, tool_update_dashboard, @@ -39,6 +57,7 @@ from narwhals.stable.v1.typing import IntoFrame from pins.boards import BaseBoard + from ._data_dict import DataDict from ._viz_tools import VisualizeData TOOL_GROUPS = Literal["filter", "update", "query", "visualize"] @@ -59,36 +78,26 @@ class QueryChatBase(Generic[IntoFrameT]): def __init__( self, - data_source: IntoFrame | sqlalchemy.Engine | BaseBoard | None, + data_source: IntoFrame | sqlalchemy.Engine | BaseBoard | None = None, table_name: str | None = None, *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | list[DataDict | str | Path] | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ): - if table_name is None: - if isinstance(data_source, DataSource): - table_name = data_source.table_name - elif data_source is not None: - raise ValueError( - "table_name is required when data_source is not a DataSource" - ) + self._data_dicts: list[DataDict] = _normalize_data_dicts(data_dict) - # Store table_name for later normalization - self._table_name = table_name + # Multi-table storage: dict of data sources keyed by table name + self._data_sources: dict[str, DataSource] = {} + self._query_executor: QueryExecutor | None = None - if ( - table_name is not None - and not is_pins_board(data_source) - and not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", table_name) - ): - raise ValueError( - "Table name must begin with a letter and contain only letters, numbers, and underscores", - ) + # Track server initialization state for add/remove table validation + self._server_initialized = False self.tools = normalize_tools(tools, default=DEFAULT_TOOLS) self.greeting = greeting.read_text() if isinstance(greeting, Path) else greeting @@ -96,66 +105,109 @@ def __init__( # Store init parameters for deferred system prompt building self._prompt_template = prompt_template self._data_description = data_description - self._data_description_mode: Literal["supplied", "inferred", "empty"] = ( - "supplied" if data_description is not None else "empty" - ) self._extra_instructions = extra_instructions self._categorical_threshold = categorical_threshold self._client_spec: str | chatlas.Chat | None = client self._client_console = None - # Initialize data source (may be None for deferred pattern) + self._system_prompt: QueryChatSystemPrompt | None = None + if data_source is not None: if table_name is None: - raise ValueError("table_name is required when data_source is provided") - self._data_source: DataSource | None = normalize_data_source( - data_source, table_name - ) - self._table_name = self._data_source.table_name - self._auto_fill_data_description() - self._build_system_prompt() - else: - self._data_source = None - self._system_prompt = None - - def _auto_fill_data_description(self) -> None: - """Auto-populate data_description from data source metadata if not user-supplied.""" - if self._data_description_mode == "inferred": - self._data_description = None - self._data_description_mode = "empty" - if self._data_description_mode == "empty" and self._data_source is not None: - desc = self._data_source.get_data_description() - if desc: - self._data_description = desc - self._data_description_mode = "inferred" - - def _build_system_prompt(self) -> None: - """Build/rebuild the system prompt from current data source.""" - if self._data_source is None: + if isinstance(data_source, DataSource): + table_name = data_source.table_name + else: + raise ValueError( + "table_name is required when data_source is provided" + ) + self.add_table(data_source, table_name) + + def _build_system_prompt( + self, + *, + data_sources: dict[str, DataSource] | None = None, + ) -> None: + """Build/rebuild the system prompt from current or staged data sources.""" + next_data_sources = self._data_sources if data_sources is None else data_sources + + if not next_data_sources: raise RuntimeError("Cannot build system prompt without data_source") - prompt_template = self._prompt_template - if prompt_template is None: - prompt_template = Path(__file__).parent / "prompts" / "prompt.md" + client_has_history = ( + isinstance(self._client_spec, chatlas.Chat) and bool(self._client_spec.get_turns()) + ) or ( + self._client_console is not None and bool(self._client_console.get_turns()) + ) + if client_has_history: + warnings.warn( + "System prompt rebuilt after chat history exists. " + "This invalidates any prompt caching from prior turns. " + "Configure all tables before starting a conversation.", + UserWarning, + stacklevel=3, + ) self._system_prompt = QueryChatSystemPrompt( - prompt_template=prompt_template, - data_source=self._data_source, + prompt_template=self._prompt_template, + data_sources=next_data_sources, data_description=self._data_description, extra_instructions=self._extra_instructions, categorical_threshold=self._categorical_threshold, + data_dicts=self._data_dicts, ) - def _require_data_source(self, method_name: str) -> DataSource[IntoFrameT]: - """Raise if data_source is not set, otherwise return it for type narrowing.""" - if self._data_source is None: + def _build_query_executor( + self, *, data_sources: dict[str, DataSource] | None = None + ) -> QueryExecutor: + """Build a query executor from current or staged data sources.""" + sources = self._data_sources if data_sources is None else data_sources + + validate_source_group_compatibility(sources) + + if len(sources) == 1: + return DataSourceExecutor(dict(sources)) + + first_source = next(iter(sources.values())) + + if isinstance(first_source, DataFrameSource): + return DuckDBExecutor( + {n: s for n, s in sources.items() if isinstance(s, DataFrameSource)} + ) + if isinstance(first_source, PolarsLazySource): + return PolarsSQLExecutor( + {n: s for n, s in sources.items() if isinstance(s, PolarsLazySource)} + ) + + return DataSourceExecutor(dict(sources)) + + def _require_initialized(self, method_name: str) -> None: + """Raise if no data sources have been registered.""" + if not self._data_sources: raise RuntimeError( - f"data_source must be set before calling {method_name}(). " - "Either pass data_source to __init__(), set the data_source property, " - "or pass data_source to server()." + f"At least one data source must be set before calling {method_name}(). " + "Either pass data_source to __init__() or call add_table()." + ) + + def _require_single_table(self, method_name: str) -> None: + """Raise if multiple tables are registered, directing to per-table API.""" + if len(self._data_sources) > 1: + table_list = ", ".join(f"'{n}'" for n in self._data_sources) + raise AttributeError( + f"Cannot use .{method_name}() with multiple tables ({table_list}). " + f"Use .table('name').{method_name}() for per-table access." ) - return self._data_source + + def _require_query_executor(self, method_name: str) -> QueryExecutor: + """Return the cached executor, building it lazily on first use.""" + if self._query_executor is None: + if not self._data_sources: + raise RuntimeError( + f"query executor must be set before calling {method_name}(). " + "Set the data_source first so querychat can build an executor." + ) + self._query_executor = self._build_query_executor() + return self._query_executor def _create_session_client( self, @@ -163,7 +215,7 @@ def _create_session_client( client_spec: str | chatlas.Chat | None | MISSING_TYPE = MISSING, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE = MISSING, update_dashboard: Callable[[UpdateDashboardData], None] | None = None, - reset_dashboard: Callable[[], None] | None = None, + reset_dashboard: ResetDashboardCallback | None = None, visualize: Callable[[VisualizeData], None] | None = None, ) -> chatlas.Chat: """Create a fresh, fully-configured Chat.""" @@ -178,20 +230,44 @@ def _create_session_client( if resolved_tools is None: return chat - data_source = self._require_data_source("_create_session_client") + executor = self._require_query_executor("_create_session_client") + + # Always register the schema tool (for all non-None tool sets) + chat.register_tool( + tool_get_schema( + self._data_dicts, + executor, + list(self._data_sources.keys()), + self._categorical_threshold, + ) + ) if "update" in resolved_tools: update_fn = update_dashboard or (lambda _: None) - reset_fn = reset_dashboard or (lambda: None) - chat.register_tool(tool_update_dashboard(data_source, update_fn)) - chat.register_tool(tool_reset_dashboard(reset_fn)) + user_reset = reset_dashboard or (lambda _table: None) + + chat.register_tool( + tool_update_dashboard( + executor, + list(self._data_sources.keys()), + update_fn, + multi_table=len(self._data_sources) > 1, + ) + ) + chat.register_tool( + tool_reset_dashboard(user_reset, list(self._data_sources.keys())) + ) if "query" in resolved_tools: - chat.register_tool(tool_query(data_source)) + chat.register_tool( + tool_query(executor, multi_table=len(self._data_sources) > 1) + ) if "visualize" in resolved_tools: viz_fn = visualize or (lambda _: None) - chat.register_tool(tool_visualize(data_source, viz_fn)) + chat.register_tool( + tool_visualize(executor, viz_fn, multi_table=len(self._data_sources) > 1) + ) return chat @@ -200,7 +276,7 @@ def client( *, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE = MISSING, update_dashboard: Callable[[UpdateDashboardData], None] | None = None, - reset_dashboard: Callable[[], None] | None = None, + reset_dashboard: ResetDashboardCallback | None = None, visualize: Callable[[VisualizeData], None] | None = None, ) -> chatlas.Chat: """ @@ -225,7 +301,7 @@ def client( A configured chat client. """ - self._require_data_source("client") + self._require_initialized("client") return self._create_session_client( tools=tools, update_dashboard=update_dashboard, @@ -235,7 +311,7 @@ def client( def generate_greeting(self, *, echo: Literal["none", "output"] = "none") -> str: """Generate a welcome greeting for the chat.""" - self._require_data_source("generate_greeting") + self._require_initialized("generate_greeting") chat = create_client(self._client_spec) if self._system_prompt is not None: chat.system_prompt = self._system_prompt.render(self.tools) @@ -249,7 +325,7 @@ def console( **kwargs, ) -> None: """Launch an interactive console chat with the data.""" - self._require_data_source("console") + self._require_initialized("console") if new or self._client_console is None: self._client_console = self.client(tools=tools, **kwargs) @@ -258,32 +334,268 @@ def console( @property def system_prompt(self) -> str: """Get the system prompt.""" - self._require_data_source("system_prompt") + self._require_initialized("system_prompt") if self._system_prompt is None: raise RuntimeError("System prompt not initialized") return self._system_prompt.render(self.tools) @property - def data_source(self) -> DataSource | None: - """Get the current data source.""" - return self._data_source + def data_source(self) -> DataSource: + """Removed. Use ``add_table()`` and ``remove_table()`` to manage tables.""" + raise AttributeError( + "The .data_source property has been removed. " + "Use qc.add_table(df, 'name') to add a new table, " + "or qc.add_table(df, 'name', replace=True) to replace an existing one." + ) @data_source.setter - def data_source(self, value: IntoFrame | sqlalchemy.Engine | BaseBoard) -> None: - """Set the data source, normalizing and rebuilding system prompt.""" - old_source = self._data_source - if self._table_name is None: - raise ValueError("table_name must be set before assigning a data source") - self._data_source = normalize_data_source(value, self._table_name) - if old_source is not None and old_source is not self._data_source: + def data_source(self, _value: object) -> None: + raise AttributeError( + "The .data_source setter has been removed. " + "Use qc.add_table(df, 'name') to add a new table, " + "or qc.add_table(df, 'name', replace=True) to replace an existing one." + ) + + def table_names(self) -> list[str]: + """ + Return the names of all registered tables. + + Returns + ------- + list[str] + List of table names in the order they were added. + + """ + return list(self._data_sources.keys()) + + def add_table( + self, + data_source: IntoFrame | sqlalchemy.Engine | BaseBoard, + table_name: str, + *, + replace: bool = False, + ) -> None: + """ + Add or replace a table in the QueryChat instance. + + Parameters + ---------- + data_source + The data source (DataFrame, LazyFrame, database connection, or pins board). + table_name + Name for the table. + replace + If True, replace an existing table with the same name. + If False (default), raise ValueError if the table already exists. + + Raises + ------ + ValueError + If table_name already exists (and replace=False) or is invalid. + RuntimeError + If called after server() has been invoked. + + """ + if self._server_initialized: + raise RuntimeError( + "Cannot add tables after server initialization. " + "Add all tables before calling .server() or .app()." + ) + + if not is_pins_board(data_source) and not re.match( + r"^[a-zA-Z][a-zA-Z0-9_]*$", table_name + ): + raise ValueError( + "Table name must begin with a letter and contain only " + "letters, numbers, and underscores" + ) + + if table_name in self._data_sources and not replace: + raise ValueError(f"Table '{table_name}' already exists") + + normalized = normalize_data_source(data_source, table_name) + try: + other_sources = { + name: source + for name, source in self._data_sources.items() + if name != table_name + } + check_source_compatibility(other_sources, normalized, table_name) + next_data_sources = dict(self._data_sources) + next_data_sources[table_name] = normalized + + self._build_system_prompt(data_sources=next_data_sources) + except Exception: + cleanup_failed_staged_source(data_source, normalized) + raise + + old_source = self._data_sources.get(table_name) + self._data_sources = next_data_sources + if old_source is not None and old_source is not normalized: old_source.cleanup() - self._auto_fill_data_description() - self._build_system_prompt() + if self._query_executor is not None: + with contextlib.suppress(Exception): + self._query_executor.cleanup() + self._query_executor = None + + def add_tables( + self, + data_source: sqlalchemy.Engine, + tables: list[str] | None = None, + *, + replace: bool = False, + ) -> None: + """ + Add multiple tables from a SQLAlchemy engine in a single call. + + Unlike calling :meth:`add_table` repeatedly, this method builds the + system prompt exactly once after all tables have been staged, avoiding + N-1 spurious intermediate rebuilds. + + Parameters + ---------- + data_source + A SQLAlchemy engine. Only engines are supported; pass individual + DataFrames or other sources via :meth:`add_table`. + tables + Table names to register. When ``None``, all tables returned by + ``sqlalchemy.inspect(data_source).get_table_names()`` are used. + replace + If ``True``, replace any existing table whose name appears in + ``tables``. If ``False`` (default), raise ``ValueError`` if any + name already exists. + + Raises + ------ + TypeError + If ``data_source`` is not a ``sqlalchemy.Engine``. + ValueError + If the resolved table list is empty, any name is invalid, or any + name already exists (and ``replace=False``). + RuntimeError + If called after :meth:`server` has been invoked. + + Examples + -------- + Register all tables from an engine: + + >>> qc = QueryChat() + >>> qc.add_tables(engine) + + Register a specific subset: + + >>> qc.add_tables(engine, ["orders", "customers"]) + + """ + if self._server_initialized: + raise RuntimeError( + "Cannot add tables after server initialization. " + "Add all tables before calling .server() or .app()." + ) + + if not isinstance(data_source, sqlalchemy.Engine): + raise TypeError( + f"add_tables() requires a sqlalchemy.Engine, got {type(data_source).__name__}. " + "Use add_table() for DataFrames and other source types." + ) + + if tables is None: + tables = sqlalchemy.inspect(data_source).get_table_names() + + if not tables: + raise ValueError("No tables found in database") + + for table_name in tables: + if not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", table_name): + raise ValueError( + "Table name must begin with a letter and contain only " + "letters, numbers, and underscores" + ) + if table_name in self._data_sources and not replace: + raise ValueError(f"Table '{table_name}' already exists") + + normalized = { + name: normalize_data_source(data_source, name) for name in tables + } + + staged: dict[str, DataSource] = {} + for name, source in normalized.items(): + other_sources = { + n: s + for n, s in self._data_sources.items() + if n != name + } + check_source_compatibility({**other_sources, **staged}, source, name) + staged[name] = source + + next_data_sources = {**self._data_sources, **normalized} + self._build_system_prompt(data_sources=next_data_sources) + + for name, normalized_source in normalized.items(): + old_source = self._data_sources.get(name) + if old_source is not None and old_source is not normalized_source: + old_source.cleanup() + + self._data_sources = next_data_sources + if self._query_executor is not None: + with contextlib.suppress(Exception): + self._query_executor.cleanup() + self._query_executor = None + + def remove_table(self, table_name: str) -> None: + """ + Remove a table from the QueryChat instance. + + Parameters + ---------- + table_name + Name of the table to remove. + + Raises + ------ + ValueError + If table doesn't exist or is the last remaining table. + RuntimeError + If called after server() has been invoked. + + """ + if self._server_initialized: + raise RuntimeError( + "Cannot remove tables after server initialization. " + "Configure all tables before calling .server() or .app()." + ) + + if table_name not in self._data_sources: + available = ", ".join(self._data_sources.keys()) + raise ValueError(f"Table '{table_name}' not found. Available: {available}") + + if len(self._data_sources) == 1: + raise ValueError( + "Cannot remove last table. At least one table is required." + ) + + removed_source = self._data_sources[table_name] + next_data_sources = dict(self._data_sources) + del next_data_sources[table_name] + + self._build_system_prompt(data_sources=next_data_sources) + self._data_sources = next_data_sources + if self._query_executor is not None: + with contextlib.suppress(Exception): + self._query_executor.cleanup() + self._query_executor = None + removed_source.cleanup() + + def _mark_server_initialized(self) -> None: + """Mark that the server has been initialized. Prevents add/remove_table.""" + self._server_initialized = True def cleanup(self) -> None: - """Clean up resources associated with the data source.""" - if self._data_source is not None: - self._data_source.cleanup() + """Clean up resources associated with all data sources.""" + if self._query_executor is not None: + self._query_executor.cleanup() + for source in self._data_sources.values(): + source.cleanup() def normalize_data_source( @@ -330,6 +642,24 @@ def normalize_data_source( ) +def cleanup_failed_staged_source( + original_source: IntoFrame | sqlalchemy.Engine | BaseBoard | DataSource, + normalized_source: DataSource, +) -> None: + """ + Clean up transient resources created during a failed staged rebuild. + + DataFrameSource and PinSource both allocate disposable connections during + normalization. SQLAlchemySource wraps a caller-owned engine, while + PolarsLazySource and IbisSource do not allocate disposable resources here. + """ + if isinstance(original_source, (DataSource, sqlalchemy.Engine)): + return + + if isinstance(normalized_source, (DataFrameSource, PinSource)): + normalized_source.cleanup() + + def create_client(client: str | chatlas.Chat | None) -> chatlas.Chat: """Resolve a client spec into a fresh Chat with no conversation history.""" if client is None: @@ -371,3 +701,175 @@ def normalize_tools( "vl-convert-python. Install them with: pip install querychat[viz]" ) return resolved + + +def _normalize_data_dicts( + data_dict: DataDict | str | Path | list[DataDict | str | Path] | None, +) -> list[DataDict]: + from ._data_dict import DataDict as _DataDict + + if data_dict is None: + return [] + if isinstance(data_dict, list): + return [ + _DataDict.from_yaml(item) if isinstance(item, (str, Path)) else item + for item in data_dict + ] + if isinstance(data_dict, (str, Path)): + return [_DataDict.from_yaml(data_dict)] + return [data_dict] + + +def _get_table_sql(state: AppStateDict | None, table: str) -> str | None: + """Extract the SQL for a specific table from a serialized state dict.""" + if state is None: + return None + per_table = state.get("table_states") + if per_table and table in per_table: + return per_table[table].get("sql") + # Backward compat: if table matches the active table and no table_states key exists + if state.get("table") == table: + return state.get("sql") + return None + + +class StateDictQueryChat(QueryChatBase[IntoFrameT]): + """Base for Dash and Gradio adapters that pass serialized state dicts per request.""" + + def _client_factory( + self, + update_cb: Callable[[UpdateDashboardData], None], + reset_cb: Callable[[str], None], + ) -> chatlas.Chat: + """Create a chat client with dashboard callbacks.""" + return self.client(update_dashboard=update_cb, reset_dashboard=reset_cb) + + def _df_for_source( + self, data_source: DataSource[IntoFrameT], sql: str | None + ) -> IntoFrameT: + if sql: + with contextlib.suppress(Exception): + return self._require_query_executor("df").execute_query(sql) + return data_source.get_data() + + def df(self, state: AppStateDict | None, *, table: str | None = None) -> IntoFrameT: + """ + Get the current DataFrame from state. + + Parameters + ---------- + state + The state dictionary from a framework callback. + table + Table name to read. Defaults to the active table when None. + + Returns + ------- + : + The filtered data if a SQL query is active, otherwise the full dataset. + Returns a LazyFrame if the data source is lazy. + + """ + if table is not None: + return self._df_for_source( + self._data_sources[table], _get_table_sql(state, table) + ) + if len(self._data_sources) > 1: + primary_name = next(iter(self._data_sources)) + table_list = ", ".join(f"'{n}'" for n in self._data_sources) + warn_multi_table_flat_accessor("df", primary_name, table_list) + return self._df_for_source( + self._data_sources[primary_name], _get_table_sql(state, primary_name) + ) + data_source = self._get_state_data_source(state) + return self._df_for_source(data_source, state.get("sql") if state else None) + + def _get_state_data_source( + self, state: AppStateDict | None + ) -> DataSource[IntoFrameT]: + """Resolve the full-data source for a serialized state payload.""" + self._require_initialized("_get_state_data_source") + first_source: DataSource[IntoFrameT] = next(iter(self._data_sources.values())) + if not state: + return first_source + table_name = state.get("table") + if table_name is not None and table_name in self._data_sources: + return self._data_sources[table_name] + return first_source + + def sql(self, state: AppStateDict | None, *, table: str | None = None) -> str | None: + """ + Get the current SQL query from state. + + Parameters + ---------- + state + The state dictionary from a framework callback. + table + Table name. Defaults to the active table when None. + + Returns + ------- + : + The current SQL query, or None if showing full dataset. + + """ + if table is not None: + return _get_table_sql(state, table) + if len(self._data_sources) > 1: + primary_name = next(iter(self._data_sources)) + table_list = ", ".join(f"'{n}'" for n in self._data_sources) + warn_multi_table_flat_accessor("sql", primary_name, table_list) + return _get_table_sql(state, primary_name) + return state.get("sql") if state else None + + def _title_for_table( + self, state: AppStateDict | None, table: str + ) -> str | None: + if state is None: + return None + per_table = state.get("table_states") + if per_table and table in per_table: + return per_table[table].get("title") + if state.get("table") == table: + return state.get("title") + return None + + def title(self, state: AppStateDict | None, *, table: str | None = None) -> str | None: + """ + Get the current query title from state. + + Parameters + ---------- + state + The state dictionary from a framework callback. + table + Table name. Defaults to the active table when None. + + Returns + ------- + : + A short description of the current filter, or None if showing full dataset. + + """ + if table is not None: + return self._title_for_table(state, table) + if len(self._data_sources) > 1: + primary_name = next(iter(self._data_sources)) + table_list = ", ".join(f"'{n}'" for n in self._data_sources) + warn_multi_table_flat_accessor("title", primary_name, table_list) + return self._title_for_table(state, primary_name) + return state.get("title") if state else None + + def _deserialize_state(self, state_data: AppStateDict | None) -> AppState: + """Reconstruct AppState from a serialized state dict.""" + self._require_initialized("_deserialize_state") + state = create_app_state( + data_sources=dict(self._data_sources), + client_factory=self._client_factory, + greeting=self.greeting, + query_executor=self._require_query_executor("_deserialize_state"), + ) + if state_data: + state.update_from_dict(state_data) + return state diff --git a/pkg-py/src/querychat/_querychat_core.py b/pkg-py/src/querychat/_querychat_core.py index fb3134beb..87d110ab7 100644 --- a/pkg-py/src/querychat/_querychat_core.py +++ b/pkg-py/src/querychat/_querychat_core.py @@ -7,19 +7,19 @@ "AppState", "AppStateDict", "ClientFactory", - "StateDictAccessorMixin", "create_app_state", "stream_response", "stream_response_async", ] +import warnings from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, Optional, TypedDict, Union +from typing import TYPE_CHECKING, Optional, TypedDict, Union from chatlas import Chat, ContentToolRequest, ContentToolResult from chatlas.types import Content -from narwhals.stable.v1.typing import IntoFrameT +from typing_extensions import NotRequired from .tools import UpdateDashboardData @@ -36,21 +36,46 @@ from narwhals.stable.v1.typing import IntoFrame from ._datasource import DataSource + from ._query_executor import QueryExecutor ClientFactory = Callable[ - [Callable[[UpdateDashboardData], None], Callable[[], None]], + [Callable[[UpdateDashboardData], None], Callable[[str], None]], Chat, ] """Factory that creates a Chat client with update_dashboard and reset_dashboard callbacks.""" +def warn_multi_table_flat_accessor( + accessor_name: str, primary_table: str, table_list: str, stacklevel: int = 3 +) -> None: + """Emit a FutureWarning when a flat accessor is used with multiple tables registered.""" + warnings.warn( + f".{accessor_name}() called without a table name, but multiple tables are registered " + f"({table_list}). Defaulting to primary table '{primary_table}'. " + f"Use .table('{primary_table}').{accessor_name}() to suppress this warning. " + f"In a future version of querychat, this will raise an error.", + FutureWarning, + stacklevel=stacklevel, + ) + + +class TableStateData(TypedDict): + """Per-table state for serialization.""" + + sql: str | None + title: str | None + error: str | None + + class AppStateDict(TypedDict): """Serialized AppState for framework state stores.""" + table: NotRequired[str | None] sql: str | None title: str | None error: str | None + table_states: NotRequired[dict[str, TableStateData]] turns: list[dict] # Serialized chatlas Turns via model_dump() @@ -61,91 +86,6 @@ class DisplayMessage(TypedDict): content: str -class StateDictAccessorMixin(Generic[IntoFrameT]): - """Mixin providing df/sql/title accessors for frameworks using serialized state dicts.""" - - _data_source: DataSource[IntoFrameT] | None - - def _client_factory( - self, - update_cb: Callable[[UpdateDashboardData], None], - reset_cb: Callable[[], None], - ) -> Chat: - """Create a chat client with dashboard callbacks.""" - return self.client(update_dashboard=update_cb, reset_dashboard=reset_cb) # type: ignore[attr-defined] - - def df(self, state: AppStateDict | None) -> IntoFrameT: - """ - Get the current DataFrame from state. - - Parameters - ---------- - state - The state dictionary from a framework callback. - - Returns - ------- - : - The filtered data if a SQL query is active, otherwise the full dataset. - Returns a LazyFrame if the data source is lazy. - - """ - data_source = self._require_data_source("df") # type: ignore[attr-defined] - sql = state.get("sql") if state else None - if sql: - try: - return data_source.execute_query(sql) - except Exception: - return data_source.get_data() - return data_source.get_data() - - def sql(self, state: AppStateDict | None) -> str | None: - """ - Get the current SQL query from state. - - Parameters - ---------- - state - The state dictionary from a framework callback. - - Returns - ------- - : - The current SQL query, or None if showing full dataset. - - """ - return state.get("sql") if state else None - - def title(self, state: AppStateDict | None) -> str | None: - """ - Get the current query title from state. - - Parameters - ---------- - state - The state dictionary from a framework callback. - - Returns - ------- - : - A short description of the current filter, or None if showing full dataset. - - """ - return state.get("title") if state else None - - def _deserialize_state(self, state_data: AppStateDict | None) -> AppState: - """Reconstruct AppState from a serialized state dict.""" - data_source = self._require_data_source("_deserialize_state") # type: ignore[attr-defined] - state = create_app_state( - data_source, - self._client_factory, - self.greeting, # type: ignore[attr-defined] - ) - if state_data: - state.update_from_dict(state_data) - return state - - def format_chunk(chunk: Union[str, Content]) -> str: """Extract displayable text from a chat chunk.""" if isinstance(chunk, ContentToolRequest): @@ -163,7 +103,9 @@ def format_tool_result(result: ContentToolResult) -> str: display_info = result.extra.get("display") if result.extra else None if display_info and hasattr(display_info, "markdown"): return display_info.markdown - return str(result) + if result.value is not None: + return str(result.value) + return "" @@ -199,41 +141,93 @@ def format_query_error(e: Exception) -> str: class AppState: """Framework-agnostic application state for a querychat session.""" - data_source: DataSource + data_sources: dict[str, DataSource] client: Chat + query_executor: QueryExecutor | None = None greeting: Optional[str] = None - sql: Optional[str] = None - title: Optional[str] = None - error: Optional[str] = None + active_table: str | None = None + # sql, title, error are per-table properties backed by _table_states - def update_dashboard(self, data: UpdateDashboardData) -> None: - self.sql = data["query"] - self.title = data["title"] - self.error = None # Clear any previous error on successful update + def __post_init__(self) -> None: + if self.active_table is None: + self.active_table = next(iter(self.data_sources)) + self._table_states: dict[str, dict[str, str | None]] = { + name: {"sql": None, "title": None, "error": None} + for name in self.data_sources + } - def reset_dashboard(self) -> None: + def _get_active_state(self) -> dict[str, str | None]: + table = self.active_table or next(iter(self.data_sources)) + if table not in self._table_states: + self._table_states[table] = {"sql": None, "title": None, "error": None} + return self._table_states[table] + + @property + def sql(self) -> str | None: + return self._get_active_state()["sql"] + + @sql.setter + def sql(self, value: str | None) -> None: + self._get_active_state()["sql"] = value + + @property + def title(self) -> str | None: + return self._get_active_state()["title"] + + @title.setter + def title(self, value: str | None) -> None: + self._get_active_state()["title"] = value + + @property + def error(self) -> str | None: + return self._get_active_state()["error"] + + @error.setter + def error(self, value: str | None) -> None: + self._get_active_state()["error"] = value + + def update_dashboard(self, data: UpdateDashboardData) -> None: + table_name = data["table"] + self.active_table = table_name + if table_name not in self._table_states: + self._table_states[table_name] = {"sql": None, "title": None, "error": None} + self._table_states[table_name]["sql"] = data["query"] + self._table_states[table_name]["title"] = data["title"] + self._table_states[table_name]["error"] = None + + def reset_dashboard(self, table: str | None = None) -> None: + if table is not None: + self.active_table = table self.sql = None self.title = None self.error = None + def get_active_data_source(self) -> DataSource: + """Return the current full-data source for the active table.""" + if self.active_table is not None and self.active_table in self.data_sources: + return self.data_sources[self.active_table] + return next(iter(self.data_sources.values())) + def get_current_data(self) -> IntoFrame: """Get current data, falling back to default if query fails.""" + data_source = self.get_active_data_source() if self.sql: try: - result = self.data_source.execute_query(self.sql) + query_runner = self.query_executor or data_source + result = query_runner.execute_query(self.sql) self.error = None # Clear error on success return result except Exception as e: self.error = format_query_error(e) self.sql = None self.title = None - return self.data_source.get_data() - self.error = None - return self.data_source.get_data() + return data_source.get_data() + return data_source.get_data() def get_display_sql(self) -> str: - return self.sql or f"SELECT * FROM {self.data_source.table_name}" + table_name = self.active_table or next(iter(self.data_sources)) + return self.sql or f"SELECT * FROM {table_name}" def get_display_messages(self) -> list[DisplayMessage]: """ @@ -280,9 +274,14 @@ def initialize_greeting_if_preset(self) -> bool: def to_dict(self) -> AppStateDict: """Serialize state to dict for framework state stores.""" return { + "table": self.active_table, "sql": self.sql, "title": self.title, "error": self.error, + "table_states": { + name: {"sql": ts["sql"], "title": ts["title"], "error": ts["error"]} + for name, ts in self._table_states.items() + }, "turns": [turn.model_dump() for turn in self.client.get_turns()], } @@ -290,9 +289,22 @@ def update_from_dict(self, data: AppStateDict) -> None: """Restore state from serialized dict.""" from chatlas import Turn - self.sql = data["sql"] - self.title = data["title"] - self.error = data["error"] + self.active_table = data.get("table", next(iter(self.data_sources))) + + per_table = data.get("table_states") + if per_table: + for name, ts in per_table.items(): + if name in self._table_states: + self._table_states[name]["sql"] = ts.get("sql") + self._table_states[name]["title"] = ts.get("title") + self._table_states[name]["error"] = ts.get("error") + else: + # Backward compat: restore single active-table state from flat fields. + active = self.active_table or next(iter(self.data_sources)) + if active in self._table_states: + self._table_states[active]["sql"] = data["sql"] + self._table_states[active]["title"] = data["title"] + self._table_states[active]["error"] = data["error"] turns_data = data["turns"] turns = [Turn.model_validate(t) for t in turns_data] @@ -300,9 +312,11 @@ def update_from_dict(self, data: AppStateDict) -> None: def create_app_state( - data_source: DataSource, + *, + data_sources: dict[str, DataSource], client_factory: ClientFactory, greeting: Optional[str] = None, + query_executor: QueryExecutor | None = None, ) -> AppState: """Create AppState with callbacks connected via holder pattern.""" state_holder: dict[str, AppState | None] = {"state": None} @@ -313,16 +327,17 @@ def update_callback(data: UpdateDashboardData) -> None: raise RuntimeError("Callback invoked before state initialization") state.update_dashboard(data) - def reset_callback() -> None: + def reset_callback(_table: str) -> None: state = state_holder["state"] if state is None: raise RuntimeError("Callback invoked before state initialization") - state.reset_dashboard() + state.reset_dashboard(_table) client = client_factory(update_callback, reset_callback) state = AppState( - data_source=data_source, + data_sources=dict(data_sources), client=client, + query_executor=query_executor, greeting=greeting, ) state_holder["state"] = state diff --git a/pkg-py/src/querychat/_shiny.py b/pkg-py/src/querychat/_shiny.py index 4bf0f2681..f616671fe 100644 --- a/pkg-py/src/querychat/_shiny.py +++ b/pkg-py/src/querychat/_shiny.py @@ -24,6 +24,9 @@ import sqlalchemy from narwhals.stable.v1.typing import IntoFrame + from ._data_dict import DataDict + from ._table_accessor import TableAccessor + class QueryChat(QueryChatBase[IntoFrameT]): """ @@ -112,13 +115,13 @@ class QueryChat(QueryChatBase[IntoFrameT]): The tools can be overridden per-client by passing a different `tools` parameter to the `.client()` method. - data_description - Description of the data in plain text or Markdown. If a pathlib.Path - object is passed, querychat will read the contents of the path into a - string with `.read_text()`. - categorical_threshold - Threshold for determining if a column is categorical based on number of - unique values. + data_dict + A :class:`~querychat.DataDict` instance, or a path (``str`` or + ``pathlib.Path``) to a YAML file, that provides rich per-table and + per-column metadata. When set, documented columns use the dict's + ``values``, ``range``, and ``description`` fields instead of querying + the data source for statistics, which speeds up schema generation and + improves LLM context. Supersedes ``data_description``. extra_instructions Additional instructions for the chat model. If a pathlib.Path object is passed, querychat will read the contents of the path into a string with @@ -133,23 +136,30 @@ class QueryChat(QueryChatBase[IntoFrameT]): `data_source.get_schema()` - `{{data_description}}`: The optional data description provided - `{{extra_instructions}}`: Any additional instructions provided + categorical_threshold + Threshold for determining if a column is categorical based on number of + unique values. + data_description + Optional plain-text or Markdown description of the data, as a string or + file path. Superseded by ``data_dict`` for new code. """ @overload def __init__( self: QueryChat[Any], - data_source: None, - table_name: str, + data_source: None = None, + table_name: str | None = None, *, id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... @overload @@ -162,10 +172,11 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... @overload @@ -178,10 +189,11 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... @overload @@ -194,10 +206,11 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... @overload @@ -210,25 +223,27 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... def __init__( self, - data_source: IntoFrame | sqlalchemy.Engine | ibis.Table | None, - table_name: str, + data_source: IntoFrame | sqlalchemy.Engine | ibis.Table | None = None, + table_name: str | None = None, *, id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ): super().__init__( data_source, @@ -237,12 +252,12 @@ def __init__( client=client, tools=tools, data_description=data_description, + data_dict=data_dict, categorical_threshold=categorical_threshold, extra_instructions=extra_instructions, prompt_template=prompt_template, ) - # Use table_name for ID since data_source might be None - self.id = id or f"querychat_{table_name}" + self.id = id or (f"querychat_{table_name}" if table_name else "querychat") def app( self, *, bookmark_store: Literal["url", "server", "disable"] = "url" @@ -267,9 +282,9 @@ def app( A Shiny App object that can be run with `app.run()` or served with `shiny run`. """ - data_source = self._require_data_source("app") + self._require_initialized("app") enable_bookmarking = bookmark_store != "disable" - table_name = data_source.table_name + first_table_name = next(iter(self._data_sources)) def app_ui(request): return ui.page_sidebar( @@ -291,33 +306,47 @@ def app_ui(request): style="max-height: 33%;", ), ui.card( - ui.card_header(bs_icon("table"), " Data"), + ui.card_header( + bs_icon("table"), + " Data β€” ", + ui.output_text("data_card_header_text", inline=True), + ), ui.output_data_frame("dt"), ), - title=ui.span("querychat with ", ui.code(table_name)), + title=ui.span("querychat with ", ui.code(first_table_name)), class_="bslib-page-dashboard", fillable=True, ) def app_server(input: Inputs, output: Outputs, session: Session): + self._mark_server_initialized() if enable_bookmarking: session.bookmark.exclude.append("reset_query") vals = mod_server( self.id, - data_source=data_source, + data_sources=dict(self._data_sources), + executor=self._require_query_executor("server"), greeting=self.greeting, client=self._create_session_client, enable_bookmarking=enable_bookmarking, tools=self.tools, ) + @reactive.calc + def active_table_name() -> str: + return vals.current_table() or first_table_name + + @render.text + def data_card_header_text(): + return active_table_name() + @render.text def query_title(): - return vals.title() or "SQL Query" + return vals.table(active_table_name()).title() or "SQL Query" @render.ui def ui_reset(): - req(vals.sql()) + req(vals.table(active_table_name()).sql()) return ui.input_action_button( "reset_query", "Reset Query", @@ -327,17 +356,20 @@ def ui_reset(): @reactive.effect @reactive.event(input.reset_query) def _(): - vals.sql.set(None) - vals.title.set(None) + name = active_table_name() + # TableAccessor is read-only; mutation requires direct TableState access + vals._tables[name].sql.set(None) + vals._tables[name].title.set(None) @render.data_frame def dt(): # Collect lazy sources (LazyFrame, Ibis Table) to eager DataFrame - return as_narwhals(vals.df()) + return as_narwhals(vals.table(active_table_name()).df()) @render.ui def sql_output(): - sql_value = vals.sql() or f"SELECT * FROM {table_name}" + name = active_table_name() + sql_value = vals.table(name).sql() or f"SELECT * FROM {name}" sql_code = f"```sql\n{sql_value}\n```" return output_markdown_stream( "sql_code", @@ -412,7 +444,6 @@ def ui(self, *, id: Optional[str] = None, **kwargs): def server( self, *, - data_source: Optional[IntoFrame | sqlalchemy.Engine | ibis.Table] = None, client: str | chatlas.Chat | MISSING_TYPE = MISSING, enable_bookmarking: bool = False, id: Optional[str] = None, @@ -427,10 +458,6 @@ def server( Parameters ---------- - data_source - Optional data source to use. If provided, sets the data_source property - before initializing server logic. This is useful for the deferred pattern - where data_source is not known at initialization time. client Optional chat client to use for this session. If provided, overrides any client set at initialization time for this call only. This is useful @@ -497,10 +524,7 @@ def title(): ".server() must be called within an active Shiny session (i.e., within the server function). " ) - if data_source is not None: - self.data_source = data_source - - resolved_data_source = self._require_data_source("server") + self._require_initialized("server") resolved_client_spec = self._client_spec if isinstance(client, MISSING_TYPE) else client def create_session_client(**kwargs) -> chatlas.Chat: @@ -508,9 +532,11 @@ def create_session_client(**kwargs) -> chatlas.Chat: client_spec=resolved_client_spec, **kwargs ) + self._mark_server_initialized() return mod_server( id or self.id, - data_source=resolved_data_source, + data_sources=dict(self._data_sources), + executor=self._require_query_executor("server"), greeting=self.greeting, client=create_session_client, enable_bookmarking=enable_bookmarking, @@ -589,13 +615,13 @@ def data_table(): If `client` is not provided, querychat consults the `QUERYCHAT_CLIENT` environment variable. If that is not set, it defaults to `"openai"`. - data_description - Description of the data in plain text or Markdown. If a pathlib.Path - object is passed, querychat will read the contents of the path into a - string with `.read_text()`. - categorical_threshold - Threshold for determining if a column is categorical based on number of - unique values. + data_dict + A :class:`~querychat.DataDict` instance, or a path (``str`` or + ``pathlib.Path``) to a YAML file, that provides rich per-table and + per-column metadata. When set, documented columns use the dict's + ``values``, ``range``, and ``description`` fields instead of querying + the data source for statistics, which speeds up schema generation and + improves LLM context. Supersedes ``data_description``. extra_instructions Additional instructions for the chat model. If a pathlib.Path object is passed, querychat will read the contents of the path into a string with @@ -610,6 +636,12 @@ def data_table(): `data_source.get_schema()` - `{{data_description}}`: The optional data description provided - `{{extra_instructions}}`: Any additional instructions provided + categorical_threshold + Threshold for determining if a column is categorical based on number of + unique values. + data_description + Optional plain-text or Markdown description of the data, as a string or + file path. Superseded by ``data_dict`` for new code. """ @@ -619,17 +651,18 @@ def data_table(): @overload def __init__( self: QueryChatExpress[Any], - data_source: None, - table_name: str, + data_source: None = None, + table_name: str | None = None, *, id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, enable_bookmarking: Literal["auto", True, False] = "auto", ) -> None: ... @@ -643,10 +676,11 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, enable_bookmarking: Literal["auto", True, False] = "auto", ) -> None: ... @@ -660,10 +694,11 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, enable_bookmarking: Literal["auto", True, False] = "auto", ) -> None: ... @@ -677,10 +712,11 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, enable_bookmarking: Literal["auto", True, False] = "auto", ) -> None: ... @@ -694,26 +730,28 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, enable_bookmarking: Literal["auto", True, False] = "auto", ) -> None: ... def __init__( self, - data_source: IntoFrame | sqlalchemy.Engine | ibis.Table | None, - table_name: str, + data_source: IntoFrame | sqlalchemy.Engine | ibis.Table | None = None, + table_name: str | None = None, *, id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, enable_bookmarking: Literal["auto", True, False] = "auto", ): # Sanity check: Express should always have a (stub/real) session @@ -731,11 +769,12 @@ def __init__( client=client, tools=tools, data_description=data_description, + data_dict=data_dict, categorical_threshold=categorical_threshold, extra_instructions=extra_instructions, prompt_template=prompt_template, ) - self.id = id or f"querychat_{table_name}" + self.id = id or (f"querychat_{table_name}" if table_name else "querychat") # Determine bookmarking setting # During stub session: detect from app_opts and cache in class variable @@ -753,12 +792,34 @@ def __init__( else: enable = enable_bookmarking + self._enable_bookmarking = enable + self._vals: ServerValues[IntoFrameT] | None = None + + def _ensure_server_started(self) -> None: + """ + Start the Shiny module server if not already started. + + Called lazily from ui()/sidebar() and the reactive accessors so that + module-level add_table() calls (which happen after __init__ but before + sidebar()/ui()) can complete before server initialization locks the + table set. + """ + if self._server_initialized: + return + session = get_current_session() + if isinstance(session, ExpressStubSession): + return + if not self._data_sources: + return + self._require_initialized("_ensure_server_started") + self._mark_server_initialized() self._vals = mod_server( self.id, - data_source=self._data_source, + data_sources=dict(self._data_sources), + executor=self._require_query_executor("_ensure_server_started"), greeting=self.greeting, client=self._create_session_client, - enable_bookmarking=enable, + enable_bookmarking=self._enable_bookmarking, tools=self.tools, ) @@ -821,7 +882,18 @@ def ui(self, *, id: Optional[str] = None, **kwargs): A UI component. """ - return mod_ui(id or self.id, preload_viz=has_viz_tool(self.tools), greeting=self.greeting, **kwargs) + result = mod_ui(id or self.id, preload_viz=has_viz_tool(self.tools), greeting=self.greeting, **kwargs) + self._ensure_server_started() + return result + + def _require_vals(self) -> ServerValues[IntoFrameT]: + self._ensure_server_started() + if self._vals is None: + raise RuntimeError( + "QueryChat server is not initialized. " + "Ensure add_table() is called and sidebar()/ui() has been rendered." + ) + return self._vals def df(self) -> IntoFrameT: """ @@ -836,7 +908,7 @@ def df(self) -> IntoFrameT: data source. """ - return self._vals.df() + return self._require_vals().df() @overload def sql(self, query: None = None) -> str | None: ... @@ -863,9 +935,9 @@ def sql(self, query: Optional[str] = None) -> str | None | bool: """ if query is None: - return self._vals.sql() + return self._require_vals().sql() else: - return self._vals.sql.set(query) + return self._require_vals().sql.set(query) @overload def title(self, value: None = None) -> str | None: ... @@ -897,6 +969,58 @@ def title(self, value: Optional[str] = None) -> str | None | bool: """ if value is None: - return self._vals.title() + return self._require_vals().title() else: - return self._vals.title.set(value) + return self._require_vals().title.set(value) + + def table(self, name: str) -> TableAccessor: + """ + Get a per-table accessor with reactive state. + + Parameters + ---------- + name + Table name (must match a name passed to ``add_table()``). + + Returns + ------- + TableAccessor + Accessor with ``df()``, ``sql()``, and ``title()`` backed by + per-session reactive state. + + Examples + -------- + ```python + from querychat.express import QueryChat + from shiny.express import render + + qc = QueryChat(orders, "orders") + qc.add_table(customers, "customers") + qc.sidebar() + + @render.data_frame + def orders_table(): + return qc.table("orders").df() + + @render.data_frame + def customers_table(): + return qc.table("customers").df() + ``` + + """ + return self._require_vals().table(name) + + def current_table(self) -> str | None: + """ + Reactively read the name of the most recently queried table. + + Returns ``None`` if no query has run yet in this session. Useful for + auto-switching a tabbed UI to the active table. + + Returns + ------- + str or None + Table name, or ``None``. + + """ + return self._require_vals().current_table() diff --git a/pkg-py/src/querychat/_shiny_module.py b/pkg-py/src/querychat/_shiny_module.py index 95d4882e6..f0b3e8d4d 100644 --- a/pkg-py/src/querychat/_shiny_module.py +++ b/pkg-py/src/querychat/_shiny_module.py @@ -12,7 +12,8 @@ from shiny import module, reactive, ui -from ._querychat_core import GREETING_PROMPT +from ._querychat_core import GREETING_PROMPT, warn_multi_table_flat_accessor +from ._table_accessor import TableAccessor from ._viz_altair_widget import AltairWidget from ._viz_ggsql import execute_ggsql from ._viz_utils import has_viz_tool, preload_viz_deps_server, preload_viz_deps_ui @@ -25,6 +26,7 @@ from shiny import Inputs, Outputs, Session from ._datasource import DataSource + from ._query_executor import QueryExecutor from ._viz_tools import VisualizeData from .types import UpdateDashboardData @@ -56,6 +58,16 @@ def __getattr__(self, _name: str): ServerClient = chatlas.Chat | _DeferredStubChatClient +@dataclass +class TableState(Generic[IntoFrameT]): + """Per-table reactive state.""" + + sql: ReactiveStringOrNone + title: ReactiveStringOrNone + df: Callable[[], IntoFrameT] + + + @module.ui def mod_ui(*, preload_viz: bool = False, greeting: str | None = None, **kwargs): css_path = Path(__file__).parent / "static" / "css" / "styles.css" @@ -80,43 +92,115 @@ def mod_ui(*, preload_viz: bool = False, greeting: str | None = None, **kwargs): ) -@dataclass +class _MultiTableWarnReactive: + """Proxy that warns once per session and delegates to the primary table's reactive value.""" + + def __init__( + self, + primary: ReactiveStringOrNone, + accessor_name: str, + primary_table: str, + table_list: str, + ) -> None: + self._primary = primary + self._accessor_name = accessor_name + self._primary_table = primary_table + self._table_list = table_list + self._warned = False + + def _warn(self) -> None: + if not self._warned: + self._warned = True + warn_multi_table_flat_accessor( + self._accessor_name, self._primary_table, self._table_list, stacklevel=4 + ) + + def __call__(self) -> str | None: + self._warn() + return self._primary.get() + + def get(self) -> str | None: + self._warn() + return self._primary.get() + + def set(self, value: str | None) -> None: + self._primary.set(value) + + class ServerValues(Generic[IntoFrameT]): """ Session-specific reactive values and client returned by QueryChat.server(). - This dataclass contains all the session-specific reactive state for a QueryChat - instance. Each session gets its own ServerValues to ensure proper isolation + Each session gets its own ServerValues to ensure proper isolation between concurrent sessions. Attributes ---------- df - A reactive Calc that returns the current filtered data frame or lazy frame. - If the data source is lazy, returns a LazyFrame. If no SQL query has been - set, this returns the unfiltered data from the data source. - Call it like `.df()` to reactively read the current data. + Reactive Calc returning the current filtered data frame. + With multiple tables, warns and defaults to the primary table; use ``.table('name').df()``. sql - A reactive Value containing the current SQL query string. Access the value - by calling `.sql()`, or set it with `.sql.set("SELECT ...")`. - Returns `None` if no query has been set. + Reactive Value for the current SQL query string. + With multiple tables, warns and defaults to the primary table; use ``.table('name').sql``. title - A reactive Value containing the current title for the query. The LLM - provides this title when generating a new SQL query. Access it with - `.title()`, or set it with `.title.set("...")`. Returns - `None` if no title has been set. + Reactive Value for the current title. + With multiple tables, warns and defaults to the primary table; use ``.table('name').title``. + tables + Per-table reactive state dict. Keys are table names. client - Session chat client value. - For real sessions this is a `chatlas.Chat` created by the client - factory. For deferred stub sessions (where `data_source` is not set - yet), this is a placeholder client that raises when accessed. + Session chat client. + current_table + The name of the most recently queried table, or ``None`` if no query + has been run yet. Call ``.current_table()`` to read reactively. """ - df: Callable[[], IntoFrameT] - sql: ReactiveStringOrNone - title: ReactiveStringOrNone - client: ServerClient + def __init__( + self, + *, + df: Callable[[], IntoFrameT], + sql: ReactiveStringOrNone, + title: ReactiveStringOrNone, + tables: dict[str, TableState[IntoFrameT]], + client: ServerClient, + data_sources: dict[str, DataSource[IntoFrameT]], + current_table: ReactiveStringOrNone, + ): + self.df = df + self.sql = sql + self.title = title + self._tables = tables + self.client = client + self._data_sources = data_sources + self._current_table_rv = current_table + + def table(self, name: str) -> TableAccessor: + """ + Get a per-table accessor with reactive state. + + Parameters + ---------- + name + Table name. + + Returns + ------- + TableAccessor + Accessor with df(), sql(), title() backed by per-session state. + + """ + if name not in self._tables: + available = ", ".join(f"'{n}'" for n in self._tables) + raise ValueError(f"Table '{name}' not found. Available: {available}") + return TableAccessor(name, self._data_sources[name], state=self._tables[name]) + + def table_names(self) -> list[str]: + """Return the names of all registered tables.""" + return list(self._tables.keys()) + + def current_table(self) -> str | None: + """Return the name of the most recently queried table, or None (reactive).""" + return self._current_table_rv.get() @module.server @@ -125,15 +209,13 @@ def mod_server( output: Outputs, session: Session, *, - data_source: DataSource[IntoFrameT] | None, + data_sources: dict[str, DataSource[IntoFrameT]] | None, + executor: QueryExecutor | None, greeting: str | None, client: Callable[..., chatlas.Chat], enable_bookmarking: bool, tools: set[str] | None = None, ) -> ServerValues[IntoFrameT]: - # Reactive values to store state - sql = ReactiveStringOrNone(None) - title = ReactiveStringOrNone(None) # Holds a generated greeting so it can be saved and restored on bookmark. # Static greetings live in the UI (chat_ui(greeting=)) and persist already. # Workaround for posit-dev/shinychat#253: shinychat does not bookmark @@ -145,13 +227,36 @@ def mod_server( if not callable(client): raise TypeError("mod_server() requires a callable client factory.") - def update_dashboard(data: UpdateDashboardData): - sql.set(data["query"]) - title.set(data["title"]) + table_states: dict[str, TableState[IntoFrameT]] = {} + _current_table: ReactiveStringOrNone = ReactiveStringOrNone(None) + + def _make_table_state( + source: DataSource[IntoFrameT], exec: QueryExecutor + ) -> TableState[IntoFrameT]: + table_sql = ReactiveStringOrNone(None) + table_title = ReactiveStringOrNone(None) + + @reactive.calc + def filtered_df() -> IntoFrameT: + query = table_sql.get() + if query: + return exec.execute_query(query) + return source.get_data() - def reset_dashboard(): - sql.set(None) - title.set(None) + return TableState(sql=table_sql, title=table_title, df=filtered_df) + + def update_dashboard(data: UpdateDashboardData): + table_name = data["table"] + if table_name in table_states: + table_states[table_name].sql.set(data["query"]) + table_states[table_name].title.set(data["title"]) + _current_table.set(table_name) + + def reset_dashboard(table_name: str): + if table_name in table_states: + table_states[table_name].sql.set(None) + table_states[table_name].title.set(None) + _current_table.set(table_name) viz_widgets: list[VizWidgetEntry] = [] @@ -167,43 +272,42 @@ def build_chat_client() -> chatlas.Chat: ) # Short-circuit for stub sessions (e.g. 1st run of an Express app) - # data_source may be None during stub session for deferred pattern + # data_sources may be None during stub session for deferred pattern if session.is_stub_session(): # Mock the error that would otherwise occur in a real session def _stub_df(): raise RuntimeError("RuntimeError: No current reactive context") stub_client = ( - _DeferredStubChatClient() if data_source is None else build_chat_client() + _DeferredStubChatClient() if data_sources is None else build_chat_client() ) return ServerValues( df=_stub_df, - sql=sql, - title=title, + sql=ReactiveStringOrNone(None), + title=ReactiveStringOrNone(None), + tables={}, client=stub_client, + data_sources=data_sources or {}, + current_table=ReactiveStringOrNone(None), ) - # Real session requires data_source - if data_source is None: + # Real session requires data_sources and executor + if data_sources is None or executor is None: raise RuntimeError( - "data_source must be set before the real session. " - "Set it via the data_source property before users connect." + "At least one table must be registered before the session starts. " + "Call add_table() before server(), or pass the data to the QueryChat constructor." ) + for name, source in data_sources.items(): + table_states[name] = _make_table_state(source, executor) + # Build the session-specific chat client through QueryChat.client(...). chat = build_chat_client() if has_viz_tool(tools): preload_viz_deps_server() - # Execute query when SQL changes - @reactive.calc - def filtered_df(): - query = sql.get() - df = data_source.get_data() if not query else data_source.execute_query(query) - return df - # Chat UI logic chat_ui = shinychat.Chat(CHAT_ID) ctrl = chatlas.StreamController() @@ -258,17 +362,15 @@ async def _handle_greeting_requested(): @reactive.event(input.chat_update) def _(): update = input.chat_update() - if update is None: - return - if not isinstance(update, dict): + if update is None or not isinstance(update, dict): return - - new_query = update.get("query") - new_title = update.get("title") - if new_query is not None: - sql.set(new_query) - if new_title is not None: - title.set(new_title) + table_name = update.get("table", "") + new_query = update.get("query") or None # "" β†’ None (reset) + new_title = update.get("title") or None + if table_name and table_name in table_states: + table_states[table_name].sql.set(new_query) + table_states[table_name].title.set(new_title) + _current_table.set(table_name) if enable_bookmarking: chat_ui.enable_bookmarking(chat) @@ -277,8 +379,9 @@ def _(): @session.bookmark.on_bookmark def _on_bookmark(x: BookmarkState) -> None: vals = x.values - vals["querychat_sql"] = sql.get() - vals["querychat_title"] = title.get() + for name, state in table_states.items(): + vals[f"querychat_sql_{name}"] = state.sql.get() + vals[f"querychat_title_{name}"] = state.title.get() greeting_val = current_greeting.get() if greeting_val is not None: vals["querychat_greeting"] = greeting_val @@ -288,10 +391,16 @@ def _on_bookmark(x: BookmarkState) -> None: @session.bookmark.on_restore async def _on_restore(x: RestoreState) -> None: vals = x.values - if "querychat_sql" in vals: - sql.set(vals["querychat_sql"]) - if "querychat_title" in vals: - title.set(vals["querychat_title"]) + last_restored: str | None = None + for name, state in table_states.items(): + if f"querychat_sql_{name}" in vals: + state.sql.set(vals[f"querychat_sql_{name}"]) + if vals[f"querychat_sql_{name}"] is not None: + last_restored = name + if f"querychat_title_{name}" in vals: + state.title.set(vals[f"querychat_title_{name}"]) + if last_restored is not None: + _current_table.set(last_restored) if "querychat_greeting" in vals: current_greeting.set(vals["querychat_greeting"]) await chat_ui.set_greeting( @@ -301,11 +410,45 @@ async def _on_restore(x: RestoreState) -> None: ) if "querychat_viz_widgets" in vals: restored = restore_viz_widgets( - data_source, vals["querychat_viz_widgets"] + executor, vals["querychat_viz_widgets"] ) viz_widgets[:] = restored - return ServerValues(df=filtered_df, sql=sql, title=title, client=chat) + if len(table_states) == 1: + only_state = next(iter(table_states.values())) + return ServerValues( + df=only_state.df, + sql=only_state.sql, + title=only_state.title, + tables=table_states, + client=chat, + data_sources=data_sources, + current_table=_current_table, + ) + + primary_name = next(iter(table_states)) + primary_state = table_states[primary_name] + table_list = ", ".join(f"'{n}'" for n in table_states) + + df_warned = False + + @reactive.calc + def _multi_table_df() -> IntoFrameT: + nonlocal df_warned + if not df_warned: + df_warned = True + warn_multi_table_flat_accessor("df", primary_name, table_list) + return primary_state.df() + + return ServerValues( + df=_multi_table_df, + sql=_MultiTableWarnReactive(primary_state.sql, "sql", primary_name, table_list), # type: ignore[arg-type] + title=_MultiTableWarnReactive(primary_state.title, "title", primary_name, table_list), # type: ignore[arg-type] + tables=table_states, + client=chat, + data_sources=data_sources, + current_table=_current_table, + ) class GreetWarning(Warning): @@ -313,7 +456,7 @@ class GreetWarning(Warning): def restore_viz_widgets( - data_source: DataSource[IntoFrameT], + executor: QueryExecutor, saved_widgets: list[VizWidgetEntry], ) -> list[VizWidgetEntry]: """Re-execute ggsql queries, register widgets, and return restored entries.""" @@ -327,7 +470,7 @@ def restore_viz_widgets( ggsql_str = entry["ggsql"] try: validated = validate(ggsql_str) - spec = execute_ggsql(data_source, validated) + spec = execute_ggsql(executor, validated) altair_widget = AltairWidget.from_ggsql(spec, widget_id=widget_id) register_widget(widget_id, altair_widget.widget) restored.append(entry) diff --git a/pkg-py/src/querychat/_streamlit.py b/pkg-py/src/querychat/_streamlit.py index b68a6effc..b8093b83f 100644 --- a/pkg-py/src/querychat/_streamlit.py +++ b/pkg-py/src/querychat/_streamlit.py @@ -13,6 +13,7 @@ create_app_state, stream_response, ) +from ._table_accessor import TableAccessor from ._ui_assets import STREAMLIT_JS, SUGGESTION_CSS from ._utils import as_narwhals @@ -25,6 +26,50 @@ import sqlalchemy from narwhals.stable.v1.typing import IntoFrame + from ._data_dict import DataDict + + +class StreamlitTableAccessor(TableAccessor): + """Per-table accessor for Streamlit QueryChat. Returned by ``qc.table(name)``.""" + + def __init__(self, querychat: QueryChat, table_name: str) -> None: + # Bypass TableAccessor.__init__ β€” this subclass owns df/sql/title entirely + # via session state, so _state is never used. + self._table_name = table_name + self._data_source = querychat._data_sources[table_name] + self._querychat_ref = querychat + + def df(self) -> Any: + """ + Get the current filtered data for this table. + + Returns the full dataset when no SQL filter is active. + """ + qc = self._querychat_ref + state = qc._get_state() + ts = state._table_states.get(self._table_name, {}) + sql = ts.get("sql") + data_source = qc._data_sources[self._table_name] + if sql: + try: + executor = qc._require_query_executor("table.df") + return executor.execute_query(sql) + except Exception: + return data_source.get_data() + return data_source.get_data() + + def sql(self) -> str | None: + """Return the current SQL filter for this table, or None.""" + qc = self._querychat_ref + state = qc._get_state() + return state._table_states.get(self._table_name, {}).get("sql") + + def title(self) -> str | None: + """Return the current filter title for this table, or None.""" + qc = self._querychat_ref + state = qc._get_state() + return state._table_states.get(self._table_name, {}).get("title") + class QueryChat(QueryChatBase[IntoFrameT]): """ @@ -61,16 +106,17 @@ class QueryChat(QueryChatBase[IntoFrameT]): @overload def __init__( self: QueryChat[Any], - data_source: None, - table_name: str, + data_source: None = None, + table_name: str | None = None, *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... @overload @@ -82,10 +128,11 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... @overload @@ -97,10 +144,11 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... @overload @@ -112,10 +160,11 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... @overload @@ -127,24 +176,26 @@ def __init__( greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ) -> None: ... def __init__( self, - data_source: IntoFrame | sqlalchemy.Engine | ibis.Table | None, - table_name: str, + data_source: IntoFrame | sqlalchemy.Engine | ibis.Table | None = None, + table_name: str | None = None, *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("filter", "query"), - data_description: Optional[str | Path] = None, - categorical_threshold: int = 20, + data_dict: DataDict | str | Path | None = None, extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, + categorical_threshold: int = 20, + data_description: Optional[str | Path] = None, ): super().__init__( data_source, @@ -153,25 +204,27 @@ def __init__( client=client, tools=tools, data_description=data_description, + data_dict=data_dict, categorical_threshold=categorical_threshold, extra_instructions=extra_instructions, prompt_template=prompt_template, ) - self._state_key = f"_querychat_{table_name}" + self._state_key = f"_querychat_{table_name}" if table_name else "_querychat" def _get_state(self) -> AppState: """Get or create session state.""" - data_source = self._require_data_source("_get_state") + self._require_initialized("_get_state") import streamlit as st if self._state_key not in st.session_state: st.session_state[self._state_key] = create_app_state( - data_source, - lambda update_cb, reset_cb: self.client( + data_sources=dict(self._data_sources), + client_factory=lambda update_cb, reset_cb: self.client( update_dashboard=update_cb, reset_dashboard=reset_cb, ), - self.greeting, + greeting=self.greeting, + query_executor=self._require_query_executor("_get_state"), ) return st.session_state[self._state_key] @@ -182,11 +235,18 @@ def app(self) -> None: Configures the page, renders chat in sidebar, and displays SQL query and data table in the main area. """ - data_source = self._require_data_source("app") + self._require_initialized("app") + if len(self._data_sources) > 1: + table_list = ", ".join(f"'{n}'" for n in self._data_sources) + raise RuntimeError( + f"app() does not support multiple tables ({table_list}). " + "Build a custom layout using sidebar(), ui(), and table('name') instead." + ) import streamlit as st + table_name = next(iter(self._data_sources)) st.set_page_config( - page_title=f"querychat with {data_source.table_name}", + page_title=f"querychat with {table_name}", layout="wide", initial_sidebar_state="expanded", ) @@ -270,16 +330,32 @@ def df(self) -> IntoFrameT: eager sources, a LazyFrame for Polars lazy sources, or an Ibis Table for Ibis sources. Callers needing an eager DataFrame should collect the result (e.g., via ``as_narwhals(qc.df())``). + + Raises ``AttributeError`` when multiple tables are registered; + use ``qc.table('name').df()`` instead. """ + self._require_single_table("df") # Cast is safe because get_current_data() returns the same type as the data source return cast("IntoFrameT", self._get_state().get_current_data()) def sql(self) -> str | None: - """Get the current SQL query, or None if using default.""" + """ + Get the current SQL query, or None if using default. + + Raises ``AttributeError`` when multiple tables are registered; + use ``qc.table('name').sql()`` instead. + """ + self._require_single_table("sql") return self._get_state().sql def title(self) -> str | None: - """Get the current query title, or None if using default.""" + """ + Get the current query title, or None if using default. + + Raises ``AttributeError`` when multiple tables are registered; + use ``qc.table('name').title()`` instead. + """ + self._require_single_table("title") return self._get_state().title def reset(self) -> None: @@ -305,12 +381,13 @@ def reset(self) -> None: def _render_main_content(self) -> None: """Render the main content area (SQL + data table).""" - data_source = self._require_data_source("_render_main_content") + self._require_initialized("_render_main_content") import streamlit as st state = self._get_state() + table_name = next(iter(self._data_sources)) - st.title(f"querychat with `{data_source.table_name}`") + st.title(f"querychat with `{table_name}`") st.subheader(state.title or "SQL Query") @@ -331,3 +408,24 @@ def _render_main_content(self) -> None: df.to_native(), use_container_width=True, height=400, hide_index=True ) st.caption(f"Data has {df.shape[0]} rows and {df.shape[1]} columns.") + + def table(self, name: str) -> StreamlitTableAccessor: + """ + Return a per-table accessor for the given table name. + + Parameters + ---------- + name + Table name (must be one of ``qc.table_names()``). + + Raises + ------ + ValueError + If ``name`` is not a registered table. + + """ + if name not in self._data_sources: + raise ValueError( + f"Table '{name}' not found. Available tables: {self.table_names()}" + ) + return StreamlitTableAccessor(self, name) diff --git a/pkg-py/src/querychat/_system_prompt.py b/pkg-py/src/querychat/_system_prompt.py index f690a0696..a7e17d6cc 100644 --- a/pkg-py/src/querychat/_system_prompt.py +++ b/pkg-py/src/querychat/_system_prompt.py @@ -1,16 +1,16 @@ from __future__ import annotations -import re +import warnings from pathlib import Path from typing import TYPE_CHECKING import chevron +import yaml from ._viz_utils import has_viz_tool -_SCHEMA_TAG_RE = re.compile(r"\{\{[{#^/]?\s*schema\b") - if TYPE_CHECKING: + from ._data_dict import DataDict from ._datasource import DataSource @@ -19,47 +19,106 @@ class QueryChatSystemPrompt: def __init__( self, - prompt_template: str | Path, - data_source: DataSource, + *, + prompt_template: str | Path | None, + data_source: DataSource | None = None, + data_sources: dict[str, DataSource] | None = None, data_description: str | Path | None = None, extra_instructions: str | Path | None = None, - categorical_threshold: int = 10, + categorical_threshold: int = 20, + data_dicts: list[DataDict] | None = None, ): - """ - Initialize with prompt components. - - Args: - prompt_template: Mustache template string or path to template file - data_source: DataSource instance for schema generation - data_description: Optional data context (string or path) - extra_instructions: Optional custom LLM instructions (string or path) - categorical_threshold: Threshold for categorical column detection - - """ - if isinstance(prompt_template, Path): - self.template = prompt_template.read_text() - else: - self.template = prompt_template - - if isinstance(data_description, Path): - self.data_description = data_description.read_text() + if data_sources is not None: + self._data_sources = data_sources + elif data_source is not None: + self._data_sources = {data_source.table_name: data_source} else: - self.data_description = data_description + raise ValueError("Either data_source or data_sources must be provided") - if isinstance(extra_instructions, Path): - self.extra_instructions = extra_instructions.read_text() - else: - self.extra_instructions = extra_instructions + self._data_dicts: list[DataDict] = data_dicts or [] - if _SCHEMA_TAG_RE.search(self.template): - self.schema = data_source.get_schema( - categorical_threshold=categorical_threshold + if len(self._data_sources) > 1 and not self._data_dicts: + warnings.warn( + "Multiple tables registered without a data_dict. " + "Providing a data_dict with table descriptions and relationships " + "gives the LLM better context for multi-table queries.", + UserWarning, + stacklevel=3, ) - else: - self.schema = "" + if prompt_template is None: + prompt_template = Path(__file__).parent / "prompts" / "prompt.md" + self.template = ( + prompt_template.read_text() + if isinstance(prompt_template, Path) + else prompt_template + ) + + self.data_description = ( + data_description.read_text() + if isinstance(data_description, Path) + else data_description + ) + self.extra_instructions = ( + extra_instructions.read_text() + if isinstance(extra_instructions, Path) + else extra_instructions + ) self.categorical_threshold = categorical_threshold - self.data_source = data_source + + def _generate_tables_overview(self) -> str: + lines = [] + for name, source in self._data_sources.items(): + desc: str | None = source.get_data_description() or None + if desc and not self.data_description: + lines.append(f"- {name}: {desc}") + else: + lines.append(f"- {name}") + return "\n".join(lines) + + def _generate_data_dicts_yaml(self) -> str: + def escape_attr(val: str) -> str: + return val.replace('"', """) + + blocks: list[str] = [] + all_claimed: set[str] = set() + + for dd in self._data_dicts: + d = dd.to_prompt_dict() + # Name and description belong in the XML tag, not the YAML body + d.pop("name", None) + d.pop("description", None) + + claimed = {n for n in self._data_sources if n in dd.tables} + all_claimed.update(claimed) + if "tables" in d: + d["tables"] = { + n: v for n, v in d["tables"].items() if n in self._data_sources + } + if not d["tables"]: + del d["tables"] + + attrs = f'name="{escape_attr(dd.name)}"' if dd.name else "" + if dd.description: + attrs += f' description="{escape_attr(dd.description)}"' + + body = yaml.dump(d, default_flow_style=False, allow_unicode=True, sort_keys=False).rstrip() if d else "" + blocks.append(f"\n{body}\n" if body else f"") + + unclaimed = [n for n in self._data_sources if n not in all_claimed] + if unclaimed: + tables: dict = {} + for name in unclaimed: + desc = ( + self._data_sources[name].get_data_description() or None + if not self.data_description + else None + ) + tables[name] = {"description": desc} if desc else None + yaml_str = yaml.dump({"tables": tables}, default_flow_style=False, allow_unicode=True, sort_keys=False).rstrip() + blocks.append(f"\n{yaml_str}\n") + + return "\n\n".join(blocks) def render(self, tools: set[str] | None) -> str: """ @@ -72,20 +131,24 @@ def render(self, tools: set[str] | None) -> str: Fully rendered system prompt string """ - db_type = self.data_source.get_db_type() - is_duck_db = db_type.lower() == "duckdb" + first_source = next(iter(self._data_sources.values())) + db_type = first_source.get_db_type() + has_dicts = bool(self._data_dicts) context = { "db_type": db_type, - "is_duck_db": is_duck_db, - "semantic_views": self.data_source.get_semantic_views_description(), - "schema": self.schema, + "is_duck_db": db_type.lower() == "duckdb", + "semantic_views": first_source.get_semantic_views_description(), + "has_data_dicts": has_dicts, + "data_dicts": self._generate_data_dicts_yaml() if has_dicts else "", + "tables_overview": "" if has_dicts else self._generate_tables_overview(), "data_description": self.data_description, "extra_instructions": self.extra_instructions, "has_tool_update": "update" in tools if tools else False, "has_tool_query": "query" in tools if tools else False, "has_tool_visualize": has_viz_tool(tools), "include_query_guidelines": len(tools or ()) > 0, + "multi_table": len(self._data_sources) > 1, } prompts_dir = str(Path(__file__).parent / "prompts") @@ -95,3 +158,10 @@ def render(self, tools: set[str] | None) -> str: partials_path=prompts_dir, partials_ext="md", ) + + @property + def data_source(self) -> DataSource: + """Return single data source for backwards compatibility.""" + if len(self._data_sources) == 1: + return next(iter(self._data_sources.values())) + raise ValueError("Multiple data sources present; use _data_sources instead") diff --git a/pkg-py/src/querychat/_table_accessor.py b/pkg-py/src/querychat/_table_accessor.py new file mode 100644 index 000000000..f2dbb380e --- /dev/null +++ b/pkg-py/src/querychat/_table_accessor.py @@ -0,0 +1,62 @@ +"""TableAccessor class for accessing per-table state and data.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from ._datasource import DataSource + + +class TableAccessor: + """ + Accessor for a specific table's reactive state and data. + + Returned by ``qc_vals.table("name")`` in Shiny server callbacks, and by + ``qc.table("name")`` in Streamlit. Provides ``df()``, ``sql()``, and + ``title()`` backed by per-session reactive state. + + Parameters + ---------- + table_name + The name of the table this accessor represents. + data_source + The DataSource for this table. + state + Per-table reactive state, wired up by the framework. + + """ + + def __init__( + self, + table_name: str, + data_source: DataSource, + *, + state: Any, + ): + self._table_name = table_name + self._data_source = data_source + self._state = state + + @property + def table_name(self) -> str: + """The name of this table.""" + return self._table_name + + @property + def data_source(self) -> DataSource: + """The data source for this table.""" + return self._data_source + + def df(self) -> Any: + """Return the current filtered data for this table (reactive).""" + return self._state.df() + + def sql(self) -> str | None: + """Return the current SQL filter for this table (reactive).""" + return self._state.sql.get() + + def title(self) -> str | None: + """Return the current filter title for this table (reactive).""" + return self._state.title.get() + diff --git a/pkg-py/src/querychat/_viz_ggsql.py b/pkg-py/src/querychat/_viz_ggsql.py index 076b4f3b3..9e166c151 100644 --- a/pkg-py/src/querychat/_viz_ggsql.py +++ b/pkg-py/src/querychat/_viz_ggsql.py @@ -10,20 +10,20 @@ if TYPE_CHECKING: import ggsql - from ._datasource import DataSource + from ._query_executor import QueryExecutor -def execute_ggsql(data_source: DataSource, validated: ggsql.Validated) -> ggsql.Spec: +def execute_ggsql(executor: QueryExecutor, validated: ggsql.Validated) -> ggsql.Spec: """ - Execute a pre-validated ggsql query against a DataSource, returning a Spec. + Execute a pre-validated ggsql query against a QueryExecutor, returning a Spec. - Executes the SQL portion through DataSource (preserving database pushdown), + Executes the SQL portion through the executor (preserving database pushdown), then feeds the result into a ggsql DuckDBReader to produce a Spec. Parameters ---------- - data_source - The querychat DataSource to execute the SQL portion against. + executor + The querychat QueryExecutor to execute the SQL portion against. validated A pre-validated ggsql query (from ``ggsql.validate()``). @@ -47,7 +47,7 @@ def execute_ggsql(data_source: DataSource, validated: ggsql.Validated) -> ggsql. "result." ) - pl_df = to_polars(data_source.execute_query(validated.sql())) + pl_df = to_polars(executor.execute_query(validated.sql())) reader = DuckDBReader("duckdb://memory") table = extract_visualise_table(visual) diff --git a/pkg-py/src/querychat/_viz_tools.py b/pkg-py/src/querychat/_viz_tools.py index 19833eb77..c57b297c0 100644 --- a/pkg-py/src/querychat/_viz_tools.py +++ b/pkg-py/src/querychat/_viz_tools.py @@ -28,7 +28,7 @@ import altair as alt from ipywidgets.widgets.widget import Widget - from ._datasource import DataSource + from ._query_executor import QueryExecutor class VisualizeData(TypedDict): @@ -56,18 +56,22 @@ class VisualizeData(TypedDict): def tool_visualize( - data_source: DataSource, + executor: QueryExecutor, update_fn: Callable[[VisualizeData], None], + *, + multi_table: bool = False, ) -> Tool: """ Create a tool that executes a ggsql query and renders the visualization. Parameters ---------- - data_source - The data source to query against + executor + The query executor to query against update_fn Callback function to call with VisualizeData when visualization succeeds + multi_table + Whether multiple tables are registered. Returns ------- @@ -75,10 +79,11 @@ def tool_visualize( A tool that can be registered with chatlas """ - impl = visualize_impl(data_source, update_fn) + impl = visualize_impl(executor, update_fn) impl.__doc__ = read_prompt_template( "tool-visualize.md", - db_type=data_source.get_db_type(), + db_type=executor.get_db_type(), + multi_table=multi_table, ) return Tool.from_func( @@ -147,7 +152,7 @@ def __init__( def visualize_impl( - data_source: DataSource, + executor: QueryExecutor, update_fn: Callable[[VisualizeData], None], ) -> Callable[[str, str], ContentToolResult]: """Create the visualize implementation function.""" @@ -173,7 +178,7 @@ def visualize( "\n".join(error["message"] for error in validated.errors()) ) - spec = execute_ggsql(data_source, validated) + spec = execute_ggsql(executor, validated) raw_chart = VegaLiteWriter().render_chart(spec) altair_widget = AltairWidget(copy.deepcopy(raw_chart)) diff --git a/pkg-py/src/querychat/prompts/prompt.md b/pkg-py/src/querychat/prompts/prompt.md index 00272cada..08b0c93f7 100644 --- a/pkg-py/src/querychat/prompts/prompt.md +++ b/pkg-py/src/querychat/prompts/prompt.md @@ -1,20 +1,26 @@ You are a data dashboard chatbot that operates in a sidebar interface. Your role is to help users interact with their data through filtering, sorting, and answering questions.{{#has_tool_visualize}} You can also help them explore data visually.{{/has_tool_visualize}} -You have access to a {{db_type}} SQL database with the following schema: +You have access to a {{db_type}} SQL database with the following tables: - -{{schema}} - +{{#has_data_dicts}} +{{{data_dicts}}} -{{#data_description}} -Here is additional information about the data: +{{/has_data_dicts}} +{{^has_data_dicts}} + +{{{tables_overview}}} + +{{/has_data_dicts}} +{{#data_description}} {{data_description}} + {{/data_description}} +Always call `querychat_get_schema` before writing SQL against any table you haven't retrieved schema for in this conversation. Do not infer column names from table names, variable names, or the system prompt alone β€” verify the actual schema first, then write the query. -For security reasons, you may only query this specific table. +For security reasons, you may only query {{#multi_table}}these specific tables{{/multi_table}}{{^multi_table}}this specific table{{/multi_table}}. {{#include_query_guidelines}} ## SQL Query Guidelines @@ -82,18 +88,19 @@ You can handle these types of requests: When the user asks you to filter or sort the dashboard, e.g. "Show me..." or "Which ____ have the highest ____?" or "Filter to only include ____": - Write a {{db_type}} SQL SELECT query -- Call `querychat_update_dashboard` with the query and a descriptive title -- The query MUST return all columns from the schema (you can use `SELECT *`) +- Call `querychat_update_dashboard` with the query, table name, and a descriptive title +- You MUST specify the `table` parameter to indicate which table to filter +- The query MUST return all columns from the specified table's schema (you can use `SELECT *`) - Use a single SQL query even if complex (subqueries and CTEs are fine) - Optimize for **readability over efficiency** - Include SQL comments to explain complex logic - No confirmation messages are needed: the user will see your query in the dashboard. -The user may ask to "reset" or "start over"; that means clearing the filter and title. Do this by calling `querychat_reset_dashboard()`. +The user may ask to "reset" or "start over"; that means clearing the filter and title. Do this by calling `querychat_reset_dashboard` with the relevant `table`. **Filtering Example:** User: "Show only rows where sales are above average" -Tool Call: `querychat_update_dashboard({query: "SELECT * FROM table WHERE sales > (SELECT AVG(sales) FROM table)", title: "Above average sales"})` +Tool Call: `querychat_update_dashboard({query: "SELECT * FROM sales_data WHERE sales > (SELECT AVG(sales) FROM sales_data)", table: "sales_data", title: "Above average sales"})` Response: "" No further response needed, the user will see the updated dashboard. @@ -125,7 +132,7 @@ You can create visualizations using the `querychat_visualize` tool, which uses g #### Visualization best practices -The database schema in this prompt includes column names, types, and summary statistics. {{#has_tool_query}}If that context isn't sufficient for a confident visualization β€” e.g., you're unsure about value distributions, need to check for NULLs, or want to gauge row counts before choosing a chart type β€” use the `querychat_query` tool to inspect the data before visualizing. Always pass `collapsed=true` for these preparatory queries so the chart remains the focal point of the response.{{/has_tool_query}} +Use the `querychat_get_schema` tool to retrieve column names, types, and summary statistics for a table before writing visualization queries. {{#has_tool_query}}If that context isn't sufficient for a confident visualization β€” e.g., you're unsure about value distributions, need to check for NULLs, or want to gauge row counts before choosing a chart type β€” use the `querychat_query` tool to inspect the data before visualizing. Always pass `collapsed=true` for these preparatory queries so the chart remains the focal point of the response.{{/has_tool_query}} Follow the principles below to produce clear, interpretable charts. diff --git a/pkg-py/src/querychat/prompts/tool-get-schema.md b/pkg-py/src/querychat/prompts/tool-get-schema.md new file mode 100644 index 000000000..f34697223 --- /dev/null +++ b/pkg-py/src/querychat/prompts/tool-get-schema.md @@ -0,0 +1,18 @@ +Retrieve full column details for a table + +Returns column names, types, value ranges, categorical values, and descriptions for the specified table. + +**When to use this tool:** + +- Before writing any SQL query involving a table you have not yet inspected +- When you are unsure which table is most relevant to the user's request β€” call this tool on candidate tables to understand their contents before deciding + +Parameters +---------- +table_name + The name of the table to retrieve schema for. Must match one of the table names shown in the system prompt. + +Returns +------- +: + Full column details for the specified table, including column names, types, value ranges, categorical values, and descriptions. diff --git a/pkg-py/src/querychat/prompts/tool-query.md b/pkg-py/src/querychat/prompts/tool-query.md index 65bc7d899..bd8954738 100644 --- a/pkg-py/src/querychat/prompts/tool-query.md +++ b/pkg-py/src/querychat/prompts/tool-query.md @@ -25,6 +25,11 @@ Always use SQL for counting, averaging, summing, and other calculationsβ€”NEVER - When using `collapsed=false`, avoid duplicating the same rows/values in both the tool result and your response text - Do not reproduce large result sets in your response β€” summarize the key takeaways instead +{{#multi_table}} + +**Multi-table queries:** Your schema includes multiple tables. You can reference any table in your queries and use JOINs when the data spans tables. Use the relationships described in the schema to determine join conditions. + +{{/multi_table}} Parameters ---------- query : diff --git a/pkg-py/src/querychat/prompts/tool-reset-dashboard.md b/pkg-py/src/querychat/prompts/tool-reset-dashboard.md index 7d78b4b43..a44d0a7f1 100644 --- a/pkg-py/src/querychat/prompts/tool-reset-dashboard.md +++ b/pkg-py/src/querychat/prompts/tool-reset-dashboard.md @@ -2,9 +2,14 @@ Reset the dashboard to its original state Resets the dashboard to use the original unfiltered dataset and clears any custom title. -If the user asks to reset the dashboard, simply call this tool with no other response. The reset action will be obvious to the user. +If the user asks to reset the dashboard, call this tool with the relevant table name and no other response. The reset action will be obvious to the user. -If the user asks to start over, call this tool and then provide a new set of suggestions for next steps. Include suggestions that encourage exploration of the data in new directions. +If the user asks to start over, call this tool with the relevant table name and then provide a new set of suggestions for next steps. Include suggestions that encourage exploration of the data in new directions. + +Parameters +---------- +table + The name of the table to reset. Returns ------- diff --git a/pkg-py/src/querychat/prompts/tool-update-dashboard.md b/pkg-py/src/querychat/prompts/tool-update-dashboard.md index dae9861c0..809c3b447 100644 --- a/pkg-py/src/querychat/prompts/tool-update-dashboard.md +++ b/pkg-py/src/querychat/prompts/tool-update-dashboard.md @@ -2,6 +2,8 @@ Filter and sort the dashboard data This tool executes a {{db_type}} SQL SELECT query to filter or sort the data used in the dashboard. +The `table` parameter specifies which table to filter. Use the table name exactly as shown in the schema. + **When to use:** Call this tool whenever the user requests filtering, sorting, or data manipulation on the dashboard with questions like "Show me..." or "Which records have...". This tool is appropriate for any request that involves showing a subset of the data or reordering it. **When not to use:** Do NOT use this tool for general questions about the data that can be answered with a single value or summary statistic. For those questions, use the `querychat_query` tool instead. @@ -14,8 +16,15 @@ This tool executes a {{db_type}} SQL SELECT query to filter or sort the data use - Assume the user will only see the original columns in the dataset +{{#multi_table}} + +**Multi-table filters:** When filtering a table, you may reference other tables in WHERE clauses, subqueries, or CTEs (e.g., filtering orders by a condition on customers). The result must still return all columns of the target table specified by the `table` parameter. + +{{/multi_table}} Parameters ---------- +table : + The name of the table to filter. Must match exactly one of the table names from the schema. query : A {{db_type}} SQL SELECT query that MUST return all existing schema columns (use SELECT * or explicitly list all columns). May include additional computed columns, subqueries, CTEs, WHERE clauses, ORDER BY, and any {{db_type}}-supported SQL functions. title : diff --git a/pkg-py/src/querychat/prompts/tool-visualize.md b/pkg-py/src/querychat/prompts/tool-visualize.md index c43f4da4d..4475cd3ef 100644 --- a/pkg-py/src/querychat/prompts/tool-visualize.md +++ b/pkg-py/src/querychat/prompts/tool-visualize.md @@ -10,6 +10,11 @@ Render a ggsql query (SQL with a VISUALISE clause) as an Altair chart displayed - Do NOT include `LABEL title => ...` in the query β€” use the `title` parameter instead. - If a visualization fails, read the error message carefully and retry with a corrected query. Common fixes: correcting column names, adding `SCALE DISCRETE` for integer categories, moving SQL expressions out of `VISUALISE` into the `SELECT` clause, and using `DRAW range` for interval-style marks instead of deprecated `errorbar`.{{#has_tool_query}} If the error persists, fall back to `querychat_query` for a tabular answer.{{/has_tool_query}} +{{#multi_table}} + +**Multi-table queries:** The SELECT portion of your ggsql query can reference any table from the schema and use JOINs. + +{{/multi_table}} Parameters ---------- ggsql : diff --git a/pkg-py/src/querychat/static/js/querychat.js b/pkg-py/src/querychat/static/js/querychat.js index 18d6b4f45..58f7e3ba4 100644 --- a/pkg-py/src/querychat/static/js/querychat.js +++ b/pkg-py/src/querychat/static/js/querychat.js @@ -9,11 +9,11 @@ if (!chatContainer) return; const chatId = chatContainer.id; - const { query, title } = event.target.dataset; + const { query, title, table } = event.target.dataset; window.Shiny.setInputValue( chatId + "_update", - { query, title }, + { query, title, table }, { priority: "event" } ); }); diff --git a/pkg-py/src/querychat/static/js/schema-display.js b/pkg-py/src/querychat/static/js/schema-display.js new file mode 100644 index 000000000..b21376c6c --- /dev/null +++ b/pkg-py/src/querychat/static/js/schema-display.js @@ -0,0 +1,154 @@ +/* Generated file. Source: js/src/schema-display.js. Do not edit directly. */ + +"use strict"; +(() => { + // src/schema-display.js + var lastDisplay = null; + var lastDisplayTime = 0; + var BATCH_MS = 1e3; + var activePanel = null; + function parseColumnsJson(json) { + return JSON.parse(json).map((col) => ({ + name: col.name, + type: col.sql_type, + units: col.units || null, + description: col.description || null, + constraints: col.constraints && col.constraints.length > 0 ? col.constraints.join(", ") : null, + range: col.min_val != null && col.max_val != null ? `${col.min_val} to ${col.max_val}` : null, + categories: col.categories && col.categories.length > 0 ? col.categories.map((v) => `'${v}'`).join(", ") : null + })); + } + function esc(s) { + return String(s).replace(/&/g, "&").replace(//g, ">").replace(/"/g, """); + } + var TH = "padding:0.35em 0.75em;text-align:left;white-space:nowrap;font-weight:600;border-bottom:2px solid var(--bs-border-color,#dee2e6);background:var(--bs-tertiary-bg,#f8f9fa);position:sticky;top:0;z-index:1;"; + var TD_MONO = "padding:0.3em 0.75em;white-space:nowrap;font-family:var(--bs-font-monospace,monospace);font-size:0.875em;border-bottom:1px solid var(--bs-border-color-translucent,rgba(0,0,0,.08));"; + var TD_WRAP = "padding:0.3em 0.75em;max-width:22em;overflow-wrap:break-word;border-bottom:1px solid var(--bs-border-color-translucent,rgba(0,0,0,.08));"; + var TD_NOWRAP = "padding:0.3em 0.75em;white-space:nowrap;border-bottom:1px solid var(--bs-border-color-translucent,rgba(0,0,0,.08));"; + function renderTable(columns) { + const rows = columns.map((col) => { + let typeCell = esc(col.type); + if (col.units) { + typeCell += ` [${esc(col.units)}]`; + } + const details = col.range ? esc(col.range) : col.categories ? esc(col.categories) : ""; + return `${esc(col.name)}${typeCell}${col.description ? esc(col.description) : ""}${col.constraints ? esc(col.constraints) : ""}${details}`; + }).join(""); + return `${rows}
ColumnTypeDescriptionConstraintsRange / Values
`; + } + var PANEL_STYLE = "position:fixed;z-index:9999;background:var(--bs-body-bg,#fff);color:var(--bs-body-color,#212529);border:1px solid var(--bs-border-color,#dee2e6);border-radius:var(--bs-border-radius,0.375rem);box-shadow:0 4px 16px rgba(0,0,0,.15);overflow:auto;max-height:min(420px,60vh);"; + function positionPanel(btn, panel) { + const rect = btn.getBoundingClientRect(); + const vw = window.innerWidth; + const vh = window.innerHeight; + const pw = Math.min(Math.max(360, vw * 0.55), vw - 16); + panel.style.width = `${pw}px`; + panel.style.left = `${Math.max(8, Math.min(rect.left, vw - pw - 8))}px`; + const spaceBelow = vh - rect.bottom - 8; + const spaceAbove = rect.top - 8; + if (spaceBelow >= 120 || spaceBelow >= spaceAbove) { + panel.style.top = `${rect.bottom + 4}px`; + } else { + const panelH = Math.min(420, spaceAbove); + panel.style.top = `${Math.max(8, rect.top - panelH - 4)}px`; + } + } + function closePanel() { + if (activePanel) { + activePanel.panel.hidden = true; + activePanel.btn.setAttribute("aria-expanded", "false"); + activePanel = null; + } + } + document.addEventListener("click", closePanel); + document.addEventListener("keydown", (e) => { + if (e.key === "Escape") closePanel(); + }); + window.addEventListener( + "scroll", + (e) => { + if (activePanel && !activePanel.panel.contains( + /** @type {Node} */ + e.target + )) { + closePanel(); + } + }, + true + ); + window.addEventListener("resize", closePanel); + function createBtn(tableName, columnsJson) { + const columns = parseColumnsJson(columnsJson); + const btn = document.createElement("button"); + btn.type = "button"; + btn.style.cssText = "background:none;border:none;padding:0;color:inherit;text-decoration:underline dotted;cursor:pointer;font-size:inherit;border-radius:2px;"; + btn.textContent = tableName; + btn.setAttribute("aria-label", `Show schema for ${tableName}`); + btn.setAttribute("aria-expanded", "false"); + btn.setAttribute("aria-haspopup", "dialog"); + const panel = document.createElement("div"); + panel.setAttribute("role", "dialog"); + panel.setAttribute("aria-label", `${tableName} schema`); + panel.style.cssText = PANEL_STYLE; + panel.hidden = true; + panel.innerHTML = renderTable(columns); + document.body.appendChild(panel); + btn.addEventListener("click", (e) => { + e.stopPropagation(); + if (activePanel && activePanel.panel === panel) { + closePanel(); + return; + } + closePanel(); + positionPanel(btn, panel); + panel.hidden = false; + btn.setAttribute("aria-expanded", "true"); + activePanel = { btn, panel }; + }); + panel.addEventListener("click", (e) => e.stopPropagation()); + return btn; + } + var style = document.createElement("style"); + style.textContent = ".qc-schema-display button:focus-visible{outline:2px solid currentColor;outline-offset:2px;border-radius:2px}"; + document.head.appendChild(style); + function processCollector(sentinel) { + const now = Date.now(); + const tableName = sentinel.dataset.table; + const btn = createBtn(tableName, sentinel.dataset.schemaJson); + if (lastDisplay && document.contains(lastDisplay) && now - lastDisplayTime < BATCH_MS) { + lastDisplay.appendChild(document.createTextNode(", ")); + lastDisplay.appendChild(btn); + sentinel.remove(); + } else { + const p = document.createElement("p"); + p.className = "qc-schema-display"; + p.style.cssText = "color:var(--bs-secondary-color,#6c757d);font-size:0.875em;margin:0.1rem 0;"; + p.appendChild(document.createTextNode("\u{1F50D} Fetched schemas: ")); + p.appendChild(btn); + sentinel.replaceWith(p); + lastDisplay = p; + } + lastDisplayTime = now; + } + new MutationObserver((mutations) => { + for (const { addedNodes } of mutations) { + for (const node of addedNodes) { + if (node.nodeType !== 1) continue; + if ( + /** @type {Element} */ + node.classList.contains("qc-schema-collector") + ) { + processCollector( + /** @type {HTMLElement} */ + node + ); + } else { + node.querySelectorAll(".qc-schema-collector").forEach((el) => processCollector( + /** @type {HTMLElement} */ + el + )); + } + } + } + }).observe(document.body, { subtree: true, childList: true }); +})(); diff --git a/pkg-py/src/querychat/streamlit.py b/pkg-py/src/querychat/streamlit.py index ddd895bfb..5b82bc3fc 100644 --- a/pkg-py/src/querychat/streamlit.py +++ b/pkg-py/src/querychat/streamlit.py @@ -1,5 +1,5 @@ """Streamlit integration for querychat.""" -from ._streamlit import QueryChat +from ._streamlit import QueryChat, StreamlitTableAccessor -__all__ = ["QueryChat"] +__all__ = ["QueryChat", "StreamlitTableAccessor"] diff --git a/pkg-py/src/querychat/tools.py b/pkg-py/src/querychat/tools.py index 48a17b5cc..b446b1819 100644 --- a/pkg-py/src/querychat/tools.py +++ b/pkg-py/src/querychat/tools.py @@ -1,10 +1,18 @@ from __future__ import annotations +import html +import json +from collections.abc import Callable from typing import TYPE_CHECKING, Any, Protocol, TypedDict, runtime_checkable -from chatlas import ContentToolResult, Tool -from shinychat.types import ToolResultDisplay +from chatlas import ContentToolRequest, ContentToolResult, Tool +from htmltools import HTMLDependency, TagList, tags +from pydantic import Field +from shinychat import message_content_chunk +from shinychat.types import ChatMessage, ToolResultDisplay +from .__version import __version__ +from ._datasource import ColumnMeta, format_schema from ._icons import bs_icon from ._utils import ( as_narwhals, @@ -16,6 +24,8 @@ from ._viz_tools import tool_visualize __all__ = [ + "GetSchemaResult", + "tool_get_schema", "tool_query", "tool_reset_dashboard", "tool_update_dashboard", @@ -23,9 +33,129 @@ ] if TYPE_CHECKING: - from collections.abc import Callable + from ._data_dict import DataDict + from ._query_executor import QueryExecutor - from ._datasource import DataSource + +ResetDashboardCallback = Callable[[str], None] + + +class GetSchemaResult(ContentToolResult): + """Tool result that carries schema text and structured column metadata for a single table.""" + + table_name: str + columns: list[ColumnMeta] = Field(default_factory=list) + + +def _col_to_dict(col: ColumnMeta) -> dict[str, Any]: + return { + "name": col.name, + "sql_type": col.sql_type, + "units": col.units, + "description": col.description, + "min_val": str(col.min_val) if col.min_val is not None else None, + "max_val": str(col.max_val) if col.max_val is not None else None, + "categories": col.categories, + "constraints": col.constraints, + } + + +_orig_request_handler = message_content_chunk.dispatch(ContentToolRequest) + + +@message_content_chunk.register +def _(request: ContentToolRequest) -> ChatMessage: + if request.name == "querychat_get_schema": + return ChatMessage(content="") + return _orig_request_handler(request) + + +@message_content_chunk.register +def _(message: GetSchemaResult) -> ChatMessage: + columns_json = json.dumps([_col_to_dict(c) for c in message.columns]) + content = TagList( + tags.span( + class_="qc-schema-collector", + data_table=message.table_name, + data_schema=str(message.value), + data_schema_json=columns_json, + style="display:none", + ), + _schema_dep(), + ) + return ChatMessage(content=content) + + +def _schema_dep() -> HTMLDependency: + return HTMLDependency( + "querychat-schema-display", + __version__, + source={"package": "querychat", "subdir": "static"}, + script=[{"src": "js/schema-display.js"}], + ) + + +def _get_schema_impl( + data_dicts: list[DataDict], + executor: QueryExecutor, + table_names: list[str], + categorical_threshold: int, +) -> Callable[[str], ContentToolResult]: + def get_schema(table_name: str) -> ContentToolResult: + if table_name not in table_names: + available = ", ".join(table_names) + error = f"Table '{table_name}' not found. Available: {available}" + return ContentToolResult(value=error, error=Exception(error)) + + dd = next((d for d in data_dicts if table_name in d.tables), None) + if dd is not None: + columns = dd.get_table_schema(table_name, executor, categorical_threshold) + else: + columns = executor.get_column_details(table_name, categorical_threshold) + + schema_text = format_schema(table_name, columns) + return GetSchemaResult(value=schema_text, table_name=table_name, columns=columns) + + return get_schema + + +def tool_get_schema( + data_dicts: list[DataDict], + executor: QueryExecutor, + table_names: list[str], + categorical_threshold: int, +) -> Tool: + """ + Create a tool that retrieves full column details for a table. + + Parameters + ---------- + data_dicts + Data dictionaries with enriched column metadata. The first dict that + covers a requested table is used; tables not covered by any dict fall + back to live statistics from the executor. + executor + The query executor to use for schema introspection. + table_names + List of valid table names. + categorical_threshold + Maximum number of unique values before a text column is treated as + free-form rather than categorical. + + Returns + ------- + Tool + A tool that can be registered with chatlas. + + """ + impl = _get_schema_impl(data_dicts, executor, table_names, categorical_threshold) + description = read_prompt_template("tool-get-schema.md") + impl.__doc__ = description + return Tool.from_func( + impl, + name="querychat_get_schema", + annotations={"title": "Get Schema"}, + ) @runtime_checkable @@ -52,6 +182,8 @@ class UpdateDashboardData(TypedDict): Attributes ---------- + table + The name of the table being filtered. query The SQL query string to execute for filtering/sorting the dashboard. title @@ -66,6 +198,7 @@ class UpdateDashboardData(TypedDict): def log_update(data: UpdateDashboardData): + print(f"Table: {data['table']}") print(f"Executing: {data['query']}") print(f"Title: {data['title']}") @@ -77,35 +210,45 @@ def log_update(data: UpdateDashboardData): """ + table: str query: str title: str def _update_dashboard_impl( - data_source: DataSource, + executor: QueryExecutor, + table_names: list[str], update_fn: Callable[[UpdateDashboardData], None], -) -> Callable[[str, str], ContentToolResult]: +) -> Callable[[str, str, str], ContentToolResult]: """Create the implementation function for updating the dashboard.""" - def update_dashboard(query: str, title: str) -> ContentToolResult: + def update_dashboard(table: str, query: str, title: str) -> ContentToolResult: error = None markdown = f"```sql\n{query}\n```" value = "Dashboard updated. Use `query` tool to review results, if needed." + # Validate table exists + if table not in table_names: + available = ", ".join(table_names) + error = f"Table '{table}' not found. Available: {available}" + markdown += f"\n\n> Error: {error}" + return ContentToolResult(value=markdown, error=Exception(error)) + try: # Test the query but don't execute it yet - data_source.test_query(query, require_all_columns=True) + executor.test_query(query, table_name=table, require_all_columns=True) # Add Apply Filter button button_html = f"""""" # Call the callback with TypedDict data on success - update_fn({"query": query, "title": title}) + update_fn({"table": table, "query": query, "title": title}) except Exception as e: error = truncate_error(str(e)) @@ -130,30 +273,38 @@ def update_dashboard(query: str, title: str) -> ContentToolResult: def tool_update_dashboard( - data_source: DataSource, + executor: QueryExecutor, + table_names: list[str], update_fn: Callable[[UpdateDashboardData], None], + *, + multi_table: bool = False, ) -> Tool: """ - Create a tool that modifies the data presented in the dashboard based on the SQL query. + Create a tool that modifies the data presented in the dashboard. Parameters ---------- - data_source - The data source to query against + executor + The query executor to validate queries against. + table_names + List of valid table names for validation. update_fn - Callback function to call with UpdateDashboardData when update succeeds + Callback function to call with UpdateDashboardData when update succeeds. + multi_table + Whether multiple tables are registered. Returns ------- Tool - A tool that can be registered with chatlas + A tool that can be registered with chatlas. """ - impl = _update_dashboard_impl(data_source, update_fn) + impl = _update_dashboard_impl(executor, table_names, update_fn) description = read_prompt_template( "tool-update-dashboard.md", - db_type=data_source.get_db_type(), + db_type=executor.get_db_type(), + multi_table=multi_table, ) impl.__doc__ = description @@ -165,17 +316,27 @@ def tool_update_dashboard( def _reset_dashboard_impl( - reset_fn: Callable[[], None], -) -> Callable[[], ContentToolResult]: + reset_fn: ResetDashboardCallback, + table_names: list[str] | None, +) -> Callable[[str], ContentToolResult]: """Create the implementation function for resetting the dashboard.""" - def reset_dashboard() -> ContentToolResult: + def reset_dashboard(table: str) -> ContentToolResult: + if table_names is not None and table not in table_names: + available = ", ".join(table_names) + error = f"Table '{table}' not found. Available: {available}" + return ContentToolResult( + value=error, + error=Exception(error), + ) + # Call the callback to reset - reset_fn() + reset_fn(table) # Add Reset Filter button - button_html = """