Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 104 additions & 95 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,159 +374,168 @@ def _handle_reset_config(confirm: bool = True) -> None:

@lru_cache(maxsize=1)
def _build_parser() -> ArgumentParser:
parser = ArgumentParser()
subparsers = parser.add_subparsers(dest="command", help="Sub-commands")

subparsers.add_parser("init", help="Initialize Codeflash for your project.")
subparsers.add_parser("vscode-install", help="Install the Codeflash VSCode extension")
subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")

trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.", add_help=False)
auth_parser = subparsers.add_parser("auth", help="Authentication commands")
auth_subparsers = auth_parser.add_subparsers(dest="auth_command", help="Auth sub-commands")
auth_subparsers.add_parser("login", help="Log in to Codeflash via OAuth")
auth_subparsers.add_parser("status", help="Check authentication status")

compare_parser = subparsers.add_parser("compare", help="Compare benchmark performance between two git refs.")
compare_parser.add_argument(
"base_ref", nargs="?", default=None, help="Base git ref (default: auto-detect from PR or default branch)"
)
compare_parser.add_argument("head_ref", nargs="?", default=None, help="Head git ref (default: current branch)")
compare_parser.add_argument("--pr", type=int, help="Resolve head ref from a PR number (requires gh CLI)")
compare_parser.add_argument(
"--functions", type=str, help="Explicit functions to instrument: 'file.py::func1,func2;other.py::func3'"
)
compare_parser.add_argument("--timeout", type=int, default=600, help="Benchmark timeout in seconds (default: 600)")
compare_parser.add_argument("--output", "-o", type=str, help="Write markdown report to file")
compare_parser.add_argument(
"--memory", action="store_true", help="Profile peak memory usage per benchmark (requires memray, Linux/macOS)"
)
compare_parser.add_argument("--script", type=str, help="Shell command to run as benchmark in each worktree")
compare_parser.add_argument(
"--script-output",
type=str,
dest="script_output",
help="Relative path to JSON results file produced by --script (required with --script)",
)
compare_parser.add_argument("--config-file", type=str, dest="config_file", help="Path to pyproject.toml")
compare_parser.add_argument(
"--inject",
nargs="+",
default=None,
help="Files or directories to copy into both worktrees before benchmarking. Paths are relative to repo root.",
)

trace_optimize.add_argument(
"--max-function-count",
type=int,
default=100,
help="The maximum number of times to trace a single function. More calls to a function will not be traced. Default is 100.",
)
trace_optimize.add_argument(
"--timeout",
type=int,
help="The maximum time in seconds to trace the entire workflow. Default is indefinite. This is useful while tracing really long workflows, to not wait indefinitely.",
)
trace_optimize.add_argument(
"--output",
type=str,
default="codeflash.trace",
help="The file to save the trace to. Default is codeflash.trace.",
)
trace_optimize.add_argument(
"--config-file-path",
type=str,
help="The path to the pyproject.toml file which stores the Codeflash config. This is auto-discovered by default.",
)

parser.add_argument("--file", help="Try to optimize only this file")
parser.add_argument("--function", help="Try to optimize only this function within the given file path")
parser.add_argument(
# Shared flags used by both the main parser and the optimize subparser.
# Using argparse parents= ensures flags like --verbose, --no-pr, --file are recognized
# regardless of whether the user runs "codeflash --verbose" or "codeflash optimize --verbose".
shared_flags = ArgumentParser(add_help=False)
shared_flags.add_argument("--file", help="Try to optimize only this file")
shared_flags.add_argument("--function", help="Try to optimize only this function within the given file path")
shared_flags.add_argument(
"--all",
help="Try to optimize all functions. Can take a really long time. Can pass an optional starting directory to"
" optimize code from. If no args specified (just --all), will optimize all code in the project.",
nargs="?",
const="",
default=SUPPRESS,
)
parser.add_argument(
shared_flags.add_argument(
"--module-root",
type=str,
help="Path to the project's module that you want to optimize."
" This is the top-level root directory where all the source code is located.",
)
parser.add_argument(
shared_flags.add_argument(
"--tests-root", type=str, help="Path to the test directory of the project, where all the tests are located."
)
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from")
parser.add_argument(
shared_flags.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
shared_flags.add_argument(
"--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from"
)
shared_flags.add_argument(
"--rerun",
type=str,
help="Rerun a previous optimization by trace ID, using stored LLM results",
metavar="TRACE_ID",
)
parser.add_argument(
shared_flags.add_argument(
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
)
parser.add_argument(
shared_flags.add_argument(
"--no-gen-tests", action="store_true", help="Do not generate tests, use only existing tests for optimization."
)
parser.add_argument(
shared_flags.add_argument(
"--no-jit-opts", action="store_true", help="Do not generate JIT-compiled optimizations for numerical code."
)
parser.add_argument("--staging-review", action="store_true", help="Upload optimizations to staging for review")
parser.add_argument(
shared_flags.add_argument(
"--staging-review", action="store_true", help="Upload optimizations to staging for review"
)
shared_flags.add_argument(
"--verify-setup",
action="store_true",
help="Verify that codeflash is set up correctly by optimizing bubble sort as a test.",
)
parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose debug logs")
parser.add_argument("--version", action="store_true", help="Print the version of codeflash")
parser.add_argument(
shared_flags.add_argument("-v", "--verbose", action="store_true", help="Print verbose debug logs")
shared_flags.add_argument("--version", action="store_true", help="Print the version of codeflash")
shared_flags.add_argument(
"--benchmark", action="store_true", help="Trace benchmark tests and calculate optimization impact on benchmarks"
)
parser.add_argument(
shared_flags.add_argument(
"--benchmarks-root",
type=str,
help="Path to the directory of the project, where all the pytest-benchmark tests are located.",
)
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
parser.add_argument(
shared_flags.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
shared_flags.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
shared_flags.add_argument(
"--testgen-review", default=False, action="store_true", help="Enable AI review and repair of generated tests"
)
parser.add_argument(
shared_flags.add_argument(
"--testgen-review-turns", type=int, default=None, help="Number of review/repair cycles (default: 2)"
)
parser.add_argument(
shared_flags.add_argument(
"--async",
default=False,
action="store_true",
help="(Deprecated) Async function optimization is now enabled by default. This flag is ignored.",
)
parser.add_argument(
shared_flags.add_argument(
"--server",
type=str,
choices=["local", "prod"],
help="AI service server to use: 'local' for localhost:8000, 'prod' for app.codeflash.ai",
)
parser.add_argument(
shared_flags.add_argument(
"--effort", type=str, help="Effort level for optimization", choices=["low", "medium", "high"], default="medium"
)

# Config management flags
parser.add_argument(
shared_flags.add_argument(
"--show-config", action="store_true", help="Show current or auto-detected configuration and exit."
)
parser.add_argument(
shared_flags.add_argument(
"--reset-config", action="store_true", help="Remove codeflash configuration from project config file."
)
parser.add_argument("-y", "--yes", action="store_true", help="Skip confirmation prompts (useful for CI/scripts).")
parser.add_argument(
shared_flags.add_argument(
"-y", "--yes", action="store_true", help="Skip confirmation prompts (useful for CI/scripts)."
)
shared_flags.add_argument(
"--subagent",
action="store_true",
help="Subagent mode: skip all interactive prompts with sensible defaults. Designed for AI agent integrations.",
)
shared_flags.add_argument("--trace-only", action="store_true", help="Only trace, do not optimize")

parser = ArgumentParser(parents=[shared_flags])
subparsers = parser.add_subparsers(dest="command", help="Sub-commands")

subparsers.add_parser("init", help="Initialize Codeflash for your project.")
subparsers.add_parser("vscode-install", help="Install the Codeflash VSCode extension")
subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")

trace_optimize = subparsers.add_parser(
"optimize", help="Trace and optimize your project.", add_help=False, parents=[shared_flags]
)
auth_parser = subparsers.add_parser("auth", help="Authentication commands")
auth_subparsers = auth_parser.add_subparsers(dest="auth_command", help="Auth sub-commands")
auth_subparsers.add_parser("login", help="Log in to Codeflash via OAuth")
auth_subparsers.add_parser("status", help="Check authentication status")

compare_parser = subparsers.add_parser("compare", help="Compare benchmark performance between two git refs.")
compare_parser.add_argument(
"base_ref", nargs="?", default=None, help="Base git ref (default: auto-detect from PR or default branch)"
)
compare_parser.add_argument("head_ref", nargs="?", default=None, help="Head git ref (default: current branch)")
compare_parser.add_argument("--pr", type=int, help="Resolve head ref from a PR number (requires gh CLI)")
compare_parser.add_argument(
"--functions", type=str, help="Explicit functions to instrument: 'file.py::func1,func2;other.py::func3'"
)
compare_parser.add_argument("--timeout", type=int, default=600, help="Benchmark timeout in seconds (default: 600)")
compare_parser.add_argument("--output", "-o", type=str, help="Write markdown report to file")
compare_parser.add_argument(
"--memory", action="store_true", help="Profile peak memory usage per benchmark (requires memray, Linux/macOS)"
)
compare_parser.add_argument("--script", type=str, help="Shell command to run as benchmark in each worktree")
compare_parser.add_argument(
"--script-output",
type=str,
dest="script_output",
help="Relative path to JSON results file produced by --script (required with --script)",
)
compare_parser.add_argument("--config-file", type=str, dest="config_file", help="Path to pyproject.toml")
compare_parser.add_argument(
"--inject",
nargs="+",
default=None,
help="Files or directories to copy into both worktrees before benchmarking. Paths are relative to repo root.",
)

trace_optimize.add_argument(
"--max-function-count",
type=int,
default=256,
help="The maximum number of times to trace a single function. Default is 256.",
)
trace_optimize.add_argument(
"--timeout", type=int, help="The maximum time in seconds to trace the entire workflow. Default is indefinite."
)
trace_optimize.add_argument(
"--output",
type=str,
default="codeflash.trace",
help="The file to save the trace to. Default is codeflash.trace.",
)
trace_optimize.add_argument(
"--config-file-path",
type=str,
help="The path to the pyproject.toml file which stores the Codeflash config. This is auto-discovered by default.",
)

return parser
95 changes: 95 additions & 0 deletions tests/test_cli_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Tests for CLI argument parsing, specifically the optimize subparser flag isolation."""

import sys
from unittest import mock

import pytest


@pytest.fixture(autouse=True)
def _clear_parser_cache():
"""Clear the lru_cache on _build_parser between tests."""
from codeflash.cli_cmds.cli import _build_parser

_build_parser.cache_clear()
yield
_build_parser.cache_clear()


class TestOptimizeSubparserFlags:
"""Test that flags defined on the main parser are also recognized by the optimize subparser."""

def _parse(self, argv: list[str]) -> tuple:
from codeflash.cli_cmds.cli import _build_parser

with mock.patch.object(sys, "argv", argv):
parser = _build_parser()
args, unknown = parser.parse_known_args()
return args, unknown

def test_verbose_flag_recognized_by_optimize_subparser(self) -> None:
args, unknown = self._parse(["codeflash", "optimize", "--verbose", "mvn", "test"])
assert args.verbose is True, f"--verbose should be True, got {args.verbose}"
assert "--verbose" not in unknown, f"--verbose leaked into unknown_args: {unknown}"

def test_no_pr_flag_recognized_by_optimize_subparser(self) -> None:
args, unknown = self._parse(["codeflash", "optimize", "--no-pr", "mvn", "test"])
assert args.no_pr is True, f"--no-pr should be True, got {args.no_pr}"
assert "--no-pr" not in unknown, f"--no-pr leaked into unknown_args: {unknown}"

def test_file_flag_recognized_by_optimize_subparser(self) -> None:
args, unknown = self._parse(["codeflash", "optimize", "--file", "Foo.java", "mvn", "test"])
assert args.file == "Foo.java", f"--file should be 'Foo.java', got {args.file}"
assert "--file" not in unknown, f"--file leaked into unknown_args: {unknown}"
assert "Foo.java" not in unknown, f"file value leaked into unknown_args: {unknown}"

def test_function_flag_recognized_by_optimize_subparser(self) -> None:
args, unknown = self._parse(["codeflash", "optimize", "--function", "bar", "mvn", "test"])
assert args.function == "bar", f"--function should be 'bar', got {args.function}"
assert "--function" not in unknown, f"--function leaked into unknown_args: {unknown}"

def test_multiple_flags_recognized_by_optimize_subparser(self) -> None:
args, unknown = self._parse(
["codeflash", "optimize", "--verbose", "--no-pr", "--file", "X.java", "--function", "foo", "mvn", "test"]
)
assert args.verbose is True
assert args.no_pr is True
assert args.file == "X.java"
assert args.function == "foo"
# Only the Java command should remain in unknown
assert unknown == ["mvn", "test"], f"Expected ['mvn', 'test'], got {unknown}"

def test_flags_after_java_command_not_leaked(self) -> None:
args, unknown = self._parse(["codeflash", "optimize", "mvn", "test"])
assert args.command == "optimize"
assert unknown == ["mvn", "test"]

def test_max_function_count_default_consistency(self) -> None:
args, _ = self._parse(["codeflash", "optimize", "mvn", "test"])
assert args.max_function_count == 256, (
f"max_function_count default should be 256 (matching tracer), got {args.max_function_count}"
)


class TestMainParserFlags:
"""Test that flags still work on the main parser (non-optimize path)."""

def _parse(self, argv: list[str]) -> tuple:
from codeflash.cli_cmds.cli import _build_parser

with mock.patch.object(sys, "argv", argv):
parser = _build_parser()
args, unknown = parser.parse_known_args()
return args, unknown

def test_verbose_on_main_parser(self) -> None:
args, _ = self._parse(["codeflash", "--verbose"])
assert args.verbose is True

def test_file_on_main_parser(self) -> None:
args, _ = self._parse(["codeflash", "--file", "test.java"])
assert args.file == "test.java"

def test_no_pr_on_main_parser(self) -> None:
args, _ = self._parse(["codeflash", "--no-pr"])
assert args.no_pr is True
Loading