diff --git a/codeflash/cli_cmds/init_java.py b/codeflash/cli_cmds/init_java.py index 735e60e97..c73d39001 100644 --- a/codeflash/cli_cmds/init_java.py +++ b/codeflash/cli_cmds/init_java.py @@ -392,40 +392,37 @@ def _prompt_custom_directory(dir_type: str) -> str: def _get_git_remote_for_setup() -> str: """Get git remote for project setup.""" - try: - repo = Repo(Path.cwd(), search_parent_directories=True) - git_remotes = get_git_remotes(repo) - if not git_remotes: - return "" - - if len(git_remotes) == 1: - return git_remotes[0] - - git_panel = Panel( - Text( - "Configure Git Remote for Pull Requests.\n\nCodeflash will use this remote to create pull requests.", - style="blue", - ), - title="Git Remote Setup", - border_style="bright_blue", - ) - console.print(git_panel) - console.print() + cwd = Path.cwd().as_posix() + git_remotes = _cached_git_remotes_for_cwd(cwd) + if not git_remotes: + return "" - git_questions = [ - inquirer.List( - "git_remote", - message="Which git remote should Codeflash use?", - choices=git_remotes, - default="origin", - carousel=True, - ) - ] + if len(git_remotes) == 1: + return git_remotes[0] - git_answers = inquirer.prompt(git_questions, theme=_get_theme()) - return git_answers["git_remote"] if git_answers else git_remotes[0] - except InvalidGitRepositoryError: - return "" + git_panel = Panel( + Text( + "Configure Git Remote for Pull Requests.\n\nCodeflash will use this remote to create pull requests.", + style="blue", + ), + title="Git Remote Setup", + border_style="bright_blue", + ) + console.print(git_panel) + console.print() + + git_questions = [ + inquirer.List( + "git_remote", + message="Which git remote should Codeflash use?", + choices=git_remotes, + default="origin", + carousel=True, + ) + ] + + git_answers = inquirer.prompt(git_questions, theme=_get_theme()) + return git_answers["git_remote"] if git_answers else git_remotes[0] def get_java_formatter_cmd(formatter: str, build_tool: JavaBuildTool) -> list[str]: @@ -547,6 +544,22 @@ def get_java_test_command(build_tool: JavaBuildTool) -> str: return "mvn test" +@lru_cache(maxsize=32) +def _cached_repo_for_cwd(cwd: str) -> Repo | None: + try: + return Repo(Path(cwd), search_parent_directories=True) + except InvalidGitRepositoryError: + return None + + +@lru_cache(maxsize=32) +def _cached_git_remotes_for_cwd(cwd: str) -> list[str]: + repo = _cached_repo_for_cwd(cwd) + if not repo: + return [] + return get_git_remotes(repo) + + formatter_warning_shown = False _SPOTLESS_COMMANDS = {