diff --git a/sqlit/domains/query/completion/completion.py b/sqlit/domains/query/completion/completion.py index a923e45e..c60cf00b 100644 --- a/sqlit/domains/query/completion/completion.py +++ b/sqlit/domains/query/completion/completion.py @@ -21,7 +21,9 @@ get_all_functions, get_all_keywords, get_current_word, + get_identifier_namespaces, get_last_token_info, + get_names_for_namespace, is_inside_string, remove_comments, remove_string_literals, @@ -256,8 +258,11 @@ def get_completions( # Schema.table prefix → suggest tables after schema name # Pattern: FROM/JOIN schema. or schema.partial - if re.search(r"\b(FROM|JOIN)\s+\w+\.\w*$", clean_before, re.IGNORECASE): - return fuzzy_match(current_word, tables) + namespace_match = re.search(r"\b(?:FROM|JOIN)\s+(\w+)\.\w*$", clean_before, re.IGNORECASE) + if namespace_match: + namespace = namespace_match.group(1) + names = get_names_for_namespace(namespace, tables) + return fuzzy_match(current_word, names if names else tables) # ANY/ALL/SOME ( → suggest SELECT for subquery if re.search(r"\b(ANY|ALL|SOME)\s*\(\s*\w*$", clean_before, re.IGNORECASE): @@ -326,6 +331,7 @@ def get_completions( for suggestion in suggestions: if suggestion.type == SuggestionType.TABLE: + results.extend(get_identifier_namespaces(tables)) results.extend(tables) results.extend(cte_names) diff --git a/sqlit/domains/query/completion/core.py b/sqlit/domains/query/completion/core.py index 12039c41..8b18a563 100644 --- a/sqlit/domains/query/completion/core.py +++ b/sqlit/domains/query/completion/core.py @@ -171,6 +171,84 @@ def fuzzy_match(text: str, candidates: list[str], max_results: int = 50) -> list return [r[2] for r in results[:max_results]] +def split_identifier_parts(identifier: str) -> list[str]: + """Split a possibly quoted qualified identifier into unquoted parts.""" + parts: list[str] = [] + current: list[str] = [] + quote: str | None = None + bracketed = False + + for char in identifier: + if bracketed: + if char == "]": + bracketed = False + else: + current.append(char) + continue + + if quote: + if char == quote: + quote = None + else: + current.append(char) + continue + + if char == "[": + bracketed = True + continue + if char in {'"', "`"}: + quote = char + continue + if char == ".": + part = "".join(current).strip() + if part: + parts.append(part) + current = [] + continue + + current.append(char) + + part = "".join(current).strip() + if part: + parts.append(part) + return parts + + +def get_identifier_namespaces(identifiers: list[str]) -> list[str]: + """Return first qualifier parts from schema/database-qualified names.""" + namespaces: list[str] = [] + seen: set[str] = set() + for identifier in identifiers: + parts = split_identifier_parts(identifier) + if len(parts) < 2: + continue + namespace = parts[0] + key = namespace.lower() + if key not in seen: + seen.add(key) + namespaces.append(namespace) + return namespaces + + +def get_names_for_namespace(namespace: str, identifiers: list[str]) -> list[str]: + """Return final identifier parts that belong to the requested namespace.""" + namespace_lower = namespace.lower() + names: list[str] = [] + seen: set[str] = set() + + for identifier in identifiers: + parts = split_identifier_parts(identifier) + if len(parts) < 2 or parts[0].lower() != namespace_lower: + continue + name = parts[-1] + key = name.lower() + if key not in seen: + seen.add(key) + names.append(name) + + return names + + def extract_table_refs(sql: str) -> list[TableRef]: """Extract table references and aliases from SQL. diff --git a/sqlit/domains/query/ui/mixins/autocomplete_schema.py b/sqlit/domains/query/ui/mixins/autocomplete_schema.py index bc8b7f4b..1da9f2b1 100644 --- a/sqlit/domains/query/ui/mixins/autocomplete_schema.py +++ b/sqlit/domains/query/ui/mixins/autocomplete_schema.py @@ -403,12 +403,12 @@ def work() -> None: try: entries: list[tuple[str, list[tuple[str, tuple[str, str, str | None]]]]] = [] for schema_name, table_name in tables: + display_name = dialect.format_table_name(schema_name, table_name) if single_db: - full_name = table_name + full_name = display_name else: full_name = dialect.qualified_name(database, schema_name, table_name) - display_name = dialect.format_table_name(schema_name, table_name) metadata = [ (display_name.lower(), (schema_name, table_name, database)), (table_name.lower(), (schema_name, table_name, database)), @@ -486,12 +486,12 @@ def work() -> None: try: entries: list[tuple[str, list[tuple[str, tuple[str, str, str | None]]]]] = [] for schema_name, view_name in views: + display_name = dialect.format_table_name(schema_name, view_name) if single_db: - full_name = view_name + full_name = display_name else: full_name = dialect.qualified_name(database, schema_name, view_name) - display_name = dialect.format_table_name(schema_name, view_name) metadata = [ (display_name.lower(), (schema_name, view_name, database)), (view_name.lower(), (schema_name, view_name, database)), @@ -799,16 +799,16 @@ async def run_db_call(fn: Any, *args: Any, **kwargs: Any) -> Any: db_arg = self._get_metadata_db_arg(database) tables = await run_db_call(inspector.get_tables, connection, db_arg) for schema_name, table_name in tables: + display_name = dialect.format_table_name(schema_name, table_name) # Use simple name if we have a default database, full qualifier otherwise if len(databases) == 1: - # Single database - use simple table name - schema_cache["tables"].append(table_name) + # Single database - omit only the provider's default schema + schema_cache["tables"].append(display_name) else: # Multiple databases - use qualified identifier full_name = dialect.qualified_name(database, schema_name, table_name) schema_cache["tables"].append(full_name) # Keep metadata for column loading (multiple keys for flexible lookup) - display_name = dialect.format_table_name(schema_name, table_name) table_metadata[display_name.lower()] = (schema_name, table_name, database) table_metadata[table_name.lower()] = (schema_name, table_name, database) if database: @@ -820,16 +820,16 @@ async def run_db_call(fn: Any, *args: Any, **kwargs: Any) -> Any: # Get views views = await run_db_call(inspector.get_views, connection, db_arg) for schema_name, view_name in views: + display_name = dialect.format_table_name(schema_name, view_name) # Use simple name if we have a default database, full qualifier otherwise if len(databases) == 1: - # Single database - use simple view name - schema_cache["views"].append(view_name) + # Single database - omit only the provider's default schema + schema_cache["views"].append(display_name) else: # Multiple databases - use qualified identifier full_name = dialect.qualified_name(database, schema_name, view_name) schema_cache["views"].append(full_name) # Keep metadata for column loading (multiple keys for flexible lookup) - display_name = dialect.format_table_name(schema_name, view_name) table_metadata[display_name.lower()] = (schema_name, view_name, database) table_metadata[view_name.lower()] = (schema_name, view_name, database) if database: diff --git a/tests/integration/test_autocomplete_postgresql_schema.py b/tests/integration/test_autocomplete_postgresql_schema.py new file mode 100644 index 00000000..fcc481ca --- /dev/null +++ b/tests/integration/test_autocomplete_postgresql_schema.py @@ -0,0 +1,158 @@ +"""Integration tests for PostgreSQL schema-qualified autocomplete.""" + +from __future__ import annotations + +import os +import tempfile + +import pytest + +from sqlit.domains.shell.app.main import SSMSTUI +from tests.fixtures.postgres import ( + POSTGRES_HOST, + POSTGRES_PASSWORD, + POSTGRES_PORT, + POSTGRES_USER, +) +from tests.helpers import ConnectionConfig +from tests.integration.browsing_base import wait_for_condition + + +@pytest.fixture +def temp_config_dir(): + """Create a temporary config directory for tests.""" + with tempfile.TemporaryDirectory(prefix="sqlit-test-") as tmpdir: + original = os.environ.get("SQLIT_CONFIG_DIR") + os.environ["SQLIT_CONFIG_DIR"] = tmpdir + yield tmpdir + if original: + os.environ["SQLIT_CONFIG_DIR"] = original + else: + os.environ.pop("SQLIT_CONFIG_DIR", None) + + +@pytest.fixture +def postgres_schema_table(postgres_server_ready: bool, postgres_db: str): + """Create a non-default schema table that requires schema qualification.""" + if not postgres_server_ready: + pytest.skip("PostgreSQL is not available") + + try: + import psycopg2 + except ImportError: + pytest.skip("psycopg2 is not installed") + + conn = psycopg2.connect( + host=POSTGRES_HOST, + port=POSTGRES_PORT, + database=postgres_db, + user=POSTGRES_USER, + password=POSTGRES_PASSWORD, + connect_timeout=10, + ) + conn.autocommit = True + cursor = conn.cursor() + cursor.execute("DROP SCHEMA IF EXISTS test CASCADE") + cursor.execute("CREATE SCHEMA test") + cursor.execute(""" + CREATE TABLE test.hello_world ( + id INTEGER PRIMARY KEY, + greeting TEXT NOT NULL + ) + """) + cursor.execute("INSERT INTO test.hello_world (id, greeting) VALUES (1, 'hello')") + conn.close() + + yield + + conn = psycopg2.connect( + host=POSTGRES_HOST, + port=POSTGRES_PORT, + database=postgres_db, + user=POSTGRES_USER, + password=POSTGRES_PASSWORD, + connect_timeout=10, + ) + conn.autocommit = True + cursor = conn.cursor() + cursor.execute("DROP SCHEMA IF EXISTS test CASCADE") + conn.close() + + +def _suggestions(app: SSMSTUI, sql: str) -> list[str]: + return app._get_autocomplete_suggestions(sql, len(sql)) + + +@pytest.mark.asyncio +async def test_postgresql_schema_table_autocomplete( + postgres_server_ready: bool, + postgres_db: str, + postgres_schema_table, + temp_config_dir: str, +) -> None: + """Autocomplete should offer schema names and schema-local tables.""" + if not postgres_server_ready: + pytest.skip("PostgreSQL is not available") + + config = ConnectionConfig( + name="test-postgres-schema-autocomplete", + db_type="postgresql", + server=POSTGRES_HOST, + port=str(POSTGRES_PORT), + database=postgres_db, + username=POSTGRES_USER, + password=POSTGRES_PASSWORD, + ) + + app = SSMSTUI() + async with app.run_test(size=(120, 40)) as pilot: + await pilot.pause(0.1) + + app.connections = [config] + app.refresh_tree() + await wait_for_condition( + pilot, + lambda: len(app.object_tree.root.children) > 0, + timeout_seconds=5.0, + description="tree to be populated with connections", + ) + + app.connect_to_server(config) + await wait_for_condition( + pilot, + lambda: app.current_connection is not None, + timeout_seconds=15.0, + description="connection to be established", + ) + + load_schema_async = getattr(app, "_load_schema_cache_async", None) + if callable(load_schema_async): + await load_schema_async() + + await wait_for_condition( + pilot, + lambda: "test.hello_world" in getattr(app, "_table_metadata", {}), + timeout_seconds=20.0, + description="schema-qualified table metadata to be loaded", + ) + + from_schema_prefix = _suggestions(app, "SELECT * FROM tes") + assert "test" in from_schema_prefix + assert "hello_world" not in from_schema_prefix + + from_schema_dot = _suggestions(app, "SELECT * FROM test.") + assert "hello_world" in from_schema_dot + assert "Loading..." not in from_schema_dot + + alias_columns = _suggestions(app, "SELECT * FROM test.hello_world h WHERE h.") + if alias_columns == ["Loading..."]: + await wait_for_condition( + pilot, + lambda: bool(app._schema_cache.get("columns", {}).get("test.hello_world")), + timeout_seconds=10.0, + description="schema-qualified columns to load", + ) + alias_columns = _suggestions(app, "SELECT * FROM test.hello_world h WHERE h.") + + assert {"id", "greeting"}.issubset({item.lower() for item in alias_columns}) + assert "Loading..." not in alias_columns diff --git a/tests/unit/test_autocomplete_multidb.py b/tests/unit/test_autocomplete_multidb.py index e34946cb..1e3b208a 100644 --- a/tests/unit/test_autocomplete_multidb.py +++ b/tests/unit/test_autocomplete_multidb.py @@ -15,9 +15,6 @@ from __future__ import annotations -import pytest - - # -------------------------------------------------------------------------- # Bug 1: Dialect.qualified_name # -------------------------------------------------------------------------- @@ -153,3 +150,65 @@ def test_known_table_ref_still_triggers_loading_on_first_call() -> None: assert result == ["Loading..."] assert "customers" in host.load_calls + + +# -------------------------------------------------------------------------- +# Bug 3: schema-qualified PostgreSQL autocomplete +# -------------------------------------------------------------------------- + + +def test_schema_name_completes_from_schema_qualified_tables() -> None: + """A non-default PostgreSQL schema should be offered after FROM. + + The schema cache stores non-default schema tables as schema.table. Typing + the schema prefix should offer the schema name, not only bare table names + that fail at execution time. + """ + from sqlit.domains.query.completion import get_completions + + sql = "SELECT * FROM tes" + result = get_completions( + sql, + len(sql), + tables=["public_users", "test.hello_world"], + columns={}, + ) + + assert "test" in result + assert "hello_world" not in result + + +def test_schema_dot_completes_tables_inside_schema() -> None: + """Typing schema. should offer tables from that schema as insertable names.""" + from sqlit.domains.query.completion import get_completions + + sql = "SELECT * FROM test." + result = get_completions( + sql, + len(sql), + tables=["public_users", "test.hello_world", "test.audit_log", "other.hello_world"], + columns={}, + ) + + assert result == ["hello_world", "audit_log"] + + +def test_schema_qualified_table_ref_loads_columns() -> None: + """schema.table alias lookup should use the qualified metadata key.""" + from sqlit.domains.query.ui.mixins.autocomplete_suggestions import AutocompleteSuggestionsMixin + + host = _SchemaHost( + tables=["test.hello_world"], + metadata={ + "test.hello_world": ("test", "hello_world", "appdb"), + "hello_world": ("test", "hello_world", "appdb"), + }, + ) + host._build_alias_map = AutocompleteSuggestionsMixin._build_alias_map.__get__(host) + get_suggestions = AutocompleteSuggestionsMixin._get_autocomplete_suggestions.__get__(host) + + sql = "SELECT * FROM test.hello_world h WHERE h." + result = get_suggestions(sql, len(sql)) + + assert result == ["Loading..."] + assert host.load_calls == ["test.hello_world"]