diff --git a/README.md b/README.md index 985f1fa..f4d6b75 100644 --- a/README.md +++ b/README.md @@ -187,6 +187,37 @@ sqlit connections list sqlit connections delete "MyConnection" ``` +### Shell completion + +`sqlit` ships tab-completion for bash, zsh, and fish (subcommands, flags, +database providers, and your saved connection names). It's powered by +[`argcomplete`](https://github.com/kislyuk/argcomplete), an optional extra: + +```bash +pipx inject sqlit-tui argcomplete # or: pip install 'sqlit-tui[completion]' +``` + +Then enable it for your shell: + +```bash +# bash — add to ~/.bashrc +eval "$(sqlit completion bash)" + +# zsh — add to ~/.zshrc +eval "$(sqlit completion zsh)" + +# fish — add to ~/.config/fish/config.fish +sqlit completion fish | source +``` + +Restart your shell (or re-source the rc file) and press ``: + +```bash +sqlit conn # → connections / connect / ... +sqlit connect # → postgresql, mysql, sqlite, ... +sqlit query --connection # → your saved connection names +``` + ## Keybindings | Key | Action | diff --git a/pyproject.toml b/pyproject.toml index 8ff1771..8d86525 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ dynamic = ["version"] [project.optional-dependencies] +completion = ["argcomplete>=3.0"] all = [ "psycopg2-binary>=2.9.0", "mssql-python>=1.1.0", diff --git a/sqlit/cli.py b/sqlit/cli.py index f5b838f..f65214f 100644 --- a/sqlit/cli.py +++ b/sqlit/cli.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# PYTHON_ARGCOMPLETE_OK """sqlit - A terminal UI for SQL databases.""" from __future__ import annotations @@ -12,6 +13,7 @@ from pathlib import Path from typing import Any +from sqlit.domains.connections.cli.completion import complete_connection_names from sqlit.domains.connections.cli.helpers import add_schema_arguments, build_connection_config_from_args from sqlit.domains.connections.domain.config import AuthType, ConnectionConfig, DatabaseType from sqlit.domains.connections.providers.catalog import get_provider_schema, get_supported_db_types @@ -464,6 +466,33 @@ def _build_runtime( ) +def _cmd_completion(shell: str) -> int: + """Print the argcomplete activation snippet for the given shell.""" + try: + import argcomplete + except ImportError: + print( + "Shell completion requires the 'argcomplete' package.\n" + "Install it with: pip install 'sqlit-tui[completion]'", + file=sys.stderr, + ) + return 1 + + try: + snippet = argcomplete.shellcode(["sqlit"], shell=shell) # type: ignore[attr-defined] + except Exception as exc: # pragma: no cover - depends on argcomplete version + print( + f"Could not generate completion for {shell!r}: {exc}\n" + "Your argcomplete version may not support this shell; " + "try upgrading: pip install -U argcomplete", + file=sys.stderr, + ) + return 1 + + print(snippet) + return 0 + + def main() -> int: """Entry point for the CLI.""" startup_mark = time.perf_counter() @@ -645,7 +674,7 @@ def main() -> int: "--connection", metavar="NAME", help="Connect to a saved connection by name (opens TUI with only this connection)", - ) + ).completer = complete_connection_names # type: ignore[attr-defined] subparsers = parser.add_subparsers(dest="command", help="Available commands") @@ -702,7 +731,9 @@ def main() -> int: ) edit_parser = conn_subparsers.add_parser("edit", help="Edit an existing connection") - edit_parser.add_argument("connection_name", help="Name of connection to edit") + edit_parser.add_argument( + "connection_name", help="Name of connection to edit" + ).completer = complete_connection_names # type: ignore[attr-defined] edit_parser.add_argument("--name", "-n", help="New connection name") edit_parser.add_argument("--server", "-s", help="Server address") edit_parser.add_argument("--host", help="Alias for --server (e.g. Cloudflare D1 Account ID)") @@ -731,7 +762,9 @@ def main() -> int: ) delete_parser = conn_subparsers.add_parser("delete", help="Delete a connection") - delete_parser.add_argument("connection_name", help="Name of connection to delete") + delete_parser.add_argument( + "connection_name", help="Name of connection to delete" + ).completer = complete_connection_names # type: ignore[attr-defined] connect_parser = subparsers.add_parser("connect", help="Temporary connection (not saved)") connect_provider_parsers = connect_parser.add_subparsers(dest="provider", metavar="PROVIDER") @@ -753,7 +786,9 @@ def main() -> int: ) query_parser = subparsers.add_parser("query", help="Execute a SQL query") - query_parser.add_argument("--connection", "-c", required=True, help="Connection name to use") + query_parser.add_argument( + "--connection", "-c", required=True, help="Connection name to use" + ).completer = complete_connection_names # type: ignore[attr-defined] query_parser.add_argument("--database", "-d", help="Database to query (overrides connection default)") query_parser.add_argument("--query", "-q", help="SQL query to execute") query_parser.add_argument("--file", "-f", help="SQL file to execute") @@ -794,7 +829,7 @@ def main() -> int: "--connection", "-c", help="Target a specific saved connection (omit for global)", - ) + ).completer = complete_connection_names # type: ignore[attr-defined] alerts_set.add_argument( "--database", "-d", @@ -808,20 +843,54 @@ def main() -> int: "--connection", "-c", help="Connection whose override to clear", - ) + ).completer = complete_connection_names # type: ignore[attr-defined] alerts_unset.add_argument( "--database", "-d", help="Database whose override to clear (requires --connection)", ) + completion_parser = subparsers.add_parser( + "completion", + help="Print a shell completion script for sqlit", + description=( + "Print the shell completion activation snippet for sqlit.\n" + "Requires the optional 'argcomplete' dependency: pip install 'sqlit-tui[completion]'\n\n" + " bash: eval \"$(sqlit completion bash)\" # add to ~/.bashrc\n" + " zsh: eval \"$(sqlit completion zsh)\" # add to ~/.zshrc\n" + " fish: sqlit completion fish | source # add to ~/.config/fish/config.fish" + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + completion_parser.add_argument( + "shell", + choices=["bash", "zsh", "fish"], + help="Shell to print the completion script for", + ) + log_startup_step("cli_parser_end") + # Shell completion: argcomplete inspects the fully-built parser and exits + # here (before runtime/services are built and the TUI is imported) when a + # completion request is in progress. No-op when argcomplete isn't installed + # or we're running normally. + try: + import argcomplete + except ImportError: + pass + else: + argcomplete.autocomplete(parser) + with startup_span("cli_parse_args"): args = parser.parse_args(filtered_argv[1:]) # Skip program name _resolve_stdin_secrets(args) log_startup_step("cli_parse_end") + # `sqlit completion ` prints an activation snippet and exits without + # touching the runtime or services. + if args.command == "completion": + return _cmd_completion(args.shell) + with startup_span("runtime_build"): runtime = _build_runtime(args, startup_mark, project_dir=project_dir) diff --git a/sqlit/domains/connections/cli/completion.py b/sqlit/domains/connections/cli/completion.py new file mode 100644 index 0000000..88a046c --- /dev/null +++ b/sqlit/domains/connections/cli/completion.py @@ -0,0 +1,56 @@ +"""Shell-completion helpers for the sqlit CLI. + +These run inside the argcomplete completion subprocess on every press, +so they must stay cheap: read the connections JSON directly rather than +building the services stack (which imports asyncio and the credentials +service). See ``sqlit.cli`` for where the completers are attached. +""" + +from __future__ import annotations + +from typing import Any + + +def _connection_names() -> list[str]: + """Return saved connection names from the global config, best-effort. + + Reads ``CONFIG_DIR/connections.json`` directly and mirrors the payload + shape handled by ``ConnectionStore._unpack_connections_payload`` (a bare + list for the legacy v1 format, or a ``{"connections": [...]}`` dict for + v2+). Any error returns an empty list — completion must never raise. + """ + try: + import json + + from sqlit.shared.core.store import CONFIG_DIR + + path = CONFIG_DIR / "connections.json" + if not path.is_file(): + return [] + data = json.loads(path.read_text(encoding="utf-8")) + except Exception: + return [] + + if isinstance(data, dict): + raw = data.get("connections") + else: + raw = data + if not isinstance(raw, list): + return [] + + names: list[str] = [] + for entry in raw: + if isinstance(entry, dict): + name = entry.get("name") + if isinstance(name, str) and name: + names.append(name) + return names + + +def complete_connection_names(prefix: str, **_: Any) -> list[str]: + """argcomplete completer for arguments that take a saved connection name. + + argcomplete invokes completers with extra keyword arguments (``action``, + ``parser``, ``parsed_args``); we accept and ignore them. + """ + return [name for name in _connection_names() if name.startswith(prefix)] diff --git a/tests/cli/test_completion.py b/tests/cli/test_completion.py new file mode 100644 index 0000000..86ec9d4 --- /dev/null +++ b/tests/cli/test_completion.py @@ -0,0 +1,139 @@ +"""Tests for shell tab-completion support (issue #247).""" + +from __future__ import annotations + +import json +import os +import subprocess +import sys +from pathlib import Path + +import pytest + +from sqlit.domains.connections.cli import completion + + +def _write_connections(config_dir: Path, payload: object) -> None: + config_dir.mkdir(parents=True, exist_ok=True) + (config_dir / "connections.json").write_text(json.dumps(payload), encoding="utf-8") + + +# -------------------------------------------------------------------------- +# Unit tests for the connection-name completer +# -------------------------------------------------------------------------- + + +def test_complete_connection_names_v2_format(tmp_path: Path, monkeypatch): + _write_connections( + tmp_path, + {"version": 2, "connections": [{"name": "prod-pg"}, {"name": "prod-mysql"}, {"name": "staging"}]}, + ) + # The completer reads CONFIG_DIR lazily from the store module. + monkeypatch.setattr("sqlit.shared.core.store.CONFIG_DIR", tmp_path) + + assert completion.complete_connection_names("prod") == ["prod-pg", "prod-mysql"] + assert set(completion.complete_connection_names("")) == {"prod-pg", "prod-mysql", "staging"} + assert completion.complete_connection_names("zzz") == [] + + +def test_complete_connection_names_legacy_list_format(tmp_path: Path, monkeypatch): + _write_connections(tmp_path, [{"name": "legacy-one"}, {"name": "legacy-two"}]) + monkeypatch.setattr("sqlit.shared.core.store.CONFIG_DIR", tmp_path) + + assert set(completion.complete_connection_names("")) == {"legacy-one", "legacy-two"} + + +def test_complete_connection_names_missing_file_is_safe(tmp_path: Path, monkeypatch): + monkeypatch.setattr("sqlit.shared.core.store.CONFIG_DIR", tmp_path / "nope") + assert completion.complete_connection_names("") == [] + + +def test_complete_connection_names_malformed_json_is_safe(tmp_path: Path, monkeypatch): + (tmp_path).mkdir(parents=True, exist_ok=True) + (tmp_path / "connections.json").write_text("{not valid json", encoding="utf-8") + monkeypatch.setattr("sqlit.shared.core.store.CONFIG_DIR", tmp_path) + assert completion.complete_connection_names("") == [] + + +# -------------------------------------------------------------------------- +# `sqlit completion ` subcommand +# -------------------------------------------------------------------------- + + +@pytest.mark.parametrize("shell", ["bash", "zsh", "fish"]) +def test_completion_subcommand_prints_script(shell: str): + pytest.importorskip("argcomplete") + result = subprocess.run( + [sys.executable, "-m", "sqlit.cli", "completion", shell], + capture_output=True, + text=True, + ) + assert result.returncode == 0, result.stderr + assert result.stdout.strip(), "expected a non-empty completion script" + assert "sqlit" in result.stdout + + +def test_completion_subcommand_rejects_unknown_shell(): + result = subprocess.run( + [sys.executable, "-m", "sqlit.cli", "completion", "powershell"], + capture_output=True, + text=True, + ) + assert result.returncode == 2 + assert "invalid choice" in result.stderr + + +# -------------------------------------------------------------------------- +# End-to-end argcomplete protocol +# -------------------------------------------------------------------------- + + +def _run_completion(comp_line: str, config_dir: Path) -> list[str]: + """Drive the argcomplete protocol and return the emitted candidates. + + argcomplete writes the newline/IFS-separated candidates to fd 8, so we use + a bash redirection (`8>file`) to capture them. + """ + out_file = config_dir / "_comp_out" + inner = f"exec {sys.executable!s} -m sqlit.cli 8>{out_file!s} 9>/dev/null 2>/dev/null" + env = { + **os.environ, + "SQLIT_CONFIG_DIR": str(config_dir), + "_ARGCOMPLETE": "1", + "_ARGCOMPLETE_SHELL": "bash", + "_ARGCOMPLETE_COMP_WORDBREAKS": " \t\n\"'><=;|&(:", + "COMP_LINE": comp_line, + "COMP_POINT": str(len(comp_line)), + "COMP_TYPE": "9", + } + subprocess.run(["bash", "-c", inner], env=env) + if not out_file.exists(): + return [] + raw = out_file.read_text(encoding="utf-8", errors="replace") + # Candidates are separated by the IFS argcomplete uses (\013) or whitespace. + parts = raw.replace("\013", "\n").split("\n") + return [p.strip() for p in parts if p.strip()] + + +def test_argcomplete_completes_subcommands(tmp_path: Path): + pytest.importorskip("argcomplete") + candidates = _run_completion("sqlit ", tmp_path) + for expected in ("connections", "connect", "query", "alerts", "completion"): + assert expected in candidates, f"{expected!r} missing from {candidates}" + + +def test_argcomplete_completes_providers_for_connect(tmp_path: Path): + pytest.importorskip("argcomplete") + candidates = _run_completion("sqlit connect ", tmp_path) + for expected in ("postgresql", "mysql", "sqlite"): + assert expected in candidates + + +def test_argcomplete_completes_saved_connection_names(tmp_path: Path): + pytest.importorskip("argcomplete") + _write_connections( + tmp_path, + {"version": 2, "connections": [{"name": "prod-pg"}, {"name": "prod-mysql"}, {"name": "staging"}]}, + ) + candidates = _run_completion("sqlit query --connection prod", tmp_path) + assert set(candidates) == {"prod-pg", "prod-mysql"}