Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions sqlit/domains/query/completion/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
get_all_functions,
get_all_keywords,
get_current_word,
get_identifier_namespaces,
get_last_token_info,
get_names_for_namespace,
is_inside_string,
remove_comments,
remove_string_literals,
Expand Down Expand Up @@ -256,8 +258,11 @@ def get_completions(

# Schema.table prefix → suggest tables after schema name
# Pattern: FROM/JOIN schema. or schema.partial
if re.search(r"\b(FROM|JOIN)\s+\w+\.\w*$", clean_before, re.IGNORECASE):
return fuzzy_match(current_word, tables)
namespace_match = re.search(r"\b(?:FROM|JOIN)\s+(\w+)\.\w*$", clean_before, re.IGNORECASE)
if namespace_match:
namespace = namespace_match.group(1)
names = get_names_for_namespace(namespace, tables)
return fuzzy_match(current_word, names if names else tables)

# ANY/ALL/SOME ( → suggest SELECT for subquery
if re.search(r"\b(ANY|ALL|SOME)\s*\(\s*\w*$", clean_before, re.IGNORECASE):
Expand Down Expand Up @@ -326,6 +331,7 @@ def get_completions(

for suggestion in suggestions:
if suggestion.type == SuggestionType.TABLE:
results.extend(get_identifier_namespaces(tables))
results.extend(tables)
results.extend(cte_names)

Expand Down
78 changes: 78 additions & 0 deletions sqlit/domains/query/completion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,84 @@ def fuzzy_match(text: str, candidates: list[str], max_results: int = 50) -> list
return [r[2] for r in results[:max_results]]


def split_identifier_parts(identifier: str) -> list[str]:
"""Split a possibly quoted qualified identifier into unquoted parts."""
parts: list[str] = []
current: list[str] = []
quote: str | None = None
bracketed = False

for char in identifier:
if bracketed:
if char == "]":
bracketed = False
else:
current.append(char)
continue

if quote:
if char == quote:
quote = None
else:
current.append(char)
continue

if char == "[":
bracketed = True
continue
if char in {'"', "`"}:
quote = char
continue
if char == ".":
part = "".join(current).strip()
if part:
parts.append(part)
current = []
continue

current.append(char)

part = "".join(current).strip()
if part:
parts.append(part)
return parts


def get_identifier_namespaces(identifiers: list[str]) -> list[str]:
"""Return first qualifier parts from schema/database-qualified names."""
namespaces: list[str] = []
seen: set[str] = set()
for identifier in identifiers:
parts = split_identifier_parts(identifier)
if len(parts) < 2:
continue
namespace = parts[0]
key = namespace.lower()
if key not in seen:
seen.add(key)
namespaces.append(namespace)
return namespaces


def get_names_for_namespace(namespace: str, identifiers: list[str]) -> list[str]:
"""Return final identifier parts that belong to the requested namespace."""
namespace_lower = namespace.lower()
names: list[str] = []
seen: set[str] = set()

for identifier in identifiers:
parts = split_identifier_parts(identifier)
if len(parts) < 2 or parts[0].lower() != namespace_lower:
continue
name = parts[-1]
key = name.lower()
if key not in seen:
seen.add(key)
names.append(name)

return names


def extract_table_refs(sql: str) -> list[TableRef]:
"""Extract table references and aliases from SQL.

Expand Down
20 changes: 10 additions & 10 deletions sqlit/domains/query/ui/mixins/autocomplete_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,12 +403,12 @@ def work() -> None:
try:
entries: list[tuple[str, list[tuple[str, tuple[str, str, str | None]]]]] = []
for schema_name, table_name in tables:
display_name = dialect.format_table_name(schema_name, table_name)
if single_db:
full_name = table_name
full_name = display_name
else:
full_name = dialect.qualified_name(database, schema_name, table_name)

display_name = dialect.format_table_name(schema_name, table_name)
metadata = [
(display_name.lower(), (schema_name, table_name, database)),
(table_name.lower(), (schema_name, table_name, database)),
Expand Down Expand Up @@ -486,12 +486,12 @@ def work() -> None:
try:
entries: list[tuple[str, list[tuple[str, tuple[str, str, str | None]]]]] = []
for schema_name, view_name in views:
display_name = dialect.format_table_name(schema_name, view_name)
if single_db:
full_name = view_name
full_name = display_name
else:
full_name = dialect.qualified_name(database, schema_name, view_name)

display_name = dialect.format_table_name(schema_name, view_name)
metadata = [
(display_name.lower(), (schema_name, view_name, database)),
(view_name.lower(), (schema_name, view_name, database)),
Expand Down Expand Up @@ -799,16 +799,16 @@ async def run_db_call(fn: Any, *args: Any, **kwargs: Any) -> Any:
db_arg = self._get_metadata_db_arg(database)
tables = await run_db_call(inspector.get_tables, connection, db_arg)
for schema_name, table_name in tables:
display_name = dialect.format_table_name(schema_name, table_name)
# Use simple name if we have a default database, full qualifier otherwise
if len(databases) == 1:
# Single database - use simple table name
schema_cache["tables"].append(table_name)
# Single database - omit only the provider's default schema
schema_cache["tables"].append(display_name)
else:
# Multiple databases - use qualified identifier
full_name = dialect.qualified_name(database, schema_name, table_name)
schema_cache["tables"].append(full_name)
# Keep metadata for column loading (multiple keys for flexible lookup)
display_name = dialect.format_table_name(schema_name, table_name)
table_metadata[display_name.lower()] = (schema_name, table_name, database)
table_metadata[table_name.lower()] = (schema_name, table_name, database)
if database:
Expand All @@ -820,16 +820,16 @@ async def run_db_call(fn: Any, *args: Any, **kwargs: Any) -> Any:
# Get views
views = await run_db_call(inspector.get_views, connection, db_arg)
for schema_name, view_name in views:
display_name = dialect.format_table_name(schema_name, view_name)
# Use simple name if we have a default database, full qualifier otherwise
if len(databases) == 1:
# Single database - use simple view name
schema_cache["views"].append(view_name)
# Single database - omit only the provider's default schema
schema_cache["views"].append(display_name)
else:
# Multiple databases - use qualified identifier
full_name = dialect.qualified_name(database, schema_name, view_name)
schema_cache["views"].append(full_name)
# Keep metadata for column loading (multiple keys for flexible lookup)
display_name = dialect.format_table_name(schema_name, view_name)
table_metadata[display_name.lower()] = (schema_name, view_name, database)
table_metadata[view_name.lower()] = (schema_name, view_name, database)
if database:
Expand Down
158 changes: 158 additions & 0 deletions tests/integration/test_autocomplete_postgresql_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Integration tests for PostgreSQL schema-qualified autocomplete."""

from __future__ import annotations

import os
import tempfile

import pytest

from sqlit.domains.shell.app.main import SSMSTUI
from tests.fixtures.postgres import (
POSTGRES_HOST,
POSTGRES_PASSWORD,
POSTGRES_PORT,
POSTGRES_USER,
)
from tests.helpers import ConnectionConfig
from tests.integration.browsing_base import wait_for_condition


@pytest.fixture
def temp_config_dir():
"""Create a temporary config directory for tests."""
with tempfile.TemporaryDirectory(prefix="sqlit-test-") as tmpdir:
original = os.environ.get("SQLIT_CONFIG_DIR")
os.environ["SQLIT_CONFIG_DIR"] = tmpdir
yield tmpdir
if original:
os.environ["SQLIT_CONFIG_DIR"] = original
else:
os.environ.pop("SQLIT_CONFIG_DIR", None)


@pytest.fixture
def postgres_schema_table(postgres_server_ready: bool, postgres_db: str):
"""Create a non-default schema table that requires schema qualification."""
if not postgres_server_ready:
pytest.skip("PostgreSQL is not available")

try:
import psycopg2
except ImportError:
pytest.skip("psycopg2 is not installed")

conn = psycopg2.connect(
host=POSTGRES_HOST,
port=POSTGRES_PORT,
database=postgres_db,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
connect_timeout=10,
)
conn.autocommit = True
cursor = conn.cursor()
cursor.execute("DROP SCHEMA IF EXISTS test CASCADE")
cursor.execute("CREATE SCHEMA test")
cursor.execute("""
CREATE TABLE test.hello_world (
id INTEGER PRIMARY KEY,
greeting TEXT NOT NULL
)
""")
cursor.execute("INSERT INTO test.hello_world (id, greeting) VALUES (1, 'hello')")
conn.close()

yield

conn = psycopg2.connect(
host=POSTGRES_HOST,
port=POSTGRES_PORT,
database=postgres_db,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
connect_timeout=10,
)
conn.autocommit = True
cursor = conn.cursor()
cursor.execute("DROP SCHEMA IF EXISTS test CASCADE")
conn.close()


def _suggestions(app: SSMSTUI, sql: str) -> list[str]:
return app._get_autocomplete_suggestions(sql, len(sql))


@pytest.mark.asyncio
async def test_postgresql_schema_table_autocomplete(
postgres_server_ready: bool,
postgres_db: str,
postgres_schema_table,
temp_config_dir: str,
) -> None:
"""Autocomplete should offer schema names and schema-local tables."""
if not postgres_server_ready:
pytest.skip("PostgreSQL is not available")

config = ConnectionConfig(
name="test-postgres-schema-autocomplete",
db_type="postgresql",
server=POSTGRES_HOST,
port=str(POSTGRES_PORT),
database=postgres_db,
username=POSTGRES_USER,
password=POSTGRES_PASSWORD,
)

app = SSMSTUI()
async with app.run_test(size=(120, 40)) as pilot:
await pilot.pause(0.1)

app.connections = [config]
app.refresh_tree()
await wait_for_condition(
pilot,
lambda: len(app.object_tree.root.children) > 0,
timeout_seconds=5.0,
description="tree to be populated with connections",
)

app.connect_to_server(config)
await wait_for_condition(
pilot,
lambda: app.current_connection is not None,
timeout_seconds=15.0,
description="connection to be established",
)

load_schema_async = getattr(app, "_load_schema_cache_async", None)
if callable(load_schema_async):
await load_schema_async()

await wait_for_condition(
pilot,
lambda: "test.hello_world" in getattr(app, "_table_metadata", {}),
timeout_seconds=20.0,
description="schema-qualified table metadata to be loaded",
)

from_schema_prefix = _suggestions(app, "SELECT * FROM tes")
assert "test" in from_schema_prefix
assert "hello_world" not in from_schema_prefix

from_schema_dot = _suggestions(app, "SELECT * FROM test.")
assert "hello_world" in from_schema_dot
assert "Loading..." not in from_schema_dot

alias_columns = _suggestions(app, "SELECT * FROM test.hello_world h WHERE h.")
if alias_columns == ["Loading..."]:
await wait_for_condition(
pilot,
lambda: bool(app._schema_cache.get("columns", {}).get("test.hello_world")),
timeout_seconds=10.0,
description="schema-qualified columns to load",
)
alias_columns = _suggestions(app, "SELECT * FROM test.hello_world h WHERE h.")

assert {"id", "greeting"}.issubset({item.lower() for item in alias_columns})
assert "Loading..." not in alias_columns
Loading
Loading