Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions pkg-py/src/querychat/_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

from ._querychat_base import TOOL_GROUPS, StateDictQueryChat
from ._querychat_core import (
GREETING_PROMPT,
AppStateDict,
build_greeting_prompt,
create_app_state,
stream_response,
)
Expand Down Expand Up @@ -98,6 +98,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
) -> None: ...

@overload
Expand All @@ -114,6 +115,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
) -> None: ...

@overload
Expand All @@ -130,6 +132,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
) -> None: ...

@overload
Expand All @@ -146,6 +149,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
) -> None: ...

@overload
Expand All @@ -162,6 +166,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
) -> None: ...

def __init__(
Expand All @@ -177,6 +182,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
):
super().__init__(
data_source,
Expand All @@ -189,6 +195,7 @@ def __init__(
categorical_threshold=categorical_threshold,
extra_instructions=extra_instructions,
prompt_template=prompt_template,
greeting_tables=greeting_tables,
)

@property
Expand Down Expand Up @@ -288,7 +295,12 @@ def initialize_greeting(state_dict: AppStateDict):

if not state.initialize_greeting_if_preset():
greeting = ""
for chunk in stream_response(state.client, GREETING_PROMPT):
greeting_prompt = build_greeting_prompt(
state.data_sources,
self._categorical_threshold,
self.greeting_tables,
)
for chunk in stream_response(state.client, greeting_prompt):
greeting += chunk
state.set_greeting(greeting)

Expand Down
11 changes: 9 additions & 2 deletions pkg-py/src/querychat/_querychat_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
validate_source_group_compatibility,
)
from ._querychat_core import (
GREETING_PROMPT,
AppState,
AppStateDict,
build_greeting_prompt,
create_app_state,
warn_multi_table_flat_accessor,
)
Expand Down Expand Up @@ -91,6 +91,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
):
self._data_dicts: list[DataDict] = _normalize_data_dicts(data_dict)

Expand All @@ -113,6 +114,7 @@ def __init__(
self._client_spec: str | chatlas.Chat | None = client
self._client_console = None

self.greeting_tables: list[str] | bool | None = greeting_tables
self._system_prompt: QueryChatSystemPrompt | None = None

if data_source is not None:
Expand Down Expand Up @@ -322,7 +324,12 @@ def generate_greeting(self, *, echo: Literal["none", "output"] = "none") -> str:
chat = create_client(self._client_spec)
if self._system_prompt is not None:
chat.system_prompt = self._system_prompt.render(self.tools)
return str(chat.chat(GREETING_PROMPT, echo=echo))
prompt = build_greeting_prompt(
self._data_sources,
self._categorical_threshold,
self.greeting_tables,
)
return str(chat.chat(prompt, echo=echo))

def console(
self,
Expand Down
79 changes: 75 additions & 4 deletions pkg-py/src/querychat/_querychat_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
from __future__ import annotations

__all__ = [
"GREETING_PROMPT",
"GREETING_MARKER",
"GREETING_PROMPT", # backward compat alias — value unchanged; callers relying on the full prompt should use build_greeting_prompt() instead
"AppState",
"AppStateDict",
"ClientFactory",
"build_greeting_prompt",
"create_app_state",
"stream_response",
"stream_response_async",
]

import warnings
from collections.abc import Callable
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, TypedDict, Union

Expand All @@ -23,11 +25,23 @@

from .tools import UpdateDashboardData

GREETING_PROMPT: str = (
GREETING_MARKER: str = "<!-- querychat:greeting -->\n"
"""Sentinel prepended to every greeting prompt turn; used to skip it in display."""

GREETING_BASE_TEXT: str = (
"Please give me a friendly greeting. "
"Include a few sample suggestions grouped under ##### headings, "
"using the suggestion card format from your instructions."
)

GREETING_EXPLORE_ADDENDUM: str = (
" Include at least one suggestion encouraging the user to explore what "
"data and questions are available — for example, asking which tables or "
"columns exist, or what kinds of analysis are possible."
)

# Keep the old name as an alias so any external code that imported it still works.
GREETING_PROMPT: str = GREETING_BASE_TEXT
"""Prompt used to generate the initial greeting message."""

if TYPE_CHECKING:
Expand All @@ -46,6 +60,63 @@
"""Factory that creates a Chat client with update_dashboard and reset_dashboard callbacks."""


def build_greeting_prompt(
data_sources: Mapping[str, DataSource],
categorical_threshold: int,
greeting_tables: list[str] | bool | None, # noqa: FBT001 — bool is a sentinel, not a flag
) -> str:
"""
Build a greeting prompt, optionally embedding table schema.

Parameters
----------
data_sources
All registered data sources keyed by table name.
categorical_threshold
Passed to ``get_schema`` for each included table.
greeting_tables
Which tables to include schema for:
- ``None``: auto — single table gets full schema; multi-table gets none.
- ``True``: all tables.
- ``False``: no tables (explorer hint only).
- list of names: only the named tables.

"""
table_names = resolve_greeting_tables(data_sources, greeting_tables)

if table_names:
schema_sections: list[str] = []
multi = len(table_names) > 1
for name in table_names:
source = data_sources[name]
schema = source.get_schema(categorical_threshold=categorical_threshold)
section = f"Table '{name}':\n{schema}" if multi else schema
schema_sections.append(section)
schema_block = "\n\n".join(schema_sections)
body = f"<schema>\n{schema_block}\n</schema>\n\n{GREETING_BASE_TEXT}"
else:
body = GREETING_BASE_TEXT + GREETING_EXPLORE_ADDENDUM

return f"{GREETING_MARKER}{body}"


def resolve_greeting_tables(
data_sources: Mapping[str, DataSource],
greeting_tables: list[str] | bool | None, # noqa: FBT001 — bool is a sentinel, not a flag
) -> list[str]:
"""Return the list of table names whose schema to include in the greeting."""
if greeting_tables is True:
return list(data_sources.keys())
if greeting_tables is False:
return []
if isinstance(greeting_tables, list):
return [t for t in greeting_tables if t in data_sources]
# Auto (None): single table → include schema; multi-table → no schema
if len(data_sources) == 1:
return list(data_sources.keys())
return []


def warn_multi_table_flat_accessor(
accessor_name: str, primary_table: str, table_list: str, stacklevel: int = 3
) -> None:
Expand Down Expand Up @@ -246,7 +317,7 @@ def get_display_messages(self) -> list[DisplayMessage]:
if text_parts:
text = "\n\n".join(text_parts)
# Skip the greeting prompt - it's an internal message
if turn.role == "user" and text == GREETING_PROMPT:
if turn.role == "user" and text.startswith(GREETING_MARKER):
continue
messages.append({"role": turn.role, "content": text})

Expand Down
20 changes: 20 additions & 0 deletions pkg-py/src/querychat/_shiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
) -> None: ...

@overload
Expand All @@ -177,6 +178,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
) -> None: ...

@overload
Expand All @@ -194,6 +196,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
) -> None: ...

@overload
Expand All @@ -211,6 +214,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
) -> None: ...

@overload
Expand All @@ -228,6 +232,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
) -> None: ...

def __init__(
Expand All @@ -244,6 +249,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
):
super().__init__(
data_source,
Expand All @@ -256,6 +262,7 @@ def __init__(
categorical_threshold=categorical_threshold,
extra_instructions=extra_instructions,
prompt_template=prompt_template,
greeting_tables=greeting_tables,
)
self.id = id or (f"querychat_{table_name}" if table_name else "querychat")

Expand Down Expand Up @@ -330,6 +337,8 @@ def app_server(input: Inputs, output: Outputs, session: Session):
client=self._create_session_client,
enable_bookmarking=enable_bookmarking,
tools=self.tools,
greeting_tables=self.greeting_tables,
categorical_threshold=self._categorical_threshold,
)

@reactive.calc
Expand Down Expand Up @@ -541,6 +550,8 @@ def create_session_client(**kwargs) -> chatlas.Chat:
client=create_session_client,
enable_bookmarking=enable_bookmarking,
tools=self.tools,
greeting_tables=self.greeting_tables,
categorical_threshold=self._categorical_threshold,
)


Expand Down Expand Up @@ -663,6 +674,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
enable_bookmarking: Literal["auto", True, False] = "auto",
) -> None: ...

Expand All @@ -681,6 +693,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
enable_bookmarking: Literal["auto", True, False] = "auto",
) -> None: ...

Expand All @@ -699,6 +712,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
enable_bookmarking: Literal["auto", True, False] = "auto",
) -> None: ...

Expand All @@ -717,6 +731,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
enable_bookmarking: Literal["auto", True, False] = "auto",
) -> None: ...

Expand All @@ -735,6 +750,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
enable_bookmarking: Literal["auto", True, False] = "auto",
) -> None: ...

Expand All @@ -752,6 +768,7 @@ def __init__(
prompt_template: Optional[str | Path] = None,
categorical_threshold: int = 20,
data_description: Optional[str | Path] = None,
greeting_tables: list[str] | bool | None = None,
enable_bookmarking: Literal["auto", True, False] = "auto",
):
# Sanity check: Express should always have a (stub/real) session
Expand All @@ -773,6 +790,7 @@ def __init__(
categorical_threshold=categorical_threshold,
extra_instructions=extra_instructions,
prompt_template=prompt_template,
greeting_tables=greeting_tables,
)
self.id = id or (f"querychat_{table_name}" if table_name else "querychat")

Expand Down Expand Up @@ -821,6 +839,8 @@ def _ensure_server_started(self) -> None:
client=self._create_session_client,
enable_bookmarking=self._enable_bookmarking,
tools=self.tools,
greeting_tables=self.greeting_tables,
categorical_threshold=self._categorical_threshold,
)

def sidebar(
Expand Down
Loading
Loading