diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index d9c3d4e3c..ba3371a90 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -136,6 +136,8 @@ def __init__(self, project_root_path: Path) -> None: ) self.memory_cache = {} + self.pending_rows: list[tuple[str, str, str, str, str, str, int | TestType, int, int]] = [] + self.writes_enabled = True def insert_test( self, @@ -150,10 +152,8 @@ def insert_test( col_number: int, ) -> None: test_type_value = test_type.value if hasattr(test_type, "value") else test_type - self.cur.execute( - "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + self.pending_rows.append( ( - self.project_root_path, file_path, file_hash, qualified_name_with_modules_from_root, @@ -163,9 +163,26 @@ def insert_test( test_type_value, line_number, col_number, - ), + ) ) - self.connection.commit() + + def flush(self) -> None: + if not self.pending_rows: + return + if not self.writes_enabled: + self.pending_rows.clear() + return + try: + self.cur.executemany( + "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + [(self.project_root_path, *row) for row in self.pending_rows], + ) + self.connection.commit() + except sqlite3.OperationalError as e: + logger.debug(f"Failed to persist discovered test cache, disabling cache writes: {e}") + self.writes_enabled = False + finally: + self.pending_rows.clear() def get_function_to_test_map_for_file( self, file_path: str, file_hash: str @@ -212,6 +229,7 @@ def compute_file_hash(path: Path) -> str: return h.hexdigest() def close(self) -> None: + self.flush() self.cur.close() self.connection.close() @@ -849,6 +867,10 @@ def process_test_files( function_to_test_map = defaultdict(set) num_discovered_tests = 0 num_discovered_replay_tests = 0 + functions_to_optimize_by_name: dict[str, list[FunctionToOptimize]] = defaultdict(list) + if functions_to_optimize: + for function_to_optimize in functions_to_optimize: + functions_to_optimize_by_name[function_to_optimize.function_name].append(function_to_optimize) # Set up sys_path for Jedi to resolve imports correctly import sys @@ -891,8 +913,8 @@ def process_test_files( test_functions = set() all_names = script.get_names(all_scopes=True, references=True) - all_defs = script.get_names(all_scopes=True, definitions=True) all_names_top = script.get_names(all_scopes=True) + all_defs = [name for name in all_names if name.is_definition()] top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} @@ -967,10 +989,9 @@ def process_test_files( test_function_names_set = set(test_functions_by_name.keys()) relevant_names = [] - - names_with_full_name = [name for name in all_names if name.full_name is not None] - - for name in names_with_full_name: + for name in all_names: + if name.full_name is None: + continue match = FUNCTION_NAME_REGEX.search(name.full_name) if match and match.group(1) in test_function_names_set: relevant_names.append((name, match.group(1))) @@ -985,56 +1006,49 @@ def process_test_files( if not definition or definition[0].type != "function": # Fallback: Try to match against functions_to_optimize when Jedi can't resolve # This handles cases where Jedi fails with pytest fixtures - if functions_to_optimize and name.name: - for func_to_opt in functions_to_optimize: - # Check if this unresolved name matches a function we're looking for - if func_to_opt.function_name == name.name: - # Check if the test file imports the class/module containing this function - qualified_name_with_modules = func_to_opt.qualified_name_with_modules_from_root( - project_root_path - ) + if functions_to_optimize_by_name and name.name: + for func_to_opt in functions_to_optimize_by_name.get(name.name, []): + qualified_name_with_modules = func_to_opt.qualified_name_with_modules_from_root( + project_root_path + ) - # Only add if this test actually tests the function we're optimizing - for test_func in test_functions_by_name[scope]: - if test_func.parameters is not None: - if test_framework == "pytest": - scope_test_function = ( - f"{test_func.function_name}[{test_func.parameters}]" - ) - else: # unittest - scope_test_function = ( - f"{test_func.function_name}_{test_func.parameters}" - ) - else: - scope_test_function = test_func.function_name - - function_to_test_map[qualified_name_with_modules].add( - FunctionCalledInTest( - tests_in_file=TestsInFile( - test_file=test_file, - test_class=test_func.test_class, - test_function=scope_test_function, - test_type=test_func.test_type, - ), - position=CodePosition(line_no=name.line, col_no=name.column), - ) - ) - tests_cache.insert_test( - file_path=str(test_file), - file_hash=file_hash, - qualified_name_with_modules_from_root=qualified_name_with_modules, - function_name=scope, - test_class=test_func.test_class or "", - test_function=scope_test_function, - test_type=test_func.test_type, - line_number=name.line, - col_number=name.column, + # Only add if this test actually tests the function we're optimizing + for test_func in test_functions_by_name[scope]: + if test_func.parameters is not None: + if test_framework == "pytest": + scope_test_function = f"{test_func.function_name}[{test_func.parameters}]" + else: # unittest + scope_test_function = f"{test_func.function_name}_{test_func.parameters}" + else: + scope_test_function = test_func.function_name + + function_to_test_map[qualified_name_with_modules].add( + FunctionCalledInTest( + tests_in_file=TestsInFile( + test_file=test_file, + test_class=test_func.test_class, + test_function=scope_test_function, + test_type=test_func.test_type, + ), + position=CodePosition(line_no=name.line, col_no=name.column), ) + ) + tests_cache.insert_test( + file_path=str(test_file), + file_hash=file_hash, + qualified_name_with_modules_from_root=qualified_name_with_modules, + function_name=scope, + test_class=test_func.test_class or "", + test_function=scope_test_function, + test_type=test_func.test_type, + line_number=name.line, + col_number=name.column, + ) - if test_func.test_type == TestType.REPLAY_TEST: - num_discovered_replay_tests += 1 + if test_func.test_type == TestType.REPLAY_TEST: + num_discovered_replay_tests += 1 - num_discovered_tests += 1 + num_discovered_tests += 1 continue definition_obj = definition[0] definition_path = str(definition_obj.module_path) @@ -1090,6 +1104,7 @@ def process_test_files( logger.debug(str(e)) continue + tests_cache.flush() progress.advance(task_id) tests_cache.close() diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index af54a56fb..bfbf02fc4 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -5,8 +5,9 @@ import os from collections import defaultdict from dataclasses import dataclass, field +from functools import cache from itertools import chain -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import libcst as cst @@ -47,6 +48,13 @@ from codeflash.languages.python.context.unused_definition_remover import UsageInfo +@cache +def get_jedi_project(project_root_path: str) -> Any: + import jedi + + return jedi.Project(path=project_root_path) + + @dataclass class FileContextCache: original_module: cst.Module @@ -102,15 +110,21 @@ def get_code_optimization_context( testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT, call_graph: DependencyResolver | None = None, ) -> CodeOptimizationContext: + jedi_project = get_jedi_project(str(project_root_path)) + # Get FunctionSource representation of helpers of FTO fto_input = {function_to_optimize.file_path: {function_to_optimize.qualified_name}} if call_graph is not None: helpers_of_fto_dict, helpers_of_fto_list = call_graph.get_callees(fto_input) else: - helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(fto_input, project_root_path) + helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi( + fto_input, project_root_path, jedi_project=jedi_project + ) # Add function to optimize into helpers of FTO dict, as they'll be processed together - fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path) + fto_as_function_source = get_function_to_optimize_as_function_source( + function_to_optimize, project_root_path, jedi_project=jedi_project + ) helpers_of_fto_dict[function_to_optimize.file_path].add(fto_as_function_source) # Format data to search for helpers of helpers using get_function_sources_from_jedi @@ -124,7 +138,7 @@ def get_code_optimization_context( qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if "." in qn}) helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi( - helpers_of_fto_qualified_names_dict, project_root_path + helpers_of_fto_qualified_names_dict, project_root_path, jedi_project=jedi_project ) # Extract all code contexts in a single pass (one CST parse per file) @@ -133,11 +147,14 @@ def get_code_optimization_context( final_read_writable_code = all_ctx.read_writable # Ensure the target file is first in the code blocks so the LLM knows which file to optimize - target_relative = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve()) - target_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path == target_relative] - other_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path != target_relative] - if target_blocks: - final_read_writable_code.code_strings = target_blocks + other_blocks + try: + target_relative = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve()) + target_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path == target_relative] + other_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path != target_relative] + if target_blocks: + final_read_writable_code.code_strings = target_blocks + other_blocks + except ValueError: + pass read_only_code_markdown = all_ctx.read_only @@ -434,13 +451,13 @@ def re_extract_from_cache( ) -> CodeStringsMarkdown: """Re-extract context from cached modules without file I/O or CST parsing.""" result = CodeStringsMarkdown() - for cache in file_caches: + for file_cache in file_caches: try: pruned = parse_code_and_prune_cst( - cache.cleaned_module, + file_cache.cleaned_module, code_context_type, - cache.fto_names, - cache.hoh_names, + file_cache.fto_names, + file_cache.hoh_names, remove_docstrings=remove_docstrings, ) except ValueError: @@ -450,24 +467,25 @@ def re_extract_from_cache( code = ast.unparse(ast.parse(pruned.code)) else: code = add_needed_imports_from_module( - src_module_code=cache.original_module, + src_module_code=file_cache.original_module, dst_module_code=pruned, - src_path=cache.file_path, - dst_path=cache.file_path, + src_path=file_cache.file_path, + dst_path=file_cache.file_path, project_root=project_root_path, - helper_functions=cache.helper_functions, + helper_functions=file_cache.helper_functions, ) - result.code_strings.append(CodeString(code=code, file_path=cache.relative_path)) + result.code_strings.append(CodeString(code=code, file_path=file_cache.relative_path)) return result def get_function_to_optimize_as_function_source( - function_to_optimize: FunctionToOptimize, project_root_path: Path + function_to_optimize: FunctionToOptimize, project_root_path: Path, *, jedi_project: object | None = None ) -> FunctionSource: import jedi # Use jedi to find function to optimize - script = jedi.Script(path=function_to_optimize.file_path, project=jedi.Project(path=project_root_path)) + project = jedi_project if jedi_project is not None else get_jedi_project(str(project_root_path)) + script = jedi.Script(path=function_to_optimize.file_path, project=project) # Get all names in the file names = script.get_names(all_scopes=True, definitions=True, references=False) @@ -498,22 +516,40 @@ def get_function_to_optimize_as_function_source( def get_function_sources_from_jedi( - file_path_to_qualified_function_names: dict[Path, set[str]], project_root_path: Path + file_path_to_qualified_function_names: dict[Path, set[str]], + project_root_path: Path, + *, + jedi_project: object | None = None, ) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: import jedi + project = jedi_project if jedi_project is not None else get_jedi_project(str(project_root_path)) file_path_to_function_source = defaultdict(set) function_source_list: list[FunctionSource] = [] for file_path, qualified_function_names in file_path_to_qualified_function_names.items(): - script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path)) + script = jedi.Script(path=file_path, project=project) file_refs = script.get_names(all_scopes=True, definitions=False, references=True) + # Pre-group references by their parent function's qualified name for O(1) lookup + refs_by_parent: dict[str, list[Name]] = defaultdict(list) + for ref in file_refs: + if not ref.full_name: + continue + try: + parent = ref.parent() + if parent is None or parent.type != "function": + continue + parent_qn = get_qualified_name(parent.module_name, parent.full_name) + # Exclude self-references (recursive calls) — the ref's own qualified name matches the parent + ref_qn = get_qualified_name(ref.module_name, ref.full_name) + if ref_qn == parent_qn: + continue + refs_by_parent[parent_qn].append(ref) + except (AttributeError, ValueError): + continue + for qualified_function_name in qualified_function_names: - names = [ - ref - for ref in file_refs - if ref.full_name and belongs_to_function_qualified(ref, qualified_function_name) - ] + names = refs_by_parent.get(qualified_function_name, []) for name in names: try: definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False) @@ -1103,7 +1139,7 @@ def _resolve_imported_class_reference( module_name, class_name = resolved_name.rsplit(".", 1) try: script_code = f"from {module_name} import {class_name}" - script = jedi.Script(script_code, project=jedi.Project(path=project_root_path)) + script = jedi.Script(script_code, project=get_jedi_project(str(project_root_path))) definitions = script.goto(1, len(f"from {module_name} import ") + len(class_name), follow_imports=True) except Exception: return None @@ -1263,7 +1299,7 @@ def extract_parameter_type_constructors( def append_type_context(type_name: str, module_name: str, *, transitive: bool = False) -> None: try: script_code = f"from {module_name} import {type_name}" - script = jedi.Script(script_code, project=jedi.Project(path=project_root_path)) + script = jedi.Script(script_code, project=get_jedi_project(str(project_root_path))) definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True) if not definitions: return @@ -1429,7 +1465,7 @@ def extract_class_and_bases( continue try: test_code = f"import {module_name}" - script = jedi.Script(test_code, project=jedi.Project(path=project_root_path)) + script = jedi.Script(test_code, project=get_jedi_project(str(project_root_path))) completions = script.goto(1, len(test_code)) if not completions: @@ -1547,21 +1583,6 @@ def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None: names.add(node.value.id) -def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode: - if not isinstance(indented_block.body[0], cst.SimpleStatementLine): - return indented_block - first_stmt = indented_block.body[0].body[0] - if isinstance(first_stmt, cst.Expr) and isinstance(first_stmt.value, cst.SimpleString): - return indented_block.with_changes(body=indented_block.body[1:]) - return indented_block - - -def _maybe_strip_docstring(node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfig) -> cst.FunctionDef | cst.ClassDef: - if cfg.remove_docstrings and isinstance(node.body, cst.IndentedBlock): - return node.with_changes(body=remove_docstring_from_body(node.body)) - return node - - class ImportCollector(ast.NodeVisitor): def __init__(self) -> None: self.imported_names: dict[str, str] = {} @@ -1734,3 +1755,22 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b return False except ValueError: return False + + +def _maybe_strip_docstring(node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfig) -> cst.FunctionDef | cst.ClassDef: + """Strip docstring from function or class if configured to do so.""" + if not cfg.remove_docstrings or not isinstance(node.body, cst.IndentedBlock): + return node + + body_stmts = node.body.body + if not body_stmts: + return node + + first_stmt = body_stmts[0] + if isinstance(first_stmt, cst.SimpleStatementLine) and len(first_stmt.body) == 1: + expr_stmt = first_stmt.body[0] + if isinstance(expr_stmt, cst.Expr) and isinstance(expr_stmt.value, (cst.SimpleString, cst.ConcatenatedString)): + new_body = body_stmts[1:] or [cst.SimpleStatementLine(body=[cst.Pass()])] + return node.with_changes(body=node.body.with_changes(body=new_body)) + + return node diff --git a/codeflash/languages/python/static_analysis/code_extractor.py b/codeflash/languages/python/static_analysis/code_extractor.py index 1b315d629..ee37d2146 100644 --- a/codeflash/languages/python/static_analysis/code_extractor.py +++ b/codeflash/languages/python/static_analysis/code_extractor.py @@ -553,7 +553,12 @@ def add_needed_imports_from_module( if not helper_functions_fqn: helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])} - dst_code_fallback = dst_module_code if isinstance(dst_module_code, str) else dst_module_code.code + # Cache the fallback early to avoid repeated isinstance checks + if isinstance(dst_module_code, str): + dst_code_fallback = dst_module_code + else: + # Keep Module-input fallback formatting aligned with transformed_module.code.lstrip("\n"). + dst_code_fallback = dst_module_code.code.lstrip("\n") src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path) dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path) @@ -563,18 +568,31 @@ def add_needed_imports_from_module( full_module_name=dst_module_and_package.name, full_package_name=dst_module_and_package.package, ) - gatherer: GatherImportsVisitor = GatherImportsVisitor( - CodemodContext( - filename=src_path.name, - full_module_name=src_module_and_package.name, - full_package_name=src_module_and_package.package, - ) - ) try: if isinstance(src_module_code, cst.Module): src_module = src_module_code.visit(FutureAliasedImportTransformer()) else: src_module = cst.parse_module(src_module_code).visit(FutureAliasedImportTransformer()) + + # Early exit: check if source has any imports at module level + has_module_level_imports = any( + isinstance(s, (cst.Import, cst.ImportFrom)) + for stmt in src_module.body + if isinstance(stmt, cst.SimpleStatementLine) + for s in stmt.body + ) + + if not has_module_level_imports: + return dst_code_fallback + + gatherer: GatherImportsVisitor = GatherImportsVisitor( + CodemodContext( + filename=src_path.name, + full_module_name=src_module_and_package.name, + full_package_name=src_module_and_package.package, + ) + ) + # Exclude function/class bodies so GatherImportsVisitor only sees module-level imports. # Nested imports (inside functions) are part of function logic and must not be # scheduled for add/remove — RemoveImportsVisitor would strip them as "unused". @@ -582,21 +600,31 @@ def add_needed_imports_from_module( body=[stmt for stmt in src_module.body if not isinstance(stmt, (cst.FunctionDef, cst.ClassDef))] ) module_level_only.visit(gatherer) + + # Early exit: if no imports were gathered, return destination as-is + if ( + not gatherer.module_imports + and not gatherer.object_mapping + and not gatherer.module_aliases + and not gatherer.alias_mapping + ): + return dst_code_fallback + except Exception as e: logger.error(f"Error parsing source module code: {e}") return dst_code_fallback dotted_import_collector = DottedImportCollector() - if isinstance(dst_module_code, cst.Module): - parsed_dst_module = dst_module_code - parsed_dst_module.visit(dotted_import_collector) - else: + if isinstance(dst_module_code, str): try: parsed_dst_module = cst.parse_module(dst_module_code) - parsed_dst_module.visit(dotted_import_collector) except cst.ParserSyntaxError as e: logger.exception(f"Syntax error in destination module code: {e}") return dst_code_fallback + else: + parsed_dst_module = dst_module_code + + parsed_dst_module.visit(dotted_import_collector) try: for mod in gatherer.module_imports: diff --git a/tests/test_add_needed_imports_from_module.py b/tests/test_add_needed_imports_from_module.py index 345b966dc..198058b28 100644 --- a/tests/test_add_needed_imports_from_module.py +++ b/tests/test_add_needed_imports_from_module.py @@ -527,3 +527,23 @@ def parse(): return PATTERN.findall("") """ assert result == expected + + +def test_module_input_fallback_strips_leading_newlines() -> None: + src_code = """ +def parse(): + return helper() + +def helper(): + return 1 +""" + parsed_module = cst.parse_module(src_code) + + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + file_path = project_root / "mod.py" + file_path.write_text(src_code) + + result = add_needed_imports_from_module(src_code, parsed_module, file_path, file_path, project_root) + + assert result == src_code.lstrip("\n")