diff --git a/changelog.md b/changelog.md index f9e9eeca..7fd1c01d 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,11 @@ +Upcoming (TBD) +============== + +Internal +--------- +* Independent case-sensitivity for special command aliases. + + 1.71.0 (2026/05/01) ============== diff --git a/mycli/main.py b/mycli/main.py index 1c0b5e4a..3639d236 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -76,7 +76,7 @@ from mycli.packages.filepaths import dir_path_exists, guess_socket_location from mycli.packages.interactive_utils import confirm_destructive_query from mycli.packages.special.favoritequeries import FavoriteQueries -from mycli.packages.special.main import ArgType +from mycli.packages.special.main import ArgType, SpecialCommandAlias from mycli.packages.sqlresult import SQLResult from mycli.packages.ssh_utils import read_ssh_config from mycli.packages.tabular_output import sql_format @@ -312,39 +312,59 @@ def close(self) -> None: self.sqlexecute.close() def register_special_commands(self) -> None: - special.register_special_command(self.change_db, "use", "use ", "Change to a new database.", aliases=["\\u"]) + special.register_special_command( + self.change_db, + "use", + "use ", + "Change to a new database.", + aliases=[SpecialCommandAlias("\\u", case_sensitive=False)], + ) special.register_special_command( self.manual_reconnect, "connect", "connect [database]", "Reconnect to the server, optionally switching databases.", - aliases=["\\r"], case_sensitive=True, + aliases=[SpecialCommandAlias("\\r", case_sensitive=True)], ) special.register_special_command( - self.refresh_completions, "rehash", "rehash", "Refresh auto-completions.", arg_type=ArgType.NO_QUERY, aliases=["\\#"] + self.refresh_completions, + "rehash", + "rehash", + "Refresh auto-completions.", + arg_type=ArgType.NO_QUERY, + aliases=[SpecialCommandAlias("\\#", case_sensitive=False)], ) special.register_special_command( self.change_table_format, "tableformat", "tableformat ", "Change the table format used to output interactive results.", - aliases=["\\T"], case_sensitive=True, + aliases=[SpecialCommandAlias("\\T", case_sensitive=True)], ) special.register_special_command( self.change_redirect_format, "redirectformat", "redirectformat ", "Change the table format used to output redirected results.", - aliases=["\\Tr"], case_sensitive=True, + aliases=[SpecialCommandAlias("\\Tr", case_sensitive=True)], ) special.register_special_command( - self.execute_from_file, "source", "source ", "Execute queries from a file.", aliases=["\\."] + self.execute_from_file, + "source", + "source ", + "Execute queries from a file.", + aliases=[SpecialCommandAlias("\\.", case_sensitive=False)], ) special.register_special_command( - self.change_prompt_format, "prompt", "prompt ", "Change prompt format.", aliases=["\\R"], case_sensitive=True + self.change_prompt_format, + "prompt", + "prompt ", + "Change prompt format.", + case_sensitive=True, + aliases=[SpecialCommandAlias("\\R", case_sensitive=True)], ) def manual_reconnect(self, arg: str = "", **_) -> Generator[SQLResult, None, None]: diff --git a/mycli/packages/special/__init__.py b/mycli/packages/special/__init__.py index 24cfc5ed..9b226b84 100644 --- a/mycli/packages/special/__init__.py +++ b/mycli/packages/special/__init__.py @@ -49,6 +49,7 @@ ) from mycli.packages.special.main import ( CommandNotFound, + SpecialCommandAlias, execute, parse_special_command, register_special_command, @@ -58,6 +59,7 @@ __all__: list[str] = [ 'CommandNotFound', 'FinishIteration', + 'SpecialCommandAlias', 'clip_command', 'close_tee', 'copy_query_to_clipboard', diff --git a/mycli/packages/special/dbcommands.py b/mycli/packages/special/dbcommands.py index e5043ee5..0965efd3 100644 --- a/mycli/packages/special/dbcommands.py +++ b/mycli/packages/special/dbcommands.py @@ -7,7 +7,7 @@ from mycli import __version__ from mycli.packages.special import iocommands -from mycli.packages.special.main import ArgType, special_command +from mycli.packages.special.main import ArgType, SpecialCommandAlias, special_command from mycli.packages.special.utils import ( format_uptime, get_local_timezone, @@ -20,7 +20,13 @@ logger = logging.getLogger(__name__) -@special_command("\\dt", "\\dt[+] [table]", "List or describe tables.", arg_type=ArgType.PARSED_QUERY, case_sensitive=True) +@special_command( + "\\dt", + "\\dt[+] [table]", + "List or describe tables.", + arg_type=ArgType.PARSED_QUERY, + case_sensitive=True, +) def list_tables( cur: Cursor, arg: str | None = None, @@ -53,7 +59,13 @@ def list_tables( return [SQLResult(header=header, rows=results, postamble=postamble)] -@special_command("\\l", "\\l", "List databases.", arg_type=ArgType.RAW_QUERY, case_sensitive=True) +@special_command( + "\\l", + "\\l", + "List databases.", + arg_type=ArgType.RAW_QUERY, + case_sensitive=True, +) def list_databases(cur: Cursor, **_) -> list[SQLResult]: query = "SHOW DATABASES" logger.debug(query) @@ -67,7 +79,12 @@ def list_databases(cur: Cursor, **_) -> list[SQLResult]: @special_command( - "status", "status", "Get status information from the server.", arg_type=ArgType.RAW_QUERY, aliases=["\\s"], case_sensitive=True + "status", + "status", + "Get status information from the server.", + arg_type=ArgType.RAW_QUERY, + case_sensitive=True, + aliases=[SpecialCommandAlias("\\s", case_sensitive=True)], ) def status(cur: Cursor, **_) -> list[SQLResult]: query = "SHOW GLOBAL STATUS;" diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py index 2678a511..2a29c7cf 100644 --- a/mycli/packages/special/iocommands.py +++ b/mycli/packages/special/iocommands.py @@ -21,7 +21,7 @@ from mycli.packages.special.delimitercommand import DelimiterCommand from mycli.packages.special.favoritequeries import FavoriteQueries from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS -from mycli.packages.special.main import ArgType, special_command +from mycli.packages.special.main import ArgType, SpecialCommandAlias, special_command from mycli.packages.special.main import execute as special_execute from mycli.packages.special.utils import handle_cd_command from mycli.packages.sqlresult import SQLResult @@ -96,8 +96,8 @@ def is_show_warnings_enabled() -> bool: 'warnings', 'Enable automatic warnings display.', arg_type=ArgType.NO_QUERY, - aliases=['\\W'], case_sensitive=True, + aliases=[SpecialCommandAlias('\\W', case_sensitive=True)], ) def enable_show_warnings() -> Generator[SQLResult, None, None]: global SHOW_WARNINGS_ENABLED @@ -111,8 +111,8 @@ def enable_show_warnings() -> Generator[SQLResult, None, None]: 'nowarnings', 'Disable automatic warnings display.', arg_type=ArgType.NO_QUERY, - aliases=['\\w'], case_sensitive=True, + aliases=[SpecialCommandAlias('\\w', case_sensitive=True)], ) def disable_show_warnings() -> Generator[SQLResult, None, None]: global SHOW_WARNINGS_ENABLED @@ -126,8 +126,8 @@ def disable_show_warnings() -> Generator[SQLResult, None, None]: "pager [command]", "Set pager to [command]. Print query results via pager.", arg_type=ArgType.PARSED_QUERY, - aliases=["\\P"], case_sensitive=True, + aliases=[SpecialCommandAlias("\\P", case_sensitive=True)], ) def set_pager(arg: str, **_) -> list[SQLResult]: if arg: @@ -145,13 +145,27 @@ def set_pager(arg: str, **_) -> list[SQLResult]: return [SQLResult(status=msg)] -@special_command("nopager", "nopager", "Disable pager; print to stdout.", arg_type=ArgType.NO_QUERY, aliases=["\\n"], case_sensitive=True) +@special_command( + "nopager", + "nopager", + "Disable pager; print to stdout.", + arg_type=ArgType.NO_QUERY, + case_sensitive=True, + aliases=[SpecialCommandAlias("\\n", case_sensitive=True)], +) def disable_pager() -> list[SQLResult]: set_pager_enabled(False) return [SQLResult(status="Pager disabled.")] -@special_command("\\timing", "\\timing", "Toggle timing of queries.", arg_type=ArgType.NO_QUERY, aliases=["\\t"], case_sensitive=True) +@special_command( + "\\timing", + "\\timing", + "Toggle timing of queries.", + arg_type=ArgType.NO_QUERY, + case_sensitive=True, + aliases=[SpecialCommandAlias("\\t", case_sensitive=True)], +) def toggle_timing() -> list[SQLResult]: global TIMING_ENABLED TIMING_ENABLED = not TIMING_ENABLED @@ -309,7 +323,13 @@ def set_redirect(command_part: str | None, file_operator_part: str | None, file_ return set_once(file_part) -@special_command("\\f", "\\f [name [args..]]", "List or execute favorite queries.", arg_type=ArgType.PARSED_QUERY, case_sensitive=True) +@special_command( + "\\f", + "\\f [name [args..]]", + "List or execute favorite queries.", + arg_type=ArgType.PARSED_QUERY, + case_sensitive=True, +) def execute_favorite_query(cur: Cursor, arg: str, **_) -> Generator[SQLResult, None, None]: if arg == "": yield from list_favorite_queries() @@ -379,7 +399,11 @@ def subst_favorite_query_args(query: str, args: list[str]) -> list[str | None]: return [query, None] -@special_command("\\fs", "\\fs ", "Save a favorite query.") +@special_command( + "\\fs", + "\\fs ", + "Save a favorite query.", +) def save_favorite_query(arg: str, **_) -> list[SQLResult]: """Save a new favorite query.""" @@ -397,7 +421,11 @@ def save_favorite_query(arg: str, **_) -> list[SQLResult]: return [SQLResult(status="Saved.")] -@special_command("\\fd", "\\fd ", "Delete a favorite query.") +@special_command( + "\\fd", + "\\fd ", + "Delete a favorite query.", +) def delete_favorite_query(arg: str, **_) -> list[SQLResult]: """Delete an existing favorite query.""" usage = "Syntax: \\fd name.\n\n" + FavoriteQueries.instance.usage @@ -409,7 +437,11 @@ def delete_favorite_query(arg: str, **_) -> list[SQLResult]: return [SQLResult(status=status)] -@special_command("system", "system [-r] ", "Execute a system shell command (raw mode with -r).") +@special_command( + "system", + "system [-r] ", + "Execute a system shell command (raw mode with -r).", +) def execute_system_command(arg: str, **_) -> list[SQLResult]: """Execute a system shell command.""" usage = "Syntax: system [-r] [command].\n-r denotes \"raw\" mode, in which output is passed through without formatting." @@ -486,7 +518,11 @@ def parseargfile(arg: str) -> tuple[str, str]: return (os.path.expanduser(filename), mode) -@special_command("tee", "tee [-o] ", "Append all results to an output file (overwrite using -o).") +@special_command( + "tee", + "tee [-o] ", + "Append all results to an output file (overwrite using -o).", +) def set_tee(arg: str, **_) -> list[SQLResult]: global tee_file @@ -505,7 +541,11 @@ def close_tee() -> None: tee_file = None -@special_command("notee", "notee", "Stop writing results to an output file.") +@special_command( + "notee", + "notee", + "Stop writing results to an output file.", +) def no_tee(arg: str, **_) -> list[SQLResult]: close_tee() return [SQLResult(status="")] @@ -521,7 +561,12 @@ def write_tee(output: str | ANSI | FormattedText, nl: bool = True) -> None: tee_file.flush() -@special_command("\\once", "\\once [-o] ", "Append next result to an output file (overwrite using -o).", aliases=["\\o"]) +@special_command( + "\\once", + "\\once [-o] ", + "Append next result to an output file (overwrite using -o).", + aliases=[SpecialCommandAlias("\\o", case_sensitive=False)], +) def set_once(arg: str, **_) -> list[SQLResult]: global once_file, written_to_once_file @@ -574,7 +619,12 @@ def _run_post_redirect_hook(post_redirect_command: str, filename: str) -> None: raise OSError(f"Redirect post hook failed: {e}") from e -@special_command("\\pipe_once", "\\pipe_once ", "Send next result to a subprocess.", aliases=["\\|"]) +@special_command( + "\\pipe_once", + "\\pipe_once ", + "Send next result to a subprocess.", + aliases=[SpecialCommandAlias("\\|", case_sensitive=False)], +) def set_pipe_once(arg: str, **_) -> list[SQLResult]: if not arg: raise OSError("pipe_once requires a command") @@ -633,7 +683,11 @@ def flush_pipe_once_if_written(post_redirect_command: str) -> None: PIPE_ONCE['stdout_mode'] = None -@special_command("watch", "watch [seconds] [-c] ", "Execute query every [seconds] seconds (5 by default).") +@special_command( + "watch", + "watch [seconds] [-c] ", + "Execute query every [seconds] seconds (5 by default).", +) def watch_query(arg: str, **kwargs) -> Generator[SQLResult, None, None]: usage = """Syntax: watch [seconds] [-c] query. * seconds: The interval at the query will be repeated, in seconds. @@ -700,7 +754,11 @@ def watch_query(arg: str, **kwargs) -> Generator[SQLResult, None, None]: set_pager_enabled(old_pager_enabled) -@special_command("delimiter", "delimiter ", "Change end-of-statement delimiter.") +@special_command( + "delimiter", + "delimiter ", + "Change end-of-statement delimiter.", +) def set_delimiter(arg: str, **_) -> list[SQLResult]: return delimiter_command.set(arg) diff --git a/mycli/packages/special/main.py b/mycli/packages/special/main.py index 28782b75..1b03d1a6 100644 --- a/mycli/packages/special/main.py +++ b/mycli/packages/special/main.py @@ -1,4 +1,4 @@ -from collections import namedtuple +from dataclasses import dataclass from enum import Enum import logging import os @@ -25,20 +25,6 @@ CASE_SENSITIVE_COMMANDS = set() CASE_INSENSITIVE_COMMANDS = set() -SpecialCommand = namedtuple( - "SpecialCommand", - [ - "handler", - "command", - "usage", - "description", - "arg_type", - "hidden", - "case_sensitive", - "shortcut", - ], -) - class ArgType(Enum): NO_QUERY = 0 @@ -46,6 +32,24 @@ class ArgType(Enum): RAW_QUERY = 2 +@dataclass(frozen=True) +class SpecialCommandAlias: + command: str + case_sensitive: bool + + +@dataclass(frozen=True) +class SpecialCommand: + handler: Callable + command: str + usage: str + description: str + arg_type: ArgType + hidden: bool | None + case_sensitive: bool | None + aliases: list[SpecialCommandAlias] | None + + class CommandNotFound(Exception): pass @@ -69,12 +73,12 @@ def parse_special_command(sql: str) -> tuple[str, CommandVerbosity, str]: def special_command( command: str, - usage: str | None, + usage: str, description: str, arg_type: ArgType = ArgType.PARSED_QUERY, hidden: bool = False, case_sensitive: bool = False, - aliases: list[str] | None = None, + aliases: list[SpecialCommandAlias] | None = None, ) -> Callable: def wrapper(wrapped): register_special_command( @@ -95,12 +99,12 @@ def wrapper(wrapped): def register_special_command( handler: Callable, command: str, - usage: str | None, + usage: str, description: str, arg_type: ArgType = ArgType.PARSED_QUERY, hidden: bool = False, case_sensitive: bool = False, - aliases: list[str] | None = None, + aliases: list[SpecialCommandAlias] | None = None, ) -> None: cmd = command.lower() if not case_sensitive else command COMMANDS[cmd] = SpecialCommand( @@ -111,7 +115,7 @@ def register_special_command( arg_type=arg_type, hidden=hidden, case_sensitive=case_sensitive, - shortcut=aliases[0] if aliases else None, + aliases=aliases, ) if case_sensitive: CASE_SENSITIVE_COMMANDS.add(command) @@ -119,20 +123,20 @@ def register_special_command( CASE_INSENSITIVE_COMMANDS.add(command.lower()) aliases = [] if aliases is None else aliases for alias in aliases: - cmd = alias.lower() if not case_sensitive else alias - if case_sensitive: - CASE_SENSITIVE_COMMANDS.add(alias) + cmd = alias.command.lower() if not alias.case_sensitive else alias.command + if alias.case_sensitive: + CASE_SENSITIVE_COMMANDS.add(alias.command) else: - CASE_INSENSITIVE_COMMANDS.add(alias.lower()) + CASE_INSENSITIVE_COMMANDS.add(alias.command.lower()) COMMANDS[cmd] = SpecialCommand( handler, command, usage, description, arg_type=arg_type, - case_sensitive=case_sensitive, + case_sensitive=alias.case_sensitive, hidden=True, - shortcut=None, + aliases=None, ) @@ -167,20 +171,32 @@ def execute(cur: Cursor, sql: str) -> list[SQLResult]: raise CommandNotFound(f"Command type not found: {command}") -@special_command("help", "help [term]", "Show this table, or search for help on a term.", arg_type=ArgType.NO_QUERY, aliases=["\\?", "?"]) +@special_command( + "help", + "help [term]", + "Show this table, or search for help on a term.", + arg_type=ArgType.NO_QUERY, + aliases=[SpecialCommandAlias("\\?", case_sensitive=False), SpecialCommandAlias("?", case_sensitive=False)], +) def show_help(*_args) -> list[SQLResult]: header = ["Command", "Shortcut", "Usage", "Description"] result = [] for _, value in sorted(COMMANDS.items()): - if not value.hidden: - result.append((value.command, value.shortcut, value.usage, value.description)) + if value.hidden: + continue + if value.aliases: + shortcut = value.aliases[0].command + else: + shortcut = None + result.append((value.command, shortcut, value.usage, value.description)) return [SQLResult(header=header, rows=result, postamble=f'Docs index — {DOCS_URL}')] def _show_special_help(keyword: str) -> list[SQLResult]: header = ['name', 'description', 'example'] - description = '\n'.join(COMMANDS[keyword][2:4]) + command = COMMANDS[keyword] + description = '\n'.join([command.usage or '', command.description]) rows = [(keyword, description, '')] return [SQLResult(header=header, rows=rows)] @@ -224,8 +240,20 @@ def file_bug(*_args) -> list[SQLResult]: return [SQLResult(status=f'{ISSUES_URL} — press "New Issue"')] -@special_command("exit", "exit", "Exit.", arg_type=ArgType.NO_QUERY, aliases=["\\q"]) -@special_command("quit", "quit", "Quit.", arg_type=ArgType.NO_QUERY, aliases=["\\q"]) +@special_command( + "exit", + "exit", + "Exit.", + arg_type=ArgType.NO_QUERY, + aliases=[SpecialCommandAlias("\\q", case_sensitive=False)], +) +@special_command( + "quit", + "quit", + "Quit.", + arg_type=ArgType.NO_QUERY, + aliases=[SpecialCommandAlias("\\q", case_sensitive=False)], +) def quit_(*_args): raise EOFError @@ -236,10 +264,22 @@ def quit_(*_args): "Edit query with editor (uses $VISUAL or $EDITOR).", arg_type=ArgType.NO_QUERY, case_sensitive=True, - aliases=['\\e'], + aliases=[SpecialCommandAlias("\\e", case_sensitive=True)], +) +@special_command( + "\\clip", + "\\clip", + "Copy query to the system clipboard.", + arg_type=ArgType.NO_QUERY, + case_sensitive=True, +) +@special_command( + "\\G", + "\\G", + "Display query results vertically.", + arg_type=ArgType.NO_QUERY, + case_sensitive=True, ) -@special_command("\\clip", "\\clip", "Copy query to the system clipboard.", arg_type=ArgType.NO_QUERY, case_sensitive=True) -@special_command("\\G", "\\G", "Display query results vertically.", arg_type=ArgType.NO_QUERY, case_sensitive=True) def stub(): raise NotImplementedError @@ -252,7 +292,7 @@ def stub(): "Interrogate an LLM. See \"\\llm help\".", arg_type=ArgType.RAW_QUERY, case_sensitive=True, - aliases=["\\ai"], + aliases=[SpecialCommandAlias("\\ai", case_sensitive=True)], ) def llm_stub(): raise NotImplementedError diff --git a/test/pytests/test_completion_engine.py b/test/pytests/test_completion_engine.py index b17b218b..e6b4bc89 100644 --- a/test/pytests/test_completion_engine.py +++ b/test/pytests/test_completion_engine.py @@ -1671,7 +1671,13 @@ def test_after_as(expression): ) def test_source_is_file(expression): # "source" has to be registered by hand because that usually happens inside MyCLI in mycli/main.py - special.register_special_command(..., 'source', '\\. ', 'Execute commands from file.', aliases=['\\.']) + special.register_special_command( + ..., + 'source', + '\\. ', + 'Execute commands from file.', + aliases=[special.SpecialCommandAlias('\\.', case_sensitive=False)], + ) suggestions = suggest_type(expression, expression) assert suggestions == [{"type": "file_name"}] diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py index ccf8a858..0cec5752 100644 --- a/test/pytests/test_main.py +++ b/test/pytests/test_main.py @@ -717,7 +717,7 @@ def test_command_descriptions_end_with_periods(): """Make sure that mycli commands' descriptions end with a period.""" MyCli() for _, command in SPECIAL_COMMANDS.items(): - assert command[3].endswith(".") + assert command.description.endswith(".") def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): diff --git a/test/pytests/test_smart_completion_public_schema_only.py b/test/pytests/test_smart_completion_public_schema_only.py index 4b1b5a0d..72b64b87 100644 --- a/test/pytests/test_smart_completion_public_schema_only.py +++ b/test/pytests/test_smart_completion_public_schema_only.py @@ -80,7 +80,13 @@ def complete_event(): def test_use_database_completion(completer, complete_event): text = "USE " position = len(text) - special.register_special_command(..., 'use', '\\u [database]', 'Change to a new database.', aliases=['\\u']) + special.register_special_command( + ..., + 'use', + '\\u [database]', + 'Change to a new database.', + aliases=[special.SpecialCommandAlias('\\u', case_sensitive=False)], + ) result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert list(result) == [ Completion(text="test", start_position=0), @@ -652,7 +658,13 @@ def dummy_list_path(dir_name): ) def test_file_name_completion(completer, complete_event, text, expected): position = len(text) - special.register_special_command(..., 'source', '\\. ', 'Execute commands from file.', aliases=['\\.']) + special.register_special_command( + ..., + 'source', + '\\. ', + 'Execute commands from file.', + aliases=[special.SpecialCommandAlias('\\.', case_sensitive=False)], + ) result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) expected = [Completion(txt, pos) for txt, pos in expected] assert result == expected @@ -689,7 +701,13 @@ def test_source_eager_completion(completer, complete_event, tmp_path, monkeypatc script_filename = 'do_these_statements.sql' f = open(script_filename, 'w') f.close() - special.register_special_command(..., 'source', '\\. ', 'Execute commands from file.', aliases=['\\.']) + special.register_special_command( + ..., + 'source', + '\\. ', + 'Execute commands from file.', + aliases=[special.SpecialCommandAlias('\\.', case_sensitive=False)], + ) result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) success = True error = 'unknown' @@ -715,7 +733,13 @@ def test_source_leading_dot_suggestions_completion(completer, complete_event, tm script_filename = 'do_these_statements.sql' f = open(script_filename, 'w') f.close() - special.register_special_command(..., 'source', '\\. ', 'Execute commands from file.', aliases=['\\.']) + special.register_special_command( + ..., + 'source', + '\\. ', + 'Execute commands from file.', + aliases=[special.SpecialCommandAlias('\\.', case_sensitive=False)], + ) result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) success = True error = 'unknown' diff --git a/test/pytests/test_special_main.py b/test/pytests/test_special_main.py index 42fcf4b7..3c1b2e77 100644 --- a/test/pytests/test_special_main.py +++ b/test/pytests/test_special_main.py @@ -81,7 +81,7 @@ def handler() -> None: 'Demo', 'demo', 'Description', - aliases=['\\d'], + aliases=[special_main.SpecialCommandAlias('\\d', case_sensitive=False)], ) assert special_main.COMMANDS['demo'] == special_main.SpecialCommand( @@ -92,7 +92,7 @@ def handler() -> None: arg_type=special_main.ArgType.PARSED_QUERY, hidden=False, case_sensitive=False, - shortcut='\\d', + aliases=[special_main.SpecialCommandAlias('\\d', case_sensitive=False)], ) assert special_main.COMMANDS['\\d'] == special_main.SpecialCommand( handler, @@ -102,7 +102,7 @@ def handler() -> None: arg_type=special_main.ArgType.PARSED_QUERY, hidden=True, case_sensitive=False, - shortcut=None, + aliases=None, ) @@ -116,7 +116,7 @@ def test_register_special_command_tracks_case_insensitive_commands(restore_comma 'Demo', 'demo', 'Description', - aliases=['\\d'], + aliases=[special_main.SpecialCommandAlias('\\d', case_sensitive=False)], ) assert special_main.CASE_SENSITIVE_COMMANDS == set() @@ -159,7 +159,7 @@ def test_execute_raises_for_case_sensitive_alias_lookup(restore_commands: None) 'Demo', 'Description', case_sensitive=True, - aliases=['demo'], + aliases=[special_main.SpecialCommandAlias('demo', case_sensitive=True)], ) with pytest.raises(special_main.CommandNotFound, match='Command not found: DEMO'): @@ -178,7 +178,7 @@ def test_execute_raises_when_case_sensitive_exact_lookup_falls_back_to_lowercase arg_type=special_main.ArgType.NO_QUERY, hidden=False, case_sensitive=True, - shortcut=None, + aliases=None, ) special_main.CASE_SENSITIVE_COMMANDS.add('Camel') @@ -309,7 +309,7 @@ def test_execute_raises_for_unknown_arg_type(restore_commands: None) -> None: arg_type=cast(Any, object()), hidden=False, case_sensitive=False, - shortcut=None, + aliases=None, ) special_main.CASE_INSENSITIVE_COMMANDS.add('demo') @@ -319,7 +319,13 @@ def test_execute_raises_for_unknown_arg_type(restore_commands: None) -> None: def test_show_help_lists_only_visible_commands(restore_commands: None) -> None: special_main.COMMANDS.clear() - special_main.register_special_command(lambda: None, 'visible', 'visible', 'Visible command', aliases=['\\v']) + special_main.register_special_command( + lambda: None, + 'visible', + 'visible', + 'Visible command', + aliases=[special_main.SpecialCommandAlias('\\v', case_sensitive=False)], + ) special_main.register_special_command(lambda: None, 'hidden', 'hidden', 'Hidden command', hidden=True) result = special_main.show_help()[0]