From 5060c0811a7a73c4ad49a669d278b82b8886fa96 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 17 Mar 2026 00:25:52 -0600 Subject: [PATCH 01/10] perf: cache jedi project, batch test cache writes, fix Windows relative_to bug Port valuable improvements from #1846 that remain applicable after #1660: - Cache jedi.Project instances via @cache to avoid recreating across 5 call sites - Fix unguarded relative_to() in get_code_optimization_context (Windows 8.3 paths) - Pre-group references by parent function in get_function_sources_from_jedi for O(1) lookup - Batch TestsCache writes with flush() + executemany instead of per-row commit - Gracefully disable cache writes on sqlite3.OperationalError - Build functions_to_optimize_by_name dict for O(1) fallback lookup in process_test_files - Derive all_defs from all_names via is_definition() to save a redundant Jedi call --- codeflash/discovery/discover_unit_tests.py | 127 ++++++++++-------- .../python/context/code_context_extractor.py | 94 +++++++++---- 2 files changed, 136 insertions(+), 85 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index d9c3d4e3c..b3fec97cb 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -136,6 +136,8 @@ def __init__(self, project_root_path: Path) -> None: ) self.memory_cache = {} + self.pending_rows: list[tuple[str, str, str, str, str, str, int | str, int, int]] = [] + self.writes_enabled = True def insert_test( self, @@ -150,10 +152,8 @@ def insert_test( col_number: int, ) -> None: test_type_value = test_type.value if hasattr(test_type, "value") else test_type - self.cur.execute( - "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + self.pending_rows.append( ( - self.project_root_path, file_path, file_hash, qualified_name_with_modules_from_root, @@ -163,9 +163,26 @@ def insert_test( test_type_value, line_number, col_number, - ), + ) ) - self.connection.commit() + + def flush(self) -> None: + if not self.pending_rows: + return + if not self.writes_enabled: + self.pending_rows.clear() + return + try: + self.cur.executemany( + "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + [(self.project_root_path, *row) for row in self.pending_rows], + ) + self.connection.commit() + except sqlite3.OperationalError as e: + logger.debug(f"Failed to persist discovered test cache, disabling cache writes: {e}") + self.writes_enabled = False + finally: + self.pending_rows.clear() def get_function_to_test_map_for_file( self, file_path: str, file_hash: str @@ -212,6 +229,7 @@ def compute_file_hash(path: Path) -> str: return h.hexdigest() def close(self) -> None: + self.flush() self.cur.close() self.connection.close() @@ -849,6 +867,10 @@ def process_test_files( function_to_test_map = defaultdict(set) num_discovered_tests = 0 num_discovered_replay_tests = 0 + functions_to_optimize_by_name: dict[str, list[FunctionToOptimize]] = defaultdict(list) + if functions_to_optimize: + for function_to_optimize in functions_to_optimize: + functions_to_optimize_by_name[function_to_optimize.function_name].append(function_to_optimize) # Set up sys_path for Jedi to resolve imports correctly import sys @@ -891,8 +913,8 @@ def process_test_files( test_functions = set() all_names = script.get_names(all_scopes=True, references=True) - all_defs = script.get_names(all_scopes=True, definitions=True) all_names_top = script.get_names(all_scopes=True) + all_defs = [name for name in all_names if name.is_definition()] top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} @@ -967,10 +989,9 @@ def process_test_files( test_function_names_set = set(test_functions_by_name.keys()) relevant_names = [] - - names_with_full_name = [name for name in all_names if name.full_name is not None] - - for name in names_with_full_name: + for name in all_names: + if name.full_name is None: + continue match = FUNCTION_NAME_REGEX.search(name.full_name) if match and match.group(1) in test_function_names_set: relevant_names.append((name, match.group(1))) @@ -985,56 +1006,49 @@ def process_test_files( if not definition or definition[0].type != "function": # Fallback: Try to match against functions_to_optimize when Jedi can't resolve # This handles cases where Jedi fails with pytest fixtures - if functions_to_optimize and name.name: - for func_to_opt in functions_to_optimize: - # Check if this unresolved name matches a function we're looking for - if func_to_opt.function_name == name.name: - # Check if the test file imports the class/module containing this function - qualified_name_with_modules = func_to_opt.qualified_name_with_modules_from_root( - project_root_path - ) + if functions_to_optimize_by_name and name.name: + for func_to_opt in functions_to_optimize_by_name.get(name.name, []): + qualified_name_with_modules = func_to_opt.qualified_name_with_modules_from_root( + project_root_path + ) - # Only add if this test actually tests the function we're optimizing - for test_func in test_functions_by_name[scope]: - if test_func.parameters is not None: - if test_framework == "pytest": - scope_test_function = ( - f"{test_func.function_name}[{test_func.parameters}]" - ) - else: # unittest - scope_test_function = ( - f"{test_func.function_name}_{test_func.parameters}" - ) - else: - scope_test_function = test_func.function_name - - function_to_test_map[qualified_name_with_modules].add( - FunctionCalledInTest( - tests_in_file=TestsInFile( - test_file=test_file, - test_class=test_func.test_class, - test_function=scope_test_function, - test_type=test_func.test_type, - ), - position=CodePosition(line_no=name.line, col_no=name.column), - ) - ) - tests_cache.insert_test( - file_path=str(test_file), - file_hash=file_hash, - qualified_name_with_modules_from_root=qualified_name_with_modules, - function_name=scope, - test_class=test_func.test_class or "", - test_function=scope_test_function, - test_type=test_func.test_type, - line_number=name.line, - col_number=name.column, + # Only add if this test actually tests the function we're optimizing + for test_func in test_functions_by_name[scope]: + if test_func.parameters is not None: + if test_framework == "pytest": + scope_test_function = f"{test_func.function_name}[{test_func.parameters}]" + else: # unittest + scope_test_function = f"{test_func.function_name}_{test_func.parameters}" + else: + scope_test_function = test_func.function_name + + function_to_test_map[qualified_name_with_modules].add( + FunctionCalledInTest( + tests_in_file=TestsInFile( + test_file=test_file, + test_class=test_func.test_class, + test_function=scope_test_function, + test_type=test_func.test_type, + ), + position=CodePosition(line_no=name.line, col_no=name.column), ) + ) + tests_cache.insert_test( + file_path=str(test_file), + file_hash=file_hash, + qualified_name_with_modules_from_root=qualified_name_with_modules, + function_name=scope, + test_class=test_func.test_class or "", + test_function=scope_test_function, + test_type=test_func.test_type, + line_number=name.line, + col_number=name.column, + ) - if test_func.test_type == TestType.REPLAY_TEST: - num_discovered_replay_tests += 1 + if test_func.test_type == TestType.REPLAY_TEST: + num_discovered_replay_tests += 1 - num_discovered_tests += 1 + num_discovered_tests += 1 continue definition_obj = definition[0] definition_path = str(definition_obj.module_path) @@ -1090,6 +1104,7 @@ def process_test_files( logger.debug(str(e)) continue + tests_cache.flush() progress.advance(task_id) tests_cache.close() diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index af54a56fb..ad8bb331d 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -5,6 +5,7 @@ import os from collections import defaultdict from dataclasses import dataclass, field +from functools import cache from itertools import chain from typing import TYPE_CHECKING @@ -47,6 +48,13 @@ from codeflash.languages.python.context.unused_definition_remover import UsageInfo +@cache +def get_jedi_project(project_root_path: str) -> object: + import jedi + + return jedi.Project(path=project_root_path) + + @dataclass class FileContextCache: original_module: cst.Module @@ -102,15 +110,21 @@ def get_code_optimization_context( testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT, call_graph: DependencyResolver | None = None, ) -> CodeOptimizationContext: + jedi_project = get_jedi_project(str(project_root_path)) + # Get FunctionSource representation of helpers of FTO fto_input = {function_to_optimize.file_path: {function_to_optimize.qualified_name}} if call_graph is not None: helpers_of_fto_dict, helpers_of_fto_list = call_graph.get_callees(fto_input) else: - helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(fto_input, project_root_path) + helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi( + fto_input, project_root_path, jedi_project=jedi_project + ) # Add function to optimize into helpers of FTO dict, as they'll be processed together - fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path) + fto_as_function_source = get_function_to_optimize_as_function_source( + function_to_optimize, project_root_path, jedi_project=jedi_project + ) helpers_of_fto_dict[function_to_optimize.file_path].add(fto_as_function_source) # Format data to search for helpers of helpers using get_function_sources_from_jedi @@ -124,7 +138,7 @@ def get_code_optimization_context( qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if "." in qn}) helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi( - helpers_of_fto_qualified_names_dict, project_root_path + helpers_of_fto_qualified_names_dict, project_root_path, jedi_project=jedi_project ) # Extract all code contexts in a single pass (one CST parse per file) @@ -133,11 +147,14 @@ def get_code_optimization_context( final_read_writable_code = all_ctx.read_writable # Ensure the target file is first in the code blocks so the LLM knows which file to optimize - target_relative = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve()) - target_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path == target_relative] - other_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path != target_relative] - if target_blocks: - final_read_writable_code.code_strings = target_blocks + other_blocks + try: + target_relative = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve()) + target_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path == target_relative] + other_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path != target_relative] + if target_blocks: + final_read_writable_code.code_strings = target_blocks + other_blocks + except ValueError: + pass read_only_code_markdown = all_ctx.read_only @@ -434,13 +451,13 @@ def re_extract_from_cache( ) -> CodeStringsMarkdown: """Re-extract context from cached modules without file I/O or CST parsing.""" result = CodeStringsMarkdown() - for cache in file_caches: + for file_cache in file_caches: try: pruned = parse_code_and_prune_cst( - cache.cleaned_module, + file_cache.cleaned_module, code_context_type, - cache.fto_names, - cache.hoh_names, + file_cache.fto_names, + file_cache.hoh_names, remove_docstrings=remove_docstrings, ) except ValueError: @@ -450,24 +467,25 @@ def re_extract_from_cache( code = ast.unparse(ast.parse(pruned.code)) else: code = add_needed_imports_from_module( - src_module_code=cache.original_module, + src_module_code=file_cache.original_module, dst_module_code=pruned, - src_path=cache.file_path, - dst_path=cache.file_path, + src_path=file_cache.file_path, + dst_path=file_cache.file_path, project_root=project_root_path, - helper_functions=cache.helper_functions, + helper_functions=file_cache.helper_functions, ) - result.code_strings.append(CodeString(code=code, file_path=cache.relative_path)) + result.code_strings.append(CodeString(code=code, file_path=file_cache.relative_path)) return result def get_function_to_optimize_as_function_source( - function_to_optimize: FunctionToOptimize, project_root_path: Path + function_to_optimize: FunctionToOptimize, project_root_path: Path, *, jedi_project: object | None = None ) -> FunctionSource: import jedi # Use jedi to find function to optimize - script = jedi.Script(path=function_to_optimize.file_path, project=jedi.Project(path=project_root_path)) + project = jedi_project if jedi_project is not None else get_jedi_project(str(project_root_path)) + script = jedi.Script(path=function_to_optimize.file_path, project=project) # Get all names in the file names = script.get_names(all_scopes=True, definitions=True, references=False) @@ -498,22 +516,40 @@ def get_function_to_optimize_as_function_source( def get_function_sources_from_jedi( - file_path_to_qualified_function_names: dict[Path, set[str]], project_root_path: Path + file_path_to_qualified_function_names: dict[Path, set[str]], + project_root_path: Path, + *, + jedi_project: object | None = None, ) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: import jedi + project = jedi_project if jedi_project is not None else get_jedi_project(str(project_root_path)) file_path_to_function_source = defaultdict(set) function_source_list: list[FunctionSource] = [] for file_path, qualified_function_names in file_path_to_qualified_function_names.items(): - script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path)) + script = jedi.Script(path=file_path, project=project) file_refs = script.get_names(all_scopes=True, definitions=False, references=True) + # Pre-group references by their parent function's qualified name for O(1) lookup + refs_by_parent: dict[str, list[Name]] = defaultdict(list) + for ref in file_refs: + if not ref.full_name: + continue + try: + parent = ref.parent() + if parent is None or parent.type != "function": + continue + parent_qn = get_qualified_name(parent.module_name, parent.full_name) + # Exclude self-references (recursive calls) — the ref's own qualified name matches the parent + ref_qn = get_qualified_name(ref.module_name, ref.full_name) + if ref_qn == parent_qn: + continue + refs_by_parent[parent_qn].append(ref) + except (AttributeError, ValueError): + continue + for qualified_function_name in qualified_function_names: - names = [ - ref - for ref in file_refs - if ref.full_name and belongs_to_function_qualified(ref, qualified_function_name) - ] + names = refs_by_parent.get(qualified_function_name, []) for name in names: try: definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False) @@ -1103,7 +1139,7 @@ def _resolve_imported_class_reference( module_name, class_name = resolved_name.rsplit(".", 1) try: script_code = f"from {module_name} import {class_name}" - script = jedi.Script(script_code, project=jedi.Project(path=project_root_path)) + script = jedi.Script(script_code, project=get_jedi_project(str(project_root_path))) definitions = script.goto(1, len(f"from {module_name} import ") + len(class_name), follow_imports=True) except Exception: return None @@ -1263,7 +1299,7 @@ def extract_parameter_type_constructors( def append_type_context(type_name: str, module_name: str, *, transitive: bool = False) -> None: try: script_code = f"from {module_name} import {type_name}" - script = jedi.Script(script_code, project=jedi.Project(path=project_root_path)) + script = jedi.Script(script_code, project=get_jedi_project(str(project_root_path))) definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True) if not definitions: return @@ -1429,7 +1465,7 @@ def extract_class_and_bases( continue try: test_code = f"import {module_name}" - script = jedi.Script(test_code, project=jedi.Project(path=project_root_path)) + script = jedi.Script(test_code, project=get_jedi_project(str(project_root_path))) completions = script.goto(1, len(test_code)) if not completions: From 9f20cdf33f6c5140ff13aeff3b3727eb0ef0b4c1 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 06:33:18 +0000 Subject: [PATCH 02/10] fix: correct pending_rows type annotation to int | TestType --- codeflash/discovery/discover_unit_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index b3fec97cb..ba3371a90 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -136,7 +136,7 @@ def __init__(self, project_root_path: Path) -> None: ) self.memory_cache = {} - self.pending_rows: list[tuple[str, str, str, str, str, str, int | str, int, int]] = [] + self.pending_rows: list[tuple[str, str, str, str, str, str, int | TestType, int, int]] = [] self.writes_enabled = True def insert_test( From 5b52fba96a55aa99a7602e3013b46d1ce3f2801a Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 06:54:22 +0000 Subject: [PATCH 03/10] Optimize re_extract_from_cache The optimization added early-exit logic to `add_needed_imports_from_module` that checks whether the source module contains any module-level imports before invoking the heavyweight `GatherImportsVisitor` and downstream import-merging machinery. In the common case where a pruned module has no imports (or only nested ones inside functions), line profiling showed the gatherer and two `AddImportsVisitor`/`RemoveImportsVisitor` transforms consumed 36% of original runtime; the early exit skips all three, falling back to the destination code immediately. A second early exit after gathering verifies the visitor actually collected imports, avoiding redundant CST transformations when the source is import-free. Combined, these checks eliminate ~99% of the work when imports are absent, yielding an 18222% speedup with no semantic change because the fallback path always returned the correct destination code. --- .../python/context/code_context_extractor.py | 19 +++++++ .../python/static_analysis/code_extractor.py | 54 ++++++++++++++----- 2 files changed, 60 insertions(+), 13 deletions(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index ad8bb331d..d2b63cfda 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -1770,3 +1770,22 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b return False except ValueError: return False + + +def _maybe_strip_docstring(node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfig) -> cst.FunctionDef | cst.ClassDef: + """Strip docstring from function or class if configured to do so.""" + if not cfg.remove_docstrings or not isinstance(node.body, cst.IndentedBlock): + return node + + body_stmts = node.body.body + if not body_stmts: + return node + + first_stmt = body_stmts[0] + if isinstance(first_stmt, cst.SimpleStatementLine) and len(first_stmt.body) == 1: + expr_stmt = first_stmt.body[0] + if isinstance(expr_stmt, cst.Expr) and isinstance(expr_stmt.value, cst.SimpleString | cst.ConcatenatedString): + new_body = body_stmts[1:] or [cst.SimpleStatementLine(body=[cst.Pass()])] + return node.with_changes(body=node.body.with_changes(body=new_body)) + + return node diff --git a/codeflash/languages/python/static_analysis/code_extractor.py b/codeflash/languages/python/static_analysis/code_extractor.py index 1b315d629..80a274770 100644 --- a/codeflash/languages/python/static_analysis/code_extractor.py +++ b/codeflash/languages/python/static_analysis/code_extractor.py @@ -553,7 +553,13 @@ def add_needed_imports_from_module( if not helper_functions_fqn: helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])} - dst_code_fallback = dst_module_code if isinstance(dst_module_code, str) else dst_module_code.code + # Cache the fallback early to avoid repeated isinstance checks + if isinstance(dst_module_code, str): + dst_code_fallback = dst_module_code + parsed_dst_module = None + else: + dst_code_fallback = dst_module_code.code + parsed_dst_module = dst_module_code src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path) dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path) @@ -563,18 +569,32 @@ def add_needed_imports_from_module( full_module_name=dst_module_and_package.name, full_package_name=dst_module_and_package.package, ) - gatherer: GatherImportsVisitor = GatherImportsVisitor( - CodemodContext( - filename=src_path.name, - full_module_name=src_module_and_package.name, - full_package_name=src_module_and_package.package, - ) - ) try: if isinstance(src_module_code, cst.Module): src_module = src_module_code.visit(FutureAliasedImportTransformer()) else: src_module = cst.parse_module(src_module_code).visit(FutureAliasedImportTransformer()) + + # Early exit: check if source has any imports at module level + has_module_level_imports = any( + isinstance(stmt, (cst.Import, cst.ImportFrom)) + for stmt in src_module.body + if isinstance(stmt, cst.SimpleStatementLine) + for s in stmt.body + if isinstance(s, (cst.Import, cst.ImportFrom)) + ) + + if not has_module_level_imports: + return dst_code_fallback + + gatherer: GatherImportsVisitor = GatherImportsVisitor( + CodemodContext( + filename=src_path.name, + full_module_name=src_module_and_package.name, + full_package_name=src_module_and_package.package, + ) + ) + # Exclude function/class bodies so GatherImportsVisitor only sees module-level imports. # Nested imports (inside functions) are part of function logic and must not be # scheduled for add/remove — RemoveImportsVisitor would strip them as "unused". @@ -582,22 +602,30 @@ def add_needed_imports_from_module( body=[stmt for stmt in src_module.body if not isinstance(stmt, (cst.FunctionDef, cst.ClassDef))] ) module_level_only.visit(gatherer) + + # Early exit: if no imports were gathered, return destination as-is + if ( + not gatherer.module_imports + and not gatherer.object_mapping + and not gatherer.module_aliases + and not gatherer.alias_mapping + ): + return dst_code_fallback + except Exception as e: logger.error(f"Error parsing source module code: {e}") return dst_code_fallback dotted_import_collector = DottedImportCollector() - if isinstance(dst_module_code, cst.Module): - parsed_dst_module = dst_module_code - parsed_dst_module.visit(dotted_import_collector) - else: + if parsed_dst_module is None: try: parsed_dst_module = cst.parse_module(dst_module_code) - parsed_dst_module.visit(dotted_import_collector) except cst.ParserSyntaxError as e: logger.exception(f"Syntax error in destination module code: {e}") return dst_code_fallback + parsed_dst_module.visit(dotted_import_collector) + try: for mod in gatherer.module_imports: # Skip __future__ imports as they cannot be imported directly From dc8548f4098f242a3474f5e6d859e2ccc87e4c43 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 08:14:32 +0000 Subject: [PATCH 04/10] fix: remove duplicate _maybe_strip_docstring definition causing mypy no-redef error Co-authored-by: Kevin Turcios --- .../python/context/code_context_extractor.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index d2b63cfda..f30251dc6 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -1583,20 +1583,6 @@ def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None: names.add(node.value.id) -def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode: - if not isinstance(indented_block.body[0], cst.SimpleStatementLine): - return indented_block - first_stmt = indented_block.body[0].body[0] - if isinstance(first_stmt, cst.Expr) and isinstance(first_stmt.value, cst.SimpleString): - return indented_block.with_changes(body=indented_block.body[1:]) - return indented_block - - -def _maybe_strip_docstring(node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfig) -> cst.FunctionDef | cst.ClassDef: - if cfg.remove_docstrings and isinstance(node.body, cst.IndentedBlock): - return node.with_changes(body=remove_docstring_from_body(node.body)) - return node - class ImportCollector(ast.NodeVisitor): def __init__(self) -> None: From 7e675e46bce85763297fd0ac3be24021bbb4fabf Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 17 Mar 2026 02:18:56 -0600 Subject: [PATCH 05/10] fix: remove extra blank line to pass ruff format --- codeflash/languages/python/context/code_context_extractor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index f30251dc6..0a7152a0e 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -1583,7 +1583,6 @@ def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None: names.add(node.value.id) - class ImportCollector(ast.NodeVisitor): def __init__(self) -> None: self.imported_names: dict[str, str] = {} From 96d20d376cd0b8e738591a24ba67f09b5704c725 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 08:28:54 +0000 Subject: [PATCH 06/10] fix: correct has_module_level_imports early-exit check in add_needed_imports_from_module The generator expression yielded `isinstance(stmt, ...)` where `stmt` was already filtered to be `cst.SimpleStatementLine`, so it always returned False. This caused add_needed_imports_from_module to always skip adding imports, breaking 11 tests in test_code_context_extractor.py. Co-authored-by: Kevin Turcios --- codeflash/languages/python/static_analysis/code_extractor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/codeflash/languages/python/static_analysis/code_extractor.py b/codeflash/languages/python/static_analysis/code_extractor.py index 80a274770..cb269f69f 100644 --- a/codeflash/languages/python/static_analysis/code_extractor.py +++ b/codeflash/languages/python/static_analysis/code_extractor.py @@ -577,11 +577,10 @@ def add_needed_imports_from_module( # Early exit: check if source has any imports at module level has_module_level_imports = any( - isinstance(stmt, (cst.Import, cst.ImportFrom)) + isinstance(s, (cst.Import, cst.ImportFrom)) for stmt in src_module.body if isinstance(stmt, cst.SimpleStatementLine) for s in stmt.body - if isinstance(s, (cst.Import, cst.ImportFrom)) ) if not has_module_level_imports: From fbd0007f13166b342c4bd4ff0d43899fd9e0762f Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 17 Mar 2026 03:18:04 -0600 Subject: [PATCH 07/10] fix: resolve Python 3.9 isinstance union type and extra blank lines in flat output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace with in _maybe_strip_docstring — the | syntax for runtime isinstance is Python 3.10+ only and caused TypeError on 3.9 - Strip leading newlines from cst.Module.code in add_needed_imports_from_module fallback path so early-exit returns are consistent with the normal transformed_module.code.lstrip("\n") path, fixing extra blank lines after file headers in read_writable_code.flat output Co-Authored-By: Oz --- codeflash/languages/python/context/code_context_extractor.py | 2 +- codeflash/languages/python/static_analysis/code_extractor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 0a7152a0e..a13965ed0 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -1769,7 +1769,7 @@ def _maybe_strip_docstring(node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfi first_stmt = body_stmts[0] if isinstance(first_stmt, cst.SimpleStatementLine) and len(first_stmt.body) == 1: expr_stmt = first_stmt.body[0] - if isinstance(expr_stmt, cst.Expr) and isinstance(expr_stmt.value, cst.SimpleString | cst.ConcatenatedString): + if isinstance(expr_stmt, cst.Expr) and isinstance(expr_stmt.value, (cst.SimpleString, cst.ConcatenatedString)): new_body = body_stmts[1:] or [cst.SimpleStatementLine(body=[cst.Pass()])] return node.with_changes(body=node.body.with_changes(body=new_body)) diff --git a/codeflash/languages/python/static_analysis/code_extractor.py b/codeflash/languages/python/static_analysis/code_extractor.py index cb269f69f..c71b54d00 100644 --- a/codeflash/languages/python/static_analysis/code_extractor.py +++ b/codeflash/languages/python/static_analysis/code_extractor.py @@ -558,7 +558,7 @@ def add_needed_imports_from_module( dst_code_fallback = dst_module_code parsed_dst_module = None else: - dst_code_fallback = dst_module_code.code + dst_code_fallback = dst_module_code.code.lstrip("\n") parsed_dst_module = dst_module_code src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path) From ba7feb079575558e04194a61c9b3424ccfbd8427 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 21:38:11 +0000 Subject: [PATCH 08/10] fix: restore fallback behavior and fix type narrowing in add_needed_imports_from_module - Remove silent .lstrip("\n") on fallback for cst.Module input (restores original behavior) - Replace `parsed_dst_module is None` with `isinstance(dst_module_code, str)` for correct mypy narrowing - Change get_jedi_project return type from object to Any for accuracy Co-authored-by: Kevin Turcios Co-Authored-By: Claude Sonnet 4.6 --- codeflash/languages/python/context/code_context_extractor.py | 4 ++-- codeflash/languages/python/static_analysis/code_extractor.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index a13965ed0..bfbf02fc4 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from functools import cache from itertools import chain -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import libcst as cst @@ -49,7 +49,7 @@ @cache -def get_jedi_project(project_root_path: str) -> object: +def get_jedi_project(project_root_path: str) -> Any: import jedi return jedi.Project(path=project_root_path) diff --git a/codeflash/languages/python/static_analysis/code_extractor.py b/codeflash/languages/python/static_analysis/code_extractor.py index c71b54d00..2570c1b90 100644 --- a/codeflash/languages/python/static_analysis/code_extractor.py +++ b/codeflash/languages/python/static_analysis/code_extractor.py @@ -558,7 +558,7 @@ def add_needed_imports_from_module( dst_code_fallback = dst_module_code parsed_dst_module = None else: - dst_code_fallback = dst_module_code.code.lstrip("\n") + dst_code_fallback = dst_module_code.code parsed_dst_module = dst_module_code src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path) @@ -616,7 +616,7 @@ def add_needed_imports_from_module( return dst_code_fallback dotted_import_collector = DottedImportCollector() - if parsed_dst_module is None: + if isinstance(dst_module_code, str): try: parsed_dst_module = cst.parse_module(dst_module_code) except cst.ParserSyntaxError as e: From bf9adf2673029058746745e659d2f4f86523366d Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 17 Mar 2026 18:16:54 -0600 Subject: [PATCH 09/10] fix: normalize module fallback formatting for import merge --- .../python/static_analysis/code_extractor.py | 3 ++- tests/test_add_needed_imports_from_module.py | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/codeflash/languages/python/static_analysis/code_extractor.py b/codeflash/languages/python/static_analysis/code_extractor.py index 2570c1b90..6bd81bb74 100644 --- a/codeflash/languages/python/static_analysis/code_extractor.py +++ b/codeflash/languages/python/static_analysis/code_extractor.py @@ -558,7 +558,8 @@ def add_needed_imports_from_module( dst_code_fallback = dst_module_code parsed_dst_module = None else: - dst_code_fallback = dst_module_code.code + # Keep Module-input fallback formatting aligned with transformed_module.code.lstrip("\n"). + dst_code_fallback = dst_module_code.code.lstrip("\n") parsed_dst_module = dst_module_code src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path) diff --git a/tests/test_add_needed_imports_from_module.py b/tests/test_add_needed_imports_from_module.py index 345b966dc..198058b28 100644 --- a/tests/test_add_needed_imports_from_module.py +++ b/tests/test_add_needed_imports_from_module.py @@ -527,3 +527,23 @@ def parse(): return PATTERN.findall("") """ assert result == expected + + +def test_module_input_fallback_strips_leading_newlines() -> None: + src_code = """ +def parse(): + return helper() + +def helper(): + return 1 +""" + parsed_module = cst.parse_module(src_code) + + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + file_path = project_root / "mod.py" + file_path.write_text(src_code) + + result = add_needed_imports_from_module(src_code, parsed_module, file_path, file_path, project_root) + + assert result == src_code.lstrip("\n") From 7cf183e78777e43fac9a2dbf93bbfdd8aa6e4fea Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Wed, 18 Mar 2026 00:21:47 +0000 Subject: [PATCH 10/10] fix: add else branch for parsed_dst_module to resolve mypy None narrowing The early `parsed_dst_module = None` assignment broke mypy's type narrowing. Add explicit `else: parsed_dst_module = dst_module_code` so mypy sees the variable is always a `cst.Module` before the `.visit()` call. Co-authored-by: Kevin Turcios --- codeflash/languages/python/static_analysis/code_extractor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codeflash/languages/python/static_analysis/code_extractor.py b/codeflash/languages/python/static_analysis/code_extractor.py index 6bd81bb74..ee37d2146 100644 --- a/codeflash/languages/python/static_analysis/code_extractor.py +++ b/codeflash/languages/python/static_analysis/code_extractor.py @@ -556,11 +556,9 @@ def add_needed_imports_from_module( # Cache the fallback early to avoid repeated isinstance checks if isinstance(dst_module_code, str): dst_code_fallback = dst_module_code - parsed_dst_module = None else: # Keep Module-input fallback formatting aligned with transformed_module.code.lstrip("\n"). dst_code_fallback = dst_module_code.code.lstrip("\n") - parsed_dst_module = dst_module_code src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path) dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path) @@ -623,6 +621,8 @@ def add_needed_imports_from_module( except cst.ParserSyntaxError as e: logger.exception(f"Syntax error in destination module code: {e}") return dst_code_fallback + else: + parsed_dst_module = dst_module_code parsed_dst_module.visit(dotted_import_collector)