diff --git a/pkg-py/src/querychat/_gradio.py b/pkg-py/src/querychat/_gradio.py
index 18a87cb8..e6434b63 100644
--- a/pkg-py/src/querychat/_gradio.py
+++ b/pkg-py/src/querychat/_gradio.py
@@ -12,8 +12,8 @@
from ._querychat_base import TOOL_GROUPS, StateDictQueryChat
from ._querychat_core import (
- GREETING_PROMPT,
AppStateDict,
+ build_greeting_prompt,
create_app_state,
stream_response,
)
@@ -98,6 +98,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
@overload
@@ -114,6 +115,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
@overload
@@ -130,6 +132,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
@overload
@@ -146,6 +149,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
@overload
@@ -162,6 +166,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
def __init__(
@@ -177,6 +182,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
):
super().__init__(
data_source,
@@ -189,6 +195,7 @@ def __init__(
categorical_threshold=categorical_threshold,
extra_instructions=extra_instructions,
prompt_template=prompt_template,
+ greeting_tables=greeting_tables,
)
@property
@@ -288,7 +295,12 @@ def initialize_greeting(state_dict: AppStateDict):
if not state.initialize_greeting_if_preset():
greeting = ""
- for chunk in stream_response(state.client, GREETING_PROMPT):
+ greeting_prompt = build_greeting_prompt(
+ state.data_sources,
+ self._categorical_threshold,
+ self.greeting_tables,
+ )
+ for chunk in stream_response(state.client, greeting_prompt):
greeting += chunk
state.set_greeting(greeting)
diff --git a/pkg-py/src/querychat/_querychat_base.py b/pkg-py/src/querychat/_querychat_base.py
index d3b4c612..0cccfa5b 100644
--- a/pkg-py/src/querychat/_querychat_base.py
+++ b/pkg-py/src/querychat/_querychat_base.py
@@ -32,9 +32,9 @@
validate_source_group_compatibility,
)
from ._querychat_core import (
- GREETING_PROMPT,
AppState,
AppStateDict,
+ build_greeting_prompt,
create_app_state,
warn_multi_table_flat_accessor,
)
@@ -91,6 +91,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
):
self._data_dicts: list[DataDict] = _normalize_data_dicts(data_dict)
@@ -113,6 +114,7 @@ def __init__(
self._client_spec: str | chatlas.Chat | None = client
self._client_console = None
+ self.greeting_tables: list[str] | bool | None = greeting_tables
self._system_prompt: QueryChatSystemPrompt | None = None
if data_source is not None:
@@ -322,7 +324,12 @@ def generate_greeting(self, *, echo: Literal["none", "output"] = "none") -> str:
chat = create_client(self._client_spec)
if self._system_prompt is not None:
chat.system_prompt = self._system_prompt.render(self.tools)
- return str(chat.chat(GREETING_PROMPT, echo=echo))
+ prompt = build_greeting_prompt(
+ self._data_sources,
+ self._categorical_threshold,
+ self.greeting_tables,
+ )
+ return str(chat.chat(prompt, echo=echo))
def console(
self,
diff --git a/pkg-py/src/querychat/_querychat_core.py b/pkg-py/src/querychat/_querychat_core.py
index 87d110ab..650b0c57 100644
--- a/pkg-py/src/querychat/_querychat_core.py
+++ b/pkg-py/src/querychat/_querychat_core.py
@@ -3,17 +3,19 @@
from __future__ import annotations
__all__ = [
- "GREETING_PROMPT",
+ "GREETING_MARKER",
+ "GREETING_PROMPT", # backward compat alias — value unchanged; callers relying on the full prompt should use build_greeting_prompt() instead
"AppState",
"AppStateDict",
"ClientFactory",
+ "build_greeting_prompt",
"create_app_state",
"stream_response",
"stream_response_async",
]
import warnings
-from collections.abc import Callable
+from collections.abc import Callable, Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, TypedDict, Union
@@ -23,11 +25,23 @@
from .tools import UpdateDashboardData
-GREETING_PROMPT: str = (
+GREETING_MARKER: str = "\n"
+"""Sentinel prepended to every greeting prompt turn; used to skip it in display."""
+
+GREETING_BASE_TEXT: str = (
"Please give me a friendly greeting. "
"Include a few sample suggestions grouped under ##### headings, "
"using the suggestion card format from your instructions."
)
+
+GREETING_EXPLORE_ADDENDUM: str = (
+ " Include at least one suggestion encouraging the user to explore what "
+ "data and questions are available — for example, asking which tables or "
+ "columns exist, or what kinds of analysis are possible."
+)
+
+# Keep the old name as an alias so any external code that imported it still works.
+GREETING_PROMPT: str = GREETING_BASE_TEXT
"""Prompt used to generate the initial greeting message."""
if TYPE_CHECKING:
@@ -46,6 +60,63 @@
"""Factory that creates a Chat client with update_dashboard and reset_dashboard callbacks."""
+def build_greeting_prompt(
+ data_sources: Mapping[str, DataSource],
+ categorical_threshold: int,
+ greeting_tables: list[str] | bool | None, # noqa: FBT001 — bool is a sentinel, not a flag
+) -> str:
+ """
+ Build a greeting prompt, optionally embedding table schema.
+
+ Parameters
+ ----------
+ data_sources
+ All registered data sources keyed by table name.
+ categorical_threshold
+ Passed to ``get_schema`` for each included table.
+ greeting_tables
+ Which tables to include schema for:
+ - ``None``: auto — single table gets full schema; multi-table gets none.
+ - ``True``: all tables.
+ - ``False``: no tables (explorer hint only).
+ - list of names: only the named tables.
+
+ """
+ table_names = resolve_greeting_tables(data_sources, greeting_tables)
+
+ if table_names:
+ schema_sections: list[str] = []
+ multi = len(table_names) > 1
+ for name in table_names:
+ source = data_sources[name]
+ schema = source.get_schema(categorical_threshold=categorical_threshold)
+ section = f"Table '{name}':\n{schema}" if multi else schema
+ schema_sections.append(section)
+ schema_block = "\n\n".join(schema_sections)
+ body = f"\n{schema_block}\n\n\n{GREETING_BASE_TEXT}"
+ else:
+ body = GREETING_BASE_TEXT + GREETING_EXPLORE_ADDENDUM
+
+ return f"{GREETING_MARKER}{body}"
+
+
+def resolve_greeting_tables(
+ data_sources: Mapping[str, DataSource],
+ greeting_tables: list[str] | bool | None, # noqa: FBT001 — bool is a sentinel, not a flag
+) -> list[str]:
+ """Return the list of table names whose schema to include in the greeting."""
+ if greeting_tables is True:
+ return list(data_sources.keys())
+ if greeting_tables is False:
+ return []
+ if isinstance(greeting_tables, list):
+ return [t for t in greeting_tables if t in data_sources]
+ # Auto (None): single table → include schema; multi-table → no schema
+ if len(data_sources) == 1:
+ return list(data_sources.keys())
+ return []
+
+
def warn_multi_table_flat_accessor(
accessor_name: str, primary_table: str, table_list: str, stacklevel: int = 3
) -> None:
@@ -246,7 +317,7 @@ def get_display_messages(self) -> list[DisplayMessage]:
if text_parts:
text = "\n\n".join(text_parts)
# Skip the greeting prompt - it's an internal message
- if turn.role == "user" and text == GREETING_PROMPT:
+ if turn.role == "user" and text.startswith(GREETING_MARKER):
continue
messages.append({"role": turn.role, "content": text})
diff --git a/pkg-py/src/querychat/_shiny.py b/pkg-py/src/querychat/_shiny.py
index f616671f..b342642c 100644
--- a/pkg-py/src/querychat/_shiny.py
+++ b/pkg-py/src/querychat/_shiny.py
@@ -160,6 +160,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
@overload
@@ -177,6 +178,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
@overload
@@ -194,6 +196,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
@overload
@@ -211,6 +214,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
@overload
@@ -228,6 +232,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
def __init__(
@@ -244,6 +249,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
):
super().__init__(
data_source,
@@ -256,6 +262,7 @@ def __init__(
categorical_threshold=categorical_threshold,
extra_instructions=extra_instructions,
prompt_template=prompt_template,
+ greeting_tables=greeting_tables,
)
self.id = id or (f"querychat_{table_name}" if table_name else "querychat")
@@ -330,6 +337,8 @@ def app_server(input: Inputs, output: Outputs, session: Session):
client=self._create_session_client,
enable_bookmarking=enable_bookmarking,
tools=self.tools,
+ greeting_tables=self.greeting_tables,
+ categorical_threshold=self._categorical_threshold,
)
@reactive.calc
@@ -541,6 +550,8 @@ def create_session_client(**kwargs) -> chatlas.Chat:
client=create_session_client,
enable_bookmarking=enable_bookmarking,
tools=self.tools,
+ greeting_tables=self.greeting_tables,
+ categorical_threshold=self._categorical_threshold,
)
@@ -663,6 +674,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
enable_bookmarking: Literal["auto", True, False] = "auto",
) -> None: ...
@@ -681,6 +693,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
enable_bookmarking: Literal["auto", True, False] = "auto",
) -> None: ...
@@ -699,6 +712,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
enable_bookmarking: Literal["auto", True, False] = "auto",
) -> None: ...
@@ -717,6 +731,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
enable_bookmarking: Literal["auto", True, False] = "auto",
) -> None: ...
@@ -735,6 +750,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
enable_bookmarking: Literal["auto", True, False] = "auto",
) -> None: ...
@@ -752,6 +768,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
enable_bookmarking: Literal["auto", True, False] = "auto",
):
# Sanity check: Express should always have a (stub/real) session
@@ -773,6 +790,7 @@ def __init__(
categorical_threshold=categorical_threshold,
extra_instructions=extra_instructions,
prompt_template=prompt_template,
+ greeting_tables=greeting_tables,
)
self.id = id or (f"querychat_{table_name}" if table_name else "querychat")
@@ -821,6 +839,8 @@ def _ensure_server_started(self) -> None:
client=self._create_session_client,
enable_bookmarking=self._enable_bookmarking,
tools=self.tools,
+ greeting_tables=self.greeting_tables,
+ categorical_threshold=self._categorical_threshold,
)
def sidebar(
diff --git a/pkg-py/src/querychat/_shiny_module.py b/pkg-py/src/querychat/_shiny_module.py
index f0b3e8d4..138ee7e3 100644
--- a/pkg-py/src/querychat/_shiny_module.py
+++ b/pkg-py/src/querychat/_shiny_module.py
@@ -12,7 +12,7 @@
from shiny import module, reactive, ui
-from ._querychat_core import GREETING_PROMPT, warn_multi_table_flat_accessor
+from ._querychat_core import build_greeting_prompt, warn_multi_table_flat_accessor
from ._table_accessor import TableAccessor
from ._viz_altair_widget import AltairWidget
from ._viz_ggsql import execute_ggsql
@@ -215,6 +215,8 @@ def mod_server(
client: Callable[..., chatlas.Chat],
enable_bookmarking: bool,
tools: set[str] | None = None,
+ greeting_tables: list[str] | bool | None = None,
+ categorical_threshold: int = 20,
) -> ServerValues[IntoFrameT]:
# Holds a generated greeting so it can be saved and restored on bookmark.
# Static greetings live in the UI (chat_ui(greeting=)) and persist already.
@@ -347,8 +349,13 @@ async def _handle_greeting_requested():
GreetWarning,
stacklevel=2,
)
+ greeting_prompt = build_greeting_prompt(
+ data_sources,
+ categorical_threshold,
+ greeting_tables,
+ )
greeting_client = client(tools=None)
- stream = await greeting_client.stream_async(GREETING_PROMPT, echo="none")
+ stream = await greeting_client.stream_async(greeting_prompt, echo="none")
await chat_ui.set_greeting(
shinychat.chat_greeting(stream, dismissible=False)
)
diff --git a/pkg-py/src/querychat/_streamlit.py b/pkg-py/src/querychat/_streamlit.py
index b8093b83..b1cd6244 100644
--- a/pkg-py/src/querychat/_streamlit.py
+++ b/pkg-py/src/querychat/_streamlit.py
@@ -8,8 +8,8 @@
from ._querychat_base import TOOL_GROUPS, QueryChatBase
from ._querychat_core import (
- GREETING_PROMPT,
AppState,
+ build_greeting_prompt,
create_app_state,
stream_response,
)
@@ -117,6 +117,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
@overload
@@ -133,6 +134,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
@overload
@@ -149,6 +151,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
@overload
@@ -165,6 +168,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
@overload
@@ -181,6 +185,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
) -> None: ...
def __init__(
@@ -196,6 +201,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
+ greeting_tables: list[str] | bool | None = None,
):
super().__init__(
data_source,
@@ -208,6 +214,7 @@ def __init__(
categorical_threshold=categorical_threshold,
extra_instructions=extra_instructions,
prompt_template=prompt_template,
+ greeting_tables=greeting_tables,
)
self._state_key = f"_querychat_{table_name}" if table_name else "_querychat"
@@ -291,7 +298,12 @@ def ui(self) -> None:
with st.chat_message("assistant"):
placeholder = st.empty()
placeholder.markdown("*Preparing your data assistant...*")
- for chunk in stream_response(state.client, GREETING_PROMPT):
+ greeting_prompt = build_greeting_prompt(
+ state.data_sources,
+ self._categorical_threshold,
+ self.greeting_tables,
+ )
+ for chunk in stream_response(state.client, greeting_prompt):
greeting += chunk
placeholder.markdown(greeting, unsafe_allow_html=True)
state.set_greeting(greeting)
diff --git a/pkg-py/tests/test_querychat.py b/pkg-py/tests/test_querychat.py
index 3293e356..6b4f2e1e 100644
--- a/pkg-py/tests/test_querychat.py
+++ b/pkg-py/tests/test_querychat.py
@@ -7,6 +7,7 @@
import pytest
from querychat import QueryChat
from querychat._datasource import IbisSource, PolarsLazySource
+from querychat._querychat_core import GREETING_MARKER, build_greeting_prompt
@pytest.fixture(autouse=True)
@@ -194,3 +195,155 @@ def test_querychat_with_ibis_table():
assert executed["name"].iloc[0] == "Bob"
finally:
conn.disconnect()
+
+
+def test_build_greeting_prompt_single_table_includes_schema(sample_df):
+ """Single table with greeting_tables=None loads schema automatically."""
+ import narwhals.stable.v1 as nw
+ from querychat._datasource import DataFrameSource
+
+ source = DataFrameSource(nw.from_native(sample_df), "people")
+ prompt = build_greeting_prompt(
+ data_sources={"people": source},
+ categorical_threshold=20,
+ greeting_tables=None,
+ )
+ assert prompt.startswith(GREETING_MARKER)
+ assert "" in prompt
+ assert "people" in prompt # schema content references table
+
+
+def test_build_greeting_prompt_multi_table_no_signal_omits_schema():
+ """Multi-table with greeting_tables=None omits schema, adds explorer hint."""
+ import narwhals.stable.v1 as nw
+ from querychat._datasource import DataFrameSource
+
+ sources = {
+ "orders": DataFrameSource(nw.from_native(pd.DataFrame({"id": [1]})), "orders"),
+ "customers": DataFrameSource(nw.from_native(pd.DataFrame({"id": [1]})), "customers"),
+ }
+ prompt = build_greeting_prompt(
+ data_sources=sources,
+ categorical_threshold=20,
+ greeting_tables=None,
+ )
+ assert prompt.startswith(GREETING_MARKER)
+ assert "" not in prompt
+ # Should encourage exploration of available data
+ assert "available" in prompt.lower() or "explore" in prompt.lower()
+
+
+def test_build_greeting_prompt_explicit_tables_includes_only_those():
+ """greeting_tables list loads schema for specified tables only."""
+ import narwhals.stable.v1 as nw
+ from querychat._datasource import DataFrameSource
+
+ sources = {
+ "orders": DataFrameSource(nw.from_native(pd.DataFrame({"amount": [10.0]})), "orders"),
+ "customers": DataFrameSource(nw.from_native(pd.DataFrame({"name": ["Alice"]})), "customers"),
+ }
+ prompt = build_greeting_prompt(
+ data_sources=sources,
+ categorical_threshold=20,
+ greeting_tables=["orders"],
+ )
+ assert prompt.startswith(GREETING_MARKER)
+ assert "" in prompt
+ assert "orders" in prompt
+ # customers schema should not appear (only the orders schema section)
+ # The prompt should contain "orders" but we check for column content distinction
+ assert "amount" in prompt # orders column
+ assert "name" not in prompt # customers column excluded
+
+
+def test_build_greeting_prompt_true_includes_all_tables():
+ """greeting_tables=True loads schema for all tables."""
+ import narwhals.stable.v1 as nw
+ from querychat._datasource import DataFrameSource
+
+ sources = {
+ "orders": DataFrameSource(nw.from_native(pd.DataFrame({"amount": [10.0]})), "orders"),
+ "customers": DataFrameSource(nw.from_native(pd.DataFrame({"name": ["Alice"]})), "customers"),
+ }
+ prompt = build_greeting_prompt(
+ data_sources=sources,
+ categorical_threshold=20,
+ greeting_tables=True,
+ )
+ assert prompt.startswith(GREETING_MARKER)
+ assert "" in prompt
+ assert "amount" in prompt # orders column
+ assert "name" in prompt # customers column
+
+
+def test_build_greeting_prompt_false_omits_schema_adds_explorer():
+ """greeting_tables=False skips schema entirely and adds explorer hint."""
+ import narwhals.stable.v1 as nw
+ from querychat._datasource import DataFrameSource
+
+ source = DataFrameSource(nw.from_native(pd.DataFrame({"id": [1]})), "t")
+ prompt = build_greeting_prompt(
+ data_sources={"t": source},
+ categorical_threshold=20,
+ greeting_tables=False,
+ )
+ assert prompt.startswith(GREETING_MARKER)
+ assert "" not in prompt
+ assert "available" in prompt.lower() or "explore" in prompt.lower()
+
+
+def test_generate_greeting_embeds_schema_for_single_table(sample_df):
+ """generate_greeting() sends schema-embedded prompt for a single table."""
+ qc = QueryChat(data_source=sample_df, table_name="people")
+ seen: dict[str, str] = {}
+
+ def fake_chat(self, prompt, *args, **kwargs):
+ seen["prompt"] = prompt
+ seen["system_prompt"] = self.system_prompt or ""
+ return "Hello!"
+
+ with patch("chatlas.Chat.chat", fake_chat):
+ result = qc.generate_greeting()
+
+ assert result == "Hello!"
+ assert seen["prompt"].startswith(GREETING_MARKER)
+ assert "" in seen["prompt"]
+
+
+def test_generate_greeting_omits_schema_for_multi_table_default():
+ """generate_greeting() with no greeting_tables omits schema for multi-table."""
+ qc = QueryChat()
+ qc.add_table(pd.DataFrame({"a": [1]}), "t1")
+ qc.add_table(pd.DataFrame({"b": [2]}), "t2")
+ seen: dict[str, str] = {}
+
+ def fake_chat(self, prompt, *args, **kwargs):
+ seen["prompt"] = prompt
+ return "Hello!"
+
+ with patch("chatlas.Chat.chat", fake_chat):
+ qc.generate_greeting()
+
+ assert seen["prompt"].startswith(GREETING_MARKER)
+ assert "" not in seen["prompt"]
+
+
+def test_generate_greeting_respects_greeting_tables_param(sample_df):
+ """generate_greeting() includes schema for tables named in greeting_tables."""
+ qc = QueryChat(greeting_tables=["people"])
+ qc.add_table(sample_df, "people")
+ qc.add_table(pd.DataFrame({"x": [1]}), "other")
+ seen: dict[str, str] = {}
+
+ def fake_chat(self, prompt, *args, **kwargs):
+ seen["prompt"] = prompt
+ return "Hi!"
+
+ with patch("chatlas.Chat.chat", fake_chat):
+ qc.generate_greeting()
+
+ assert seen["prompt"].startswith(GREETING_MARKER)
+ assert "" in seen["prompt"]
+ # only people schema present: people has 'name', 'age'; other has 'x'
+ assert "age" in seen["prompt"]
+ assert "x" not in seen["prompt"]
diff --git a/pkg-py/tests/test_shiny_module.py b/pkg-py/tests/test_shiny_module.py
index ff3abc78..1190ffb8 100644
--- a/pkg-py/tests/test_shiny_module.py
+++ b/pkg-py/tests/test_shiny_module.py
@@ -2,9 +2,11 @@
from __future__ import annotations
+import inspect
import os
from unittest.mock import patch
+import pandas as pd
import pytest
from shiny import ui
@@ -21,6 +23,11 @@ def set_dummy_api_key():
del os.environ["OPENAI_API_KEY"]
+@pytest.fixture
+def sample_df():
+ return pd.DataFrame({"id": [1, 2, 3], "value": [10, 20, 30]})
+
+
def _fake_chat_ui(*args, **kwargs):
"""Return a real Tag so htmltools accepts it; stash kwargs for inspection."""
_fake_chat_ui.last_kwargs = kwargs
@@ -48,3 +55,92 @@ def test_mod_ui_allow_attachments_can_be_overridden():
assert _fake_chat_ui.last_kwargs.get("allow_attachments") is False
+def _unwrap_module_server(wrapped):
+ """
+ Retrieve the original function from a @module.server-decorated callable.
+
+ Shiny's @module.server decorator stores the original function as a cell in
+ the wrapper's __closure__. We locate it by scanning the closure for callables
+ whose signature includes the expected module server parameters.
+ """
+ if wrapped.__closure__ is None:
+ return wrapped
+ for cell in wrapped.__closure__:
+ try:
+ val = cell.cell_contents
+ except ValueError:
+ continue
+ if callable(val) and "data_sources" in inspect.signature(val).parameters:
+ return val
+ return wrapped
+
+
+def test_mod_server_accepts_greeting_tables_and_categorical_threshold():
+ """mod_server must expose greeting_tables and categorical_threshold params."""
+ from querychat._shiny_module import mod_server
+
+ fn = _unwrap_module_server(mod_server)
+ params = inspect.signature(fn).parameters
+ assert "greeting_tables" in params, "mod_server missing greeting_tables param"
+ assert "categorical_threshold" in params, "mod_server missing categorical_threshold param"
+ assert params["greeting_tables"].default is None
+ assert params["categorical_threshold"].default == 20
+
+
+def test_build_greeting_prompt_imported_into_shiny_module(sample_df):
+ """
+ Checks that build_greeting_prompt is imported into _shiny_module (so the
+ module uses the shared implementation) and that calling it directly with a
+ single-table source produces a schema-embedded prompt.
+
+ Note: triggering the actual greeting reactive event requires a live Shiny
+ session, which is not available in unit tests.
+ """
+ import narwhals.stable.v1 as nw
+ import querychat._shiny_module as shiny_module
+ from querychat._datasource import DataFrameSource
+ from querychat._querychat_core import GREETING_MARKER, build_greeting_prompt
+
+ # Verify build_greeting_prompt is accessible in _shiny_module
+ assert hasattr(shiny_module, "build_greeting_prompt"), (
+ "build_greeting_prompt must be imported into _shiny_module"
+ )
+ assert shiny_module.build_greeting_prompt is build_greeting_prompt
+
+ # Verify single-table auto-mode produces a schema-containing prompt
+ source = DataFrameSource(nw.from_native(sample_df), "sample")
+ prompt = build_greeting_prompt(
+ data_sources={"sample": source},
+ categorical_threshold=20,
+ greeting_tables=None,
+ )
+ assert prompt.startswith(GREETING_MARKER)
+ assert "" in prompt
+
+
+def test_build_greeting_prompt_called_with_multi_table_no_greeting_tables():
+ """
+ With two tables and greeting_tables=None, build_greeting_prompt omits schema.
+
+ This mirrors the assertion from the brief: multi-table with no greeting_tables
+ → GREETING_MARKER prefix present, absent.
+ """
+ import narwhals.stable.v1 as nw
+ from querychat._datasource import DataFrameSource
+ from querychat._querychat_core import GREETING_MARKER, build_greeting_prompt
+
+ sources = {
+ "orders": DataFrameSource(nw.from_native(pd.DataFrame({"id": [1]})), "orders"),
+ "customers": DataFrameSource(
+ nw.from_native(pd.DataFrame({"id": [1]})), "customers"
+ ),
+ }
+ prompt = build_greeting_prompt(
+ data_sources=sources,
+ categorical_threshold=20,
+ greeting_tables=None,
+ )
+ assert prompt.startswith(GREETING_MARKER)
+ assert "" not in prompt
+
+
diff --git a/pkg-r/R/QueryChat.R b/pkg-r/R/QueryChat.R
index 576d9abc..efdff4b4 100644
--- a/pkg-r/R/QueryChat.R
+++ b/pkg-r/R/QueryChat.R
@@ -229,6 +229,8 @@ QueryChat <- R6::R6Class(
public = list(
#' @field greeting The greeting message displayed to users.
greeting = NULL,
+ #' @field greeting_tables Controls which tables' schema to include in greeting.
+ greeting_tables = NULL,
#' @field id ID for the QueryChat instance.
id = NULL,
#' @field id_override Whether the ID was explicitly set by the user.
@@ -295,6 +297,7 @@ QueryChat <- R6::R6Class(
...,
id = NULL,
greeting = NULL,
+ greeting_tables = NULL,
client = NULL,
tools = c("filter", "query"),
data_description = NULL,
@@ -342,6 +345,7 @@ QueryChat <- R6::R6Class(
greeting <- read_utf8(greeting)
}
self$greeting <- greeting
+ self$greeting_tables <- greeting_tables
# Track whether id was explicitly set
self$id_override <- id
@@ -900,6 +904,8 @@ QueryChat <- R6::R6Class(
greeting = self$greeting,
client = create_session_client,
tools = self$tools,
+ greeting_tables = self$greeting_tables,
+ categorical_threshold = private$.categorical_threshold,
enable_bookmarking = enable_bookmarking
)
result
@@ -914,7 +920,12 @@ QueryChat <- R6::R6Class(
generate_greeting = function(echo = c("none", "output")) {
private$require_initialized("$generate_greeting")
chat <- private$create_session_client()
- as.character(chat$chat(GREETING_PROMPT, echo = echo))
+ greeting_prompt <- build_greeting_prompt(
+ private$.data_sources,
+ private$.categorical_threshold,
+ self$greeting_tables
+ )
+ as.character(chat$chat(greeting_prompt, echo = echo))
},
#' @description
diff --git a/pkg-r/R/querychat_module.R b/pkg-r/R/querychat_module.R
index cb1c7ff8..9857b88a 100644
--- a/pkg-r/R/querychat_module.R
+++ b/pkg-r/R/querychat_module.R
@@ -43,6 +43,8 @@ mod_server <- function(
greeting,
client,
tools,
+ greeting_tables = NULL,
+ categorical_threshold = 20,
enable_bookmarking = FALSE
) {
shiny::moduleServer(id, function(input, output, session) {
@@ -150,13 +152,20 @@ mod_server <- function(
)
return()
}
- cli::cli_warn(c(
- "No {.arg greeting} provided to {.fn QueryChat}. Using the LLM {.arg client} to generate one now.",
- "i" = "For faster startup, lower cost, and determinism, consider providing a {.arg greeting} to {.fn QueryChat}.",
- "i" = "You can use your {.help querychat::QueryChat} object's {.fn $generate_greeting} method to generate a greeting."
- ))
+ cli::cli_warn(
+ c(
+ "No {.arg greeting} provided to {.fn QueryChat}. Using the LLM {.arg client} to generate one now.",
+ "i" = "For faster startup, lower cost, and determinism, consider providing a {.arg greeting} to {.fn QueryChat}.",
+ "i" = "You can use your {.help querychat::QueryChat} object's {.fn $generate_greeting} method to generate a greeting."
+ )
+ )
greeting_client <- client(tools = NULL)
- stream <- greeting_client$stream_async(GREETING_PROMPT)
+ greeting_prompt <- build_greeting_prompt(
+ data_sources,
+ categorical_threshold,
+ greeting_tables
+ )
+ stream <- greeting_client$stream_async(greeting_prompt)
p <- shinychat::chat_set_greeting(
"chat",
shinychat::chat_greeting(stream, dismissible = FALSE)
@@ -322,13 +331,68 @@ mod_server <- function(
})
}
-# TODO: Make this dependent on enabled tools
-GREETING_PROMPT <- paste(
+GREETING_MARKER <- "\n"
+
+GREETING_BASE_TEXT <- paste(
"Please give me a friendly greeting.",
"Include a few sample suggestions grouped under ##### headings,",
"using the suggestion card format from your instructions."
)
+GREETING_EXPLORE_ADDENDUM <- paste(
+ "Include at least one suggestion encouraging the user to explore what",
+ "data and questions are available - for example, asking which tables or",
+ "columns exist, or what kinds of analysis are possible."
+)
+
+# TODO: Make this dependent on enabled tools
+GREETING_PROMPT <- GREETING_BASE_TEXT
+
+build_greeting_prompt <- function(
+ data_sources,
+ categorical_threshold,
+ greeting_tables
+) {
+ table_names <- resolve_greeting_tables(data_sources, greeting_tables)
+
+ if (length(table_names) > 0) {
+ multi <- length(table_names) > 1
+ schema_sections <- character()
+ for (name in table_names) {
+ source <- data_sources[[name]]
+ schema <- source$get_schema(categorical_threshold = categorical_threshold)
+ section <- if (multi) paste0("Table '", name, "':\n", schema) else schema
+ schema_sections <- c(schema_sections, section)
+ }
+ schema_block <- paste(schema_sections, collapse = "\n\n")
+ body <- paste0(
+ "\n",
+ schema_block,
+ "\n\n\n",
+ GREETING_BASE_TEXT
+ )
+ } else {
+ body <- paste(GREETING_BASE_TEXT, GREETING_EXPLORE_ADDENDUM)
+ }
+
+ paste0(GREETING_MARKER, body)
+}
+
+resolve_greeting_tables <- function(data_sources, greeting_tables) {
+ all_names <- names(data_sources)
+ if (isTRUE(greeting_tables)) {
+ return(all_names)
+ }
+ if (identical(greeting_tables, FALSE)) {
+ return(character())
+ }
+ if (is.character(greeting_tables)) {
+ return(intersect(greeting_tables, all_names))
+ }
+ # Auto (NULL): single table -> include; multi-table -> none
+ if (length(all_names) == 1L) all_names else character()
+}
+
# A list of records (named lists) bookmarked to the URL comes back from Shiny's
# decoder as a data.frame, because jsonlite simplifies a JSON array of objects
# (simplifyDataFrame = TRUE). Rebuild the list-of-lists shape row by row,
@@ -339,13 +403,15 @@ restore_record_list <- function(x) {
return(NULL)
}
if (is.data.frame(x)) {
- return(lapply(seq_len(nrow(x)), function(i) {
- row <- as.list(x[i, , drop = FALSE])
- row <- lapply(row, function(v) {
- if (length(v) == 1 && is.na(v)) NULL else v
+ return(
+ lapply(seq_len(nrow(x)), function(i) {
+ row <- as.list(x[i, , drop = FALSE])
+ row <- lapply(row, function(v) {
+ if (length(v) == 1 && is.na(v)) NULL else v
+ })
+ row[!vapply(row, is.null, logical(1))]
})
- row[!vapply(row, is.null, logical(1))]
- }))
+ )
}
as.list(x)
}
diff --git a/pkg-r/man/QueryChat.Rd b/pkg-r/man/QueryChat.Rd
index c6f759de..9a61ea7f 100644
--- a/pkg-r/man/QueryChat.Rd
+++ b/pkg-r/man/QueryChat.Rd
@@ -98,6 +98,8 @@ qc <- QueryChat$new(con, "mtcars")
\describe{
\item{\code{greeting}}{The greeting message displayed to users.}
+ \item{\code{greeting_tables}}{Controls which tables' schema to include in greeting.}
+
\item{\code{id}}{ID for the QueryChat instance.}
\item{\code{id_override}}{Whether the ID was explicitly set by the user.}
@@ -148,6 +150,7 @@ qc <- QueryChat$new(con, "mtcars")
...,
id = NULL,
greeting = NULL,
+ greeting_tables = NULL,
client = NULL,
tools = c("filter", "query"),
data_description = NULL,
@@ -180,6 +183,7 @@ character string (in Markdown format) or a file path. If not provided,
a greeting will be generated at the start of each conversation using
the LLM, which adds latency and cost. Use \verb{$generate_greeting()} to
create a greeting to save and reuse.}
+ \item{\code{greeting_tables}}{Controls which tables' schema to include in greeting.}
\item{\code{client}}{Optional chat client. Can be:
\itemize{
\item An \link[ellmer:Chat]{ellmer::Chat} object
diff --git a/pkg-r/tests/testthat/test-QueryChat.R b/pkg-r/tests/testthat/test-QueryChat.R
index 99313737..fdc6578d 100644
--- a/pkg-r/tests/testthat/test-QueryChat.R
+++ b/pkg-r/tests/testthat/test-QueryChat.R
@@ -657,10 +657,12 @@ describe("QueryChat$client()", {
})
test_that("QueryChat$generate_greeting() generates a greeting using the LLM client", {
+ skip_if_no_dataframe_engine()
client <- mock_ellmer_chat_client(
public = list(
chat = function(message, ...) {
- expect_equal(message, GREETING_PROMPT)
+ expect_true(startsWith(message, querychat:::GREETING_MARKER))
+ expect_match(message, "") # single table → schema included
"Welcome! This is a mock response for testing."
}
)
@@ -668,7 +670,6 @@ test_that("QueryChat$generate_greeting() generates a greeting using the LLM clie
test_df <- new_test_df()
- # Create a mock client that returns a fixed greeting
qc <- QueryChat$new(test_df, client = client)
withr::defer(qc$cleanup())
@@ -676,6 +677,72 @@ test_that("QueryChat$generate_greeting() generates a greeting using the LLM clie
expect_equal(greeting, "Welcome! This is a mock response for testing.")
})
+test_that("generate_greeting() sends schema-embedded prompt for single table", {
+ skip_if_no_dataframe_engine()
+ client <- mock_ellmer_chat_client(
+ public = list(
+ chat = function(message, ...) {
+ expect_true(startsWith(message, querychat:::GREETING_MARKER))
+ expect_match(message, "")
+ "Hello!"
+ }
+ )
+ )
+ test_df <- new_test_df()
+ qc <- QueryChat$new(test_df, client = client)
+ withr::defer(qc$cleanup())
+
+ greeting <- qc$generate_greeting()
+ expect_equal(greeting, "Hello!")
+})
+
+test_that("generate_greeting() omits schema for multi-table with no greeting_tables", {
+ skip_if_no_dataframe_engine()
+ client <- mock_ellmer_chat_client(
+ public = list(
+ chat = function(message, ...) {
+ expect_true(startsWith(message, querychat:::GREETING_MARKER))
+ expect_false(grepl("", message))
+ "Hello!"
+ }
+ )
+ )
+ qc <- QueryChat$new(
+ data_source = NULL,
+ table_name = "placeholder",
+ client = client
+ )
+ withr::defer(qc$cleanup())
+ qc$add_table(new_test_df(), "t1")
+ qc$add_table(data.frame(x = 1), "t2")
+
+ qc$generate_greeting()
+})
+
+test_that("generate_greeting() includes schema for tables in greeting_tables", {
+ skip_if_no_dataframe_engine()
+ client <- mock_ellmer_chat_client(
+ public = list(
+ chat = function(message, ...) {
+ expect_match(message, "amount")
+ expect_false(grepl("\\bname\\b", message))
+ "Hello!"
+ }
+ )
+ )
+ qc <- QueryChat$new(
+ data_source = NULL,
+ table_name = "placeholder",
+ client = client,
+ greeting_tables = "t1"
+ )
+ withr::defer(qc$cleanup())
+ qc$add_table(data.frame(amount = c(1, 2)), "t1")
+ qc$add_table(data.frame(name = c("A")), "t2")
+
+ qc$generate_greeting()
+})
+
test_that("QueryChat$server() errors when called outside Shiny context", {
withr::local_envvar(OPENAI_API_KEY = "boop")
diff --git a/pkg-r/tests/testthat/test-querychat_module.R b/pkg-r/tests/testthat/test-querychat_module.R
index 040db90f..157944ca 100644
--- a/pkg-r/tests/testthat/test-querychat_module.R
+++ b/pkg-r/tests/testthat/test-querychat_module.R
@@ -348,3 +348,41 @@ test_that("restored viz widgets survive a second bookmark cycle", {
}
)
})
+
+test_that("build_greeting_prompt includes schema for single table", {
+ skip_if_no_dataframe_engine()
+ source <- local_data_frame_source(new_test_df(), "people")
+ data_sources <- list(people = source)
+
+ prompt <- querychat:::build_greeting_prompt(data_sources, 20, NULL)
+
+ expect_true(startsWith(prompt, querychat:::GREETING_MARKER))
+ expect_match(prompt, "")
+})
+
+test_that("build_greeting_prompt omits schema for multi-table with no signal", {
+ skip_if_no_dataframe_engine()
+ s1 <- local_data_frame_source(new_test_df(), "t1")
+ s2 <- local_data_frame_source(new_test_df(), "t2")
+ data_sources <- list(t1 = s1, t2 = s2)
+
+ prompt <- querychat:::build_greeting_prompt(data_sources, 20, NULL)
+
+ expect_true(startsWith(prompt, querychat:::GREETING_MARKER))
+ expect_false(grepl("", prompt))
+ # Should mention exploration
+ expect_match(prompt, "available|explore", ignore.case = TRUE)
+})
+
+test_that("build_greeting_prompt includes schema for explicit greeting_tables", {
+ skip_if_no_dataframe_engine()
+ s1 <- local_data_frame_source(data.frame(amount = c(10, 20)), "orders")
+ s2 <- local_data_frame_source(data.frame(name = c("A")), "customers")
+ data_sources <- list(orders = s1, customers = s2)
+
+ prompt <- querychat:::build_greeting_prompt(data_sources, 20, "orders")
+
+ expect_match(prompt, "")
+ expect_match(prompt, "amount")
+ expect_false(grepl("name", prompt))
+})