diff --git a/pkg-py/src/querychat/_gradio.py b/pkg-py/src/querychat/_gradio.py index 18a87cb8..e6434b63 100644 --- a/pkg-py/src/querychat/_gradio.py +++ b/pkg-py/src/querychat/_gradio.py @@ -12,8 +12,8 @@ from ._querychat_base import TOOL_GROUPS, StateDictQueryChat from ._querychat_core import ( - GREETING_PROMPT, AppStateDict, + build_greeting_prompt, create_app_state, stream_response, ) @@ -98,6 +98,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... @overload @@ -114,6 +115,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... @overload @@ -130,6 +132,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... @overload @@ -146,6 +149,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... @overload @@ -162,6 +166,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... def __init__( @@ -177,6 +182,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ): super().__init__( data_source, @@ -189,6 +195,7 @@ def __init__( categorical_threshold=categorical_threshold, extra_instructions=extra_instructions, prompt_template=prompt_template, + greeting_tables=greeting_tables, ) @property @@ -288,7 +295,12 @@ def initialize_greeting(state_dict: AppStateDict): if not state.initialize_greeting_if_preset(): greeting = "" - for chunk in stream_response(state.client, GREETING_PROMPT): + greeting_prompt = build_greeting_prompt( + state.data_sources, + self._categorical_threshold, + self.greeting_tables, + ) + for chunk in stream_response(state.client, greeting_prompt): greeting += chunk state.set_greeting(greeting) diff --git a/pkg-py/src/querychat/_querychat_base.py b/pkg-py/src/querychat/_querychat_base.py index d3b4c612..0cccfa5b 100644 --- a/pkg-py/src/querychat/_querychat_base.py +++ b/pkg-py/src/querychat/_querychat_base.py @@ -32,9 +32,9 @@ validate_source_group_compatibility, ) from ._querychat_core import ( - GREETING_PROMPT, AppState, AppStateDict, + build_greeting_prompt, create_app_state, warn_multi_table_flat_accessor, ) @@ -91,6 +91,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ): self._data_dicts: list[DataDict] = _normalize_data_dicts(data_dict) @@ -113,6 +114,7 @@ def __init__( self._client_spec: str | chatlas.Chat | None = client self._client_console = None + self.greeting_tables: list[str] | bool | None = greeting_tables self._system_prompt: QueryChatSystemPrompt | None = None if data_source is not None: @@ -322,7 +324,12 @@ def generate_greeting(self, *, echo: Literal["none", "output"] = "none") -> str: chat = create_client(self._client_spec) if self._system_prompt is not None: chat.system_prompt = self._system_prompt.render(self.tools) - return str(chat.chat(GREETING_PROMPT, echo=echo)) + prompt = build_greeting_prompt( + self._data_sources, + self._categorical_threshold, + self.greeting_tables, + ) + return str(chat.chat(prompt, echo=echo)) def console( self, diff --git a/pkg-py/src/querychat/_querychat_core.py b/pkg-py/src/querychat/_querychat_core.py index 87d110ab..650b0c57 100644 --- a/pkg-py/src/querychat/_querychat_core.py +++ b/pkg-py/src/querychat/_querychat_core.py @@ -3,17 +3,19 @@ from __future__ import annotations __all__ = [ - "GREETING_PROMPT", + "GREETING_MARKER", + "GREETING_PROMPT", # backward compat alias — value unchanged; callers relying on the full prompt should use build_greeting_prompt() instead "AppState", "AppStateDict", "ClientFactory", + "build_greeting_prompt", "create_app_state", "stream_response", "stream_response_async", ] import warnings -from collections.abc import Callable +from collections.abc import Callable, Mapping from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, TypedDict, Union @@ -23,11 +25,23 @@ from .tools import UpdateDashboardData -GREETING_PROMPT: str = ( +GREETING_MARKER: str = "\n" +"""Sentinel prepended to every greeting prompt turn; used to skip it in display.""" + +GREETING_BASE_TEXT: str = ( "Please give me a friendly greeting. " "Include a few sample suggestions grouped under ##### headings, " "using the suggestion card format from your instructions." ) + +GREETING_EXPLORE_ADDENDUM: str = ( + " Include at least one suggestion encouraging the user to explore what " + "data and questions are available — for example, asking which tables or " + "columns exist, or what kinds of analysis are possible." +) + +# Keep the old name as an alias so any external code that imported it still works. +GREETING_PROMPT: str = GREETING_BASE_TEXT """Prompt used to generate the initial greeting message.""" if TYPE_CHECKING: @@ -46,6 +60,63 @@ """Factory that creates a Chat client with update_dashboard and reset_dashboard callbacks.""" +def build_greeting_prompt( + data_sources: Mapping[str, DataSource], + categorical_threshold: int, + greeting_tables: list[str] | bool | None, # noqa: FBT001 — bool is a sentinel, not a flag +) -> str: + """ + Build a greeting prompt, optionally embedding table schema. + + Parameters + ---------- + data_sources + All registered data sources keyed by table name. + categorical_threshold + Passed to ``get_schema`` for each included table. + greeting_tables + Which tables to include schema for: + - ``None``: auto — single table gets full schema; multi-table gets none. + - ``True``: all tables. + - ``False``: no tables (explorer hint only). + - list of names: only the named tables. + + """ + table_names = resolve_greeting_tables(data_sources, greeting_tables) + + if table_names: + schema_sections: list[str] = [] + multi = len(table_names) > 1 + for name in table_names: + source = data_sources[name] + schema = source.get_schema(categorical_threshold=categorical_threshold) + section = f"Table '{name}':\n{schema}" if multi else schema + schema_sections.append(section) + schema_block = "\n\n".join(schema_sections) + body = f"\n{schema_block}\n\n\n{GREETING_BASE_TEXT}" + else: + body = GREETING_BASE_TEXT + GREETING_EXPLORE_ADDENDUM + + return f"{GREETING_MARKER}{body}" + + +def resolve_greeting_tables( + data_sources: Mapping[str, DataSource], + greeting_tables: list[str] | bool | None, # noqa: FBT001 — bool is a sentinel, not a flag +) -> list[str]: + """Return the list of table names whose schema to include in the greeting.""" + if greeting_tables is True: + return list(data_sources.keys()) + if greeting_tables is False: + return [] + if isinstance(greeting_tables, list): + return [t for t in greeting_tables if t in data_sources] + # Auto (None): single table → include schema; multi-table → no schema + if len(data_sources) == 1: + return list(data_sources.keys()) + return [] + + def warn_multi_table_flat_accessor( accessor_name: str, primary_table: str, table_list: str, stacklevel: int = 3 ) -> None: @@ -246,7 +317,7 @@ def get_display_messages(self) -> list[DisplayMessage]: if text_parts: text = "\n\n".join(text_parts) # Skip the greeting prompt - it's an internal message - if turn.role == "user" and text == GREETING_PROMPT: + if turn.role == "user" and text.startswith(GREETING_MARKER): continue messages.append({"role": turn.role, "content": text}) diff --git a/pkg-py/src/querychat/_shiny.py b/pkg-py/src/querychat/_shiny.py index f616671f..b342642c 100644 --- a/pkg-py/src/querychat/_shiny.py +++ b/pkg-py/src/querychat/_shiny.py @@ -160,6 +160,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... @overload @@ -177,6 +178,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... @overload @@ -194,6 +196,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... @overload @@ -211,6 +214,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... @overload @@ -228,6 +232,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... def __init__( @@ -244,6 +249,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ): super().__init__( data_source, @@ -256,6 +262,7 @@ def __init__( categorical_threshold=categorical_threshold, extra_instructions=extra_instructions, prompt_template=prompt_template, + greeting_tables=greeting_tables, ) self.id = id or (f"querychat_{table_name}" if table_name else "querychat") @@ -330,6 +337,8 @@ def app_server(input: Inputs, output: Outputs, session: Session): client=self._create_session_client, enable_bookmarking=enable_bookmarking, tools=self.tools, + greeting_tables=self.greeting_tables, + categorical_threshold=self._categorical_threshold, ) @reactive.calc @@ -541,6 +550,8 @@ def create_session_client(**kwargs) -> chatlas.Chat: client=create_session_client, enable_bookmarking=enable_bookmarking, tools=self.tools, + greeting_tables=self.greeting_tables, + categorical_threshold=self._categorical_threshold, ) @@ -663,6 +674,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, enable_bookmarking: Literal["auto", True, False] = "auto", ) -> None: ... @@ -681,6 +693,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, enable_bookmarking: Literal["auto", True, False] = "auto", ) -> None: ... @@ -699,6 +712,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, enable_bookmarking: Literal["auto", True, False] = "auto", ) -> None: ... @@ -717,6 +731,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, enable_bookmarking: Literal["auto", True, False] = "auto", ) -> None: ... @@ -735,6 +750,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, enable_bookmarking: Literal["auto", True, False] = "auto", ) -> None: ... @@ -752,6 +768,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, enable_bookmarking: Literal["auto", True, False] = "auto", ): # Sanity check: Express should always have a (stub/real) session @@ -773,6 +790,7 @@ def __init__( categorical_threshold=categorical_threshold, extra_instructions=extra_instructions, prompt_template=prompt_template, + greeting_tables=greeting_tables, ) self.id = id or (f"querychat_{table_name}" if table_name else "querychat") @@ -821,6 +839,8 @@ def _ensure_server_started(self) -> None: client=self._create_session_client, enable_bookmarking=self._enable_bookmarking, tools=self.tools, + greeting_tables=self.greeting_tables, + categorical_threshold=self._categorical_threshold, ) def sidebar( diff --git a/pkg-py/src/querychat/_shiny_module.py b/pkg-py/src/querychat/_shiny_module.py index f0b3e8d4..138ee7e3 100644 --- a/pkg-py/src/querychat/_shiny_module.py +++ b/pkg-py/src/querychat/_shiny_module.py @@ -12,7 +12,7 @@ from shiny import module, reactive, ui -from ._querychat_core import GREETING_PROMPT, warn_multi_table_flat_accessor +from ._querychat_core import build_greeting_prompt, warn_multi_table_flat_accessor from ._table_accessor import TableAccessor from ._viz_altair_widget import AltairWidget from ._viz_ggsql import execute_ggsql @@ -215,6 +215,8 @@ def mod_server( client: Callable[..., chatlas.Chat], enable_bookmarking: bool, tools: set[str] | None = None, + greeting_tables: list[str] | bool | None = None, + categorical_threshold: int = 20, ) -> ServerValues[IntoFrameT]: # Holds a generated greeting so it can be saved and restored on bookmark. # Static greetings live in the UI (chat_ui(greeting=)) and persist already. @@ -347,8 +349,13 @@ async def _handle_greeting_requested(): GreetWarning, stacklevel=2, ) + greeting_prompt = build_greeting_prompt( + data_sources, + categorical_threshold, + greeting_tables, + ) greeting_client = client(tools=None) - stream = await greeting_client.stream_async(GREETING_PROMPT, echo="none") + stream = await greeting_client.stream_async(greeting_prompt, echo="none") await chat_ui.set_greeting( shinychat.chat_greeting(stream, dismissible=False) ) diff --git a/pkg-py/src/querychat/_streamlit.py b/pkg-py/src/querychat/_streamlit.py index b8093b83..b1cd6244 100644 --- a/pkg-py/src/querychat/_streamlit.py +++ b/pkg-py/src/querychat/_streamlit.py @@ -8,8 +8,8 @@ from ._querychat_base import TOOL_GROUPS, QueryChatBase from ._querychat_core import ( - GREETING_PROMPT, AppState, + build_greeting_prompt, create_app_state, stream_response, ) @@ -117,6 +117,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... @overload @@ -133,6 +134,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... @overload @@ -149,6 +151,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... @overload @@ -165,6 +168,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... @overload @@ -181,6 +185,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ) -> None: ... def __init__( @@ -196,6 +201,7 @@ def __init__( prompt_template: Optional[str | Path] = None, categorical_threshold: int = 20, data_description: Optional[str | Path] = None, + greeting_tables: list[str] | bool | None = None, ): super().__init__( data_source, @@ -208,6 +214,7 @@ def __init__( categorical_threshold=categorical_threshold, extra_instructions=extra_instructions, prompt_template=prompt_template, + greeting_tables=greeting_tables, ) self._state_key = f"_querychat_{table_name}" if table_name else "_querychat" @@ -291,7 +298,12 @@ def ui(self) -> None: with st.chat_message("assistant"): placeholder = st.empty() placeholder.markdown("*Preparing your data assistant...*") - for chunk in stream_response(state.client, GREETING_PROMPT): + greeting_prompt = build_greeting_prompt( + state.data_sources, + self._categorical_threshold, + self.greeting_tables, + ) + for chunk in stream_response(state.client, greeting_prompt): greeting += chunk placeholder.markdown(greeting, unsafe_allow_html=True) state.set_greeting(greeting) diff --git a/pkg-py/tests/test_querychat.py b/pkg-py/tests/test_querychat.py index 3293e356..6b4f2e1e 100644 --- a/pkg-py/tests/test_querychat.py +++ b/pkg-py/tests/test_querychat.py @@ -7,6 +7,7 @@ import pytest from querychat import QueryChat from querychat._datasource import IbisSource, PolarsLazySource +from querychat._querychat_core import GREETING_MARKER, build_greeting_prompt @pytest.fixture(autouse=True) @@ -194,3 +195,155 @@ def test_querychat_with_ibis_table(): assert executed["name"].iloc[0] == "Bob" finally: conn.disconnect() + + +def test_build_greeting_prompt_single_table_includes_schema(sample_df): + """Single table with greeting_tables=None loads schema automatically.""" + import narwhals.stable.v1 as nw + from querychat._datasource import DataFrameSource + + source = DataFrameSource(nw.from_native(sample_df), "people") + prompt = build_greeting_prompt( + data_sources={"people": source}, + categorical_threshold=20, + greeting_tables=None, + ) + assert prompt.startswith(GREETING_MARKER) + assert "" in prompt + assert "people" in prompt # schema content references table + + +def test_build_greeting_prompt_multi_table_no_signal_omits_schema(): + """Multi-table with greeting_tables=None omits schema, adds explorer hint.""" + import narwhals.stable.v1 as nw + from querychat._datasource import DataFrameSource + + sources = { + "orders": DataFrameSource(nw.from_native(pd.DataFrame({"id": [1]})), "orders"), + "customers": DataFrameSource(nw.from_native(pd.DataFrame({"id": [1]})), "customers"), + } + prompt = build_greeting_prompt( + data_sources=sources, + categorical_threshold=20, + greeting_tables=None, + ) + assert prompt.startswith(GREETING_MARKER) + assert "" not in prompt + # Should encourage exploration of available data + assert "available" in prompt.lower() or "explore" in prompt.lower() + + +def test_build_greeting_prompt_explicit_tables_includes_only_those(): + """greeting_tables list loads schema for specified tables only.""" + import narwhals.stable.v1 as nw + from querychat._datasource import DataFrameSource + + sources = { + "orders": DataFrameSource(nw.from_native(pd.DataFrame({"amount": [10.0]})), "orders"), + "customers": DataFrameSource(nw.from_native(pd.DataFrame({"name": ["Alice"]})), "customers"), + } + prompt = build_greeting_prompt( + data_sources=sources, + categorical_threshold=20, + greeting_tables=["orders"], + ) + assert prompt.startswith(GREETING_MARKER) + assert "" in prompt + assert "orders" in prompt + # customers schema should not appear (only the orders schema section) + # The prompt should contain "orders" but we check for column content distinction + assert "amount" in prompt # orders column + assert "name" not in prompt # customers column excluded + + +def test_build_greeting_prompt_true_includes_all_tables(): + """greeting_tables=True loads schema for all tables.""" + import narwhals.stable.v1 as nw + from querychat._datasource import DataFrameSource + + sources = { + "orders": DataFrameSource(nw.from_native(pd.DataFrame({"amount": [10.0]})), "orders"), + "customers": DataFrameSource(nw.from_native(pd.DataFrame({"name": ["Alice"]})), "customers"), + } + prompt = build_greeting_prompt( + data_sources=sources, + categorical_threshold=20, + greeting_tables=True, + ) + assert prompt.startswith(GREETING_MARKER) + assert "" in prompt + assert "amount" in prompt # orders column + assert "name" in prompt # customers column + + +def test_build_greeting_prompt_false_omits_schema_adds_explorer(): + """greeting_tables=False skips schema entirely and adds explorer hint.""" + import narwhals.stable.v1 as nw + from querychat._datasource import DataFrameSource + + source = DataFrameSource(nw.from_native(pd.DataFrame({"id": [1]})), "t") + prompt = build_greeting_prompt( + data_sources={"t": source}, + categorical_threshold=20, + greeting_tables=False, + ) + assert prompt.startswith(GREETING_MARKER) + assert "" not in prompt + assert "available" in prompt.lower() or "explore" in prompt.lower() + + +def test_generate_greeting_embeds_schema_for_single_table(sample_df): + """generate_greeting() sends schema-embedded prompt for a single table.""" + qc = QueryChat(data_source=sample_df, table_name="people") + seen: dict[str, str] = {} + + def fake_chat(self, prompt, *args, **kwargs): + seen["prompt"] = prompt + seen["system_prompt"] = self.system_prompt or "" + return "Hello!" + + with patch("chatlas.Chat.chat", fake_chat): + result = qc.generate_greeting() + + assert result == "Hello!" + assert seen["prompt"].startswith(GREETING_MARKER) + assert "" in seen["prompt"] + + +def test_generate_greeting_omits_schema_for_multi_table_default(): + """generate_greeting() with no greeting_tables omits schema for multi-table.""" + qc = QueryChat() + qc.add_table(pd.DataFrame({"a": [1]}), "t1") + qc.add_table(pd.DataFrame({"b": [2]}), "t2") + seen: dict[str, str] = {} + + def fake_chat(self, prompt, *args, **kwargs): + seen["prompt"] = prompt + return "Hello!" + + with patch("chatlas.Chat.chat", fake_chat): + qc.generate_greeting() + + assert seen["prompt"].startswith(GREETING_MARKER) + assert "" not in seen["prompt"] + + +def test_generate_greeting_respects_greeting_tables_param(sample_df): + """generate_greeting() includes schema for tables named in greeting_tables.""" + qc = QueryChat(greeting_tables=["people"]) + qc.add_table(sample_df, "people") + qc.add_table(pd.DataFrame({"x": [1]}), "other") + seen: dict[str, str] = {} + + def fake_chat(self, prompt, *args, **kwargs): + seen["prompt"] = prompt + return "Hi!" + + with patch("chatlas.Chat.chat", fake_chat): + qc.generate_greeting() + + assert seen["prompt"].startswith(GREETING_MARKER) + assert "" in seen["prompt"] + # only people schema present: people has 'name', 'age'; other has 'x' + assert "age" in seen["prompt"] + assert "x" not in seen["prompt"] diff --git a/pkg-py/tests/test_shiny_module.py b/pkg-py/tests/test_shiny_module.py index ff3abc78..1190ffb8 100644 --- a/pkg-py/tests/test_shiny_module.py +++ b/pkg-py/tests/test_shiny_module.py @@ -2,9 +2,11 @@ from __future__ import annotations +import inspect import os from unittest.mock import patch +import pandas as pd import pytest from shiny import ui @@ -21,6 +23,11 @@ def set_dummy_api_key(): del os.environ["OPENAI_API_KEY"] +@pytest.fixture +def sample_df(): + return pd.DataFrame({"id": [1, 2, 3], "value": [10, 20, 30]}) + + def _fake_chat_ui(*args, **kwargs): """Return a real Tag so htmltools accepts it; stash kwargs for inspection.""" _fake_chat_ui.last_kwargs = kwargs @@ -48,3 +55,92 @@ def test_mod_ui_allow_attachments_can_be_overridden(): assert _fake_chat_ui.last_kwargs.get("allow_attachments") is False +def _unwrap_module_server(wrapped): + """ + Retrieve the original function from a @module.server-decorated callable. + + Shiny's @module.server decorator stores the original function as a cell in + the wrapper's __closure__. We locate it by scanning the closure for callables + whose signature includes the expected module server parameters. + """ + if wrapped.__closure__ is None: + return wrapped + for cell in wrapped.__closure__: + try: + val = cell.cell_contents + except ValueError: + continue + if callable(val) and "data_sources" in inspect.signature(val).parameters: + return val + return wrapped + + +def test_mod_server_accepts_greeting_tables_and_categorical_threshold(): + """mod_server must expose greeting_tables and categorical_threshold params.""" + from querychat._shiny_module import mod_server + + fn = _unwrap_module_server(mod_server) + params = inspect.signature(fn).parameters + assert "greeting_tables" in params, "mod_server missing greeting_tables param" + assert "categorical_threshold" in params, "mod_server missing categorical_threshold param" + assert params["greeting_tables"].default is None + assert params["categorical_threshold"].default == 20 + + +def test_build_greeting_prompt_imported_into_shiny_module(sample_df): + """ + Checks that build_greeting_prompt is imported into _shiny_module (so the + module uses the shared implementation) and that calling it directly with a + single-table source produces a schema-embedded prompt. + + Note: triggering the actual greeting reactive event requires a live Shiny + session, which is not available in unit tests. + """ + import narwhals.stable.v1 as nw + import querychat._shiny_module as shiny_module + from querychat._datasource import DataFrameSource + from querychat._querychat_core import GREETING_MARKER, build_greeting_prompt + + # Verify build_greeting_prompt is accessible in _shiny_module + assert hasattr(shiny_module, "build_greeting_prompt"), ( + "build_greeting_prompt must be imported into _shiny_module" + ) + assert shiny_module.build_greeting_prompt is build_greeting_prompt + + # Verify single-table auto-mode produces a schema-containing prompt + source = DataFrameSource(nw.from_native(sample_df), "sample") + prompt = build_greeting_prompt( + data_sources={"sample": source}, + categorical_threshold=20, + greeting_tables=None, + ) + assert prompt.startswith(GREETING_MARKER) + assert "" in prompt + + +def test_build_greeting_prompt_called_with_multi_table_no_greeting_tables(): + """ + With two tables and greeting_tables=None, build_greeting_prompt omits schema. + + This mirrors the assertion from the brief: multi-table with no greeting_tables + → GREETING_MARKER prefix present, absent. + """ + import narwhals.stable.v1 as nw + from querychat._datasource import DataFrameSource + from querychat._querychat_core import GREETING_MARKER, build_greeting_prompt + + sources = { + "orders": DataFrameSource(nw.from_native(pd.DataFrame({"id": [1]})), "orders"), + "customers": DataFrameSource( + nw.from_native(pd.DataFrame({"id": [1]})), "customers" + ), + } + prompt = build_greeting_prompt( + data_sources=sources, + categorical_threshold=20, + greeting_tables=None, + ) + assert prompt.startswith(GREETING_MARKER) + assert "" not in prompt + + diff --git a/pkg-r/R/QueryChat.R b/pkg-r/R/QueryChat.R index 576d9abc..efdff4b4 100644 --- a/pkg-r/R/QueryChat.R +++ b/pkg-r/R/QueryChat.R @@ -229,6 +229,8 @@ QueryChat <- R6::R6Class( public = list( #' @field greeting The greeting message displayed to users. greeting = NULL, + #' @field greeting_tables Controls which tables' schema to include in greeting. + greeting_tables = NULL, #' @field id ID for the QueryChat instance. id = NULL, #' @field id_override Whether the ID was explicitly set by the user. @@ -295,6 +297,7 @@ QueryChat <- R6::R6Class( ..., id = NULL, greeting = NULL, + greeting_tables = NULL, client = NULL, tools = c("filter", "query"), data_description = NULL, @@ -342,6 +345,7 @@ QueryChat <- R6::R6Class( greeting <- read_utf8(greeting) } self$greeting <- greeting + self$greeting_tables <- greeting_tables # Track whether id was explicitly set self$id_override <- id @@ -900,6 +904,8 @@ QueryChat <- R6::R6Class( greeting = self$greeting, client = create_session_client, tools = self$tools, + greeting_tables = self$greeting_tables, + categorical_threshold = private$.categorical_threshold, enable_bookmarking = enable_bookmarking ) result @@ -914,7 +920,12 @@ QueryChat <- R6::R6Class( generate_greeting = function(echo = c("none", "output")) { private$require_initialized("$generate_greeting") chat <- private$create_session_client() - as.character(chat$chat(GREETING_PROMPT, echo = echo)) + greeting_prompt <- build_greeting_prompt( + private$.data_sources, + private$.categorical_threshold, + self$greeting_tables + ) + as.character(chat$chat(greeting_prompt, echo = echo)) }, #' @description diff --git a/pkg-r/R/querychat_module.R b/pkg-r/R/querychat_module.R index cb1c7ff8..9857b88a 100644 --- a/pkg-r/R/querychat_module.R +++ b/pkg-r/R/querychat_module.R @@ -43,6 +43,8 @@ mod_server <- function( greeting, client, tools, + greeting_tables = NULL, + categorical_threshold = 20, enable_bookmarking = FALSE ) { shiny::moduleServer(id, function(input, output, session) { @@ -150,13 +152,20 @@ mod_server <- function( ) return() } - cli::cli_warn(c( - "No {.arg greeting} provided to {.fn QueryChat}. Using the LLM {.arg client} to generate one now.", - "i" = "For faster startup, lower cost, and determinism, consider providing a {.arg greeting} to {.fn QueryChat}.", - "i" = "You can use your {.help querychat::QueryChat} object's {.fn $generate_greeting} method to generate a greeting." - )) + cli::cli_warn( + c( + "No {.arg greeting} provided to {.fn QueryChat}. Using the LLM {.arg client} to generate one now.", + "i" = "For faster startup, lower cost, and determinism, consider providing a {.arg greeting} to {.fn QueryChat}.", + "i" = "You can use your {.help querychat::QueryChat} object's {.fn $generate_greeting} method to generate a greeting." + ) + ) greeting_client <- client(tools = NULL) - stream <- greeting_client$stream_async(GREETING_PROMPT) + greeting_prompt <- build_greeting_prompt( + data_sources, + categorical_threshold, + greeting_tables + ) + stream <- greeting_client$stream_async(greeting_prompt) p <- shinychat::chat_set_greeting( "chat", shinychat::chat_greeting(stream, dismissible = FALSE) @@ -322,13 +331,68 @@ mod_server <- function( }) } -# TODO: Make this dependent on enabled tools -GREETING_PROMPT <- paste( +GREETING_MARKER <- "\n" + +GREETING_BASE_TEXT <- paste( "Please give me a friendly greeting.", "Include a few sample suggestions grouped under ##### headings,", "using the suggestion card format from your instructions." ) +GREETING_EXPLORE_ADDENDUM <- paste( + "Include at least one suggestion encouraging the user to explore what", + "data and questions are available - for example, asking which tables or", + "columns exist, or what kinds of analysis are possible." +) + +# TODO: Make this dependent on enabled tools +GREETING_PROMPT <- GREETING_BASE_TEXT + +build_greeting_prompt <- function( + data_sources, + categorical_threshold, + greeting_tables +) { + table_names <- resolve_greeting_tables(data_sources, greeting_tables) + + if (length(table_names) > 0) { + multi <- length(table_names) > 1 + schema_sections <- character() + for (name in table_names) { + source <- data_sources[[name]] + schema <- source$get_schema(categorical_threshold = categorical_threshold) + section <- if (multi) paste0("Table '", name, "':\n", schema) else schema + schema_sections <- c(schema_sections, section) + } + schema_block <- paste(schema_sections, collapse = "\n\n") + body <- paste0( + "\n", + schema_block, + "\n\n\n", + GREETING_BASE_TEXT + ) + } else { + body <- paste(GREETING_BASE_TEXT, GREETING_EXPLORE_ADDENDUM) + } + + paste0(GREETING_MARKER, body) +} + +resolve_greeting_tables <- function(data_sources, greeting_tables) { + all_names <- names(data_sources) + if (isTRUE(greeting_tables)) { + return(all_names) + } + if (identical(greeting_tables, FALSE)) { + return(character()) + } + if (is.character(greeting_tables)) { + return(intersect(greeting_tables, all_names)) + } + # Auto (NULL): single table -> include; multi-table -> none + if (length(all_names) == 1L) all_names else character() +} + # A list of records (named lists) bookmarked to the URL comes back from Shiny's # decoder as a data.frame, because jsonlite simplifies a JSON array of objects # (simplifyDataFrame = TRUE). Rebuild the list-of-lists shape row by row, @@ -339,13 +403,15 @@ restore_record_list <- function(x) { return(NULL) } if (is.data.frame(x)) { - return(lapply(seq_len(nrow(x)), function(i) { - row <- as.list(x[i, , drop = FALSE]) - row <- lapply(row, function(v) { - if (length(v) == 1 && is.na(v)) NULL else v + return( + lapply(seq_len(nrow(x)), function(i) { + row <- as.list(x[i, , drop = FALSE]) + row <- lapply(row, function(v) { + if (length(v) == 1 && is.na(v)) NULL else v + }) + row[!vapply(row, is.null, logical(1))] }) - row[!vapply(row, is.null, logical(1))] - })) + ) } as.list(x) } diff --git a/pkg-r/man/QueryChat.Rd b/pkg-r/man/QueryChat.Rd index c6f759de..9a61ea7f 100644 --- a/pkg-r/man/QueryChat.Rd +++ b/pkg-r/man/QueryChat.Rd @@ -98,6 +98,8 @@ qc <- QueryChat$new(con, "mtcars") \describe{ \item{\code{greeting}}{The greeting message displayed to users.} + \item{\code{greeting_tables}}{Controls which tables' schema to include in greeting.} + \item{\code{id}}{ID for the QueryChat instance.} \item{\code{id_override}}{Whether the ID was explicitly set by the user.} @@ -148,6 +150,7 @@ qc <- QueryChat$new(con, "mtcars") ..., id = NULL, greeting = NULL, + greeting_tables = NULL, client = NULL, tools = c("filter", "query"), data_description = NULL, @@ -180,6 +183,7 @@ character string (in Markdown format) or a file path. If not provided, a greeting will be generated at the start of each conversation using the LLM, which adds latency and cost. Use \verb{$generate_greeting()} to create a greeting to save and reuse.} + \item{\code{greeting_tables}}{Controls which tables' schema to include in greeting.} \item{\code{client}}{Optional chat client. Can be: \itemize{ \item An \link[ellmer:Chat]{ellmer::Chat} object diff --git a/pkg-r/tests/testthat/test-QueryChat.R b/pkg-r/tests/testthat/test-QueryChat.R index 99313737..fdc6578d 100644 --- a/pkg-r/tests/testthat/test-QueryChat.R +++ b/pkg-r/tests/testthat/test-QueryChat.R @@ -657,10 +657,12 @@ describe("QueryChat$client()", { }) test_that("QueryChat$generate_greeting() generates a greeting using the LLM client", { + skip_if_no_dataframe_engine() client <- mock_ellmer_chat_client( public = list( chat = function(message, ...) { - expect_equal(message, GREETING_PROMPT) + expect_true(startsWith(message, querychat:::GREETING_MARKER)) + expect_match(message, "") # single table → schema included "Welcome! This is a mock response for testing." } ) @@ -668,7 +670,6 @@ test_that("QueryChat$generate_greeting() generates a greeting using the LLM clie test_df <- new_test_df() - # Create a mock client that returns a fixed greeting qc <- QueryChat$new(test_df, client = client) withr::defer(qc$cleanup()) @@ -676,6 +677,72 @@ test_that("QueryChat$generate_greeting() generates a greeting using the LLM clie expect_equal(greeting, "Welcome! This is a mock response for testing.") }) +test_that("generate_greeting() sends schema-embedded prompt for single table", { + skip_if_no_dataframe_engine() + client <- mock_ellmer_chat_client( + public = list( + chat = function(message, ...) { + expect_true(startsWith(message, querychat:::GREETING_MARKER)) + expect_match(message, "") + "Hello!" + } + ) + ) + test_df <- new_test_df() + qc <- QueryChat$new(test_df, client = client) + withr::defer(qc$cleanup()) + + greeting <- qc$generate_greeting() + expect_equal(greeting, "Hello!") +}) + +test_that("generate_greeting() omits schema for multi-table with no greeting_tables", { + skip_if_no_dataframe_engine() + client <- mock_ellmer_chat_client( + public = list( + chat = function(message, ...) { + expect_true(startsWith(message, querychat:::GREETING_MARKER)) + expect_false(grepl("", message)) + "Hello!" + } + ) + ) + qc <- QueryChat$new( + data_source = NULL, + table_name = "placeholder", + client = client + ) + withr::defer(qc$cleanup()) + qc$add_table(new_test_df(), "t1") + qc$add_table(data.frame(x = 1), "t2") + + qc$generate_greeting() +}) + +test_that("generate_greeting() includes schema for tables in greeting_tables", { + skip_if_no_dataframe_engine() + client <- mock_ellmer_chat_client( + public = list( + chat = function(message, ...) { + expect_match(message, "amount") + expect_false(grepl("\\bname\\b", message)) + "Hello!" + } + ) + ) + qc <- QueryChat$new( + data_source = NULL, + table_name = "placeholder", + client = client, + greeting_tables = "t1" + ) + withr::defer(qc$cleanup()) + qc$add_table(data.frame(amount = c(1, 2)), "t1") + qc$add_table(data.frame(name = c("A")), "t2") + + qc$generate_greeting() +}) + test_that("QueryChat$server() errors when called outside Shiny context", { withr::local_envvar(OPENAI_API_KEY = "boop") diff --git a/pkg-r/tests/testthat/test-querychat_module.R b/pkg-r/tests/testthat/test-querychat_module.R index 040db90f..157944ca 100644 --- a/pkg-r/tests/testthat/test-querychat_module.R +++ b/pkg-r/tests/testthat/test-querychat_module.R @@ -348,3 +348,41 @@ test_that("restored viz widgets survive a second bookmark cycle", { } ) }) + +test_that("build_greeting_prompt includes schema for single table", { + skip_if_no_dataframe_engine() + source <- local_data_frame_source(new_test_df(), "people") + data_sources <- list(people = source) + + prompt <- querychat:::build_greeting_prompt(data_sources, 20, NULL) + + expect_true(startsWith(prompt, querychat:::GREETING_MARKER)) + expect_match(prompt, "") +}) + +test_that("build_greeting_prompt omits schema for multi-table with no signal", { + skip_if_no_dataframe_engine() + s1 <- local_data_frame_source(new_test_df(), "t1") + s2 <- local_data_frame_source(new_test_df(), "t2") + data_sources <- list(t1 = s1, t2 = s2) + + prompt <- querychat:::build_greeting_prompt(data_sources, 20, NULL) + + expect_true(startsWith(prompt, querychat:::GREETING_MARKER)) + expect_false(grepl("", prompt)) + # Should mention exploration + expect_match(prompt, "available|explore", ignore.case = TRUE) +}) + +test_that("build_greeting_prompt includes schema for explicit greeting_tables", { + skip_if_no_dataframe_engine() + s1 <- local_data_frame_source(data.frame(amount = c(10, 20)), "orders") + s2 <- local_data_frame_source(data.frame(name = c("A")), "customers") + data_sources <- list(orders = s1, customers = s2) + + prompt <- querychat:::build_greeting_prompt(data_sources, 20, "orders") + + expect_match(prompt, "") + expect_match(prompt, "amount") + expect_false(grepl("name", prompt)) +})