diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index ad8bb331d..d2b63cfda 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -1770,3 +1770,22 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b return False except ValueError: return False + + +def _maybe_strip_docstring(node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfig) -> cst.FunctionDef | cst.ClassDef: + """Strip docstring from function or class if configured to do so.""" + if not cfg.remove_docstrings or not isinstance(node.body, cst.IndentedBlock): + return node + + body_stmts = node.body.body + if not body_stmts: + return node + + first_stmt = body_stmts[0] + if isinstance(first_stmt, cst.SimpleStatementLine) and len(first_stmt.body) == 1: + expr_stmt = first_stmt.body[0] + if isinstance(expr_stmt, cst.Expr) and isinstance(expr_stmt.value, cst.SimpleString | cst.ConcatenatedString): + new_body = body_stmts[1:] or [cst.SimpleStatementLine(body=[cst.Pass()])] + return node.with_changes(body=node.body.with_changes(body=new_body)) + + return node diff --git a/codeflash/languages/python/static_analysis/code_extractor.py b/codeflash/languages/python/static_analysis/code_extractor.py index 1b315d629..80a274770 100644 --- a/codeflash/languages/python/static_analysis/code_extractor.py +++ b/codeflash/languages/python/static_analysis/code_extractor.py @@ -553,7 +553,13 @@ def add_needed_imports_from_module( if not helper_functions_fqn: helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])} - dst_code_fallback = dst_module_code if isinstance(dst_module_code, str) else dst_module_code.code + # Cache the fallback early to avoid repeated isinstance checks + if isinstance(dst_module_code, str): + dst_code_fallback = dst_module_code + parsed_dst_module = None + else: + dst_code_fallback = dst_module_code.code + parsed_dst_module = dst_module_code src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path) dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path) @@ -563,18 +569,32 @@ def add_needed_imports_from_module( full_module_name=dst_module_and_package.name, full_package_name=dst_module_and_package.package, ) - gatherer: GatherImportsVisitor = GatherImportsVisitor( - CodemodContext( - filename=src_path.name, - full_module_name=src_module_and_package.name, - full_package_name=src_module_and_package.package, - ) - ) try: if isinstance(src_module_code, cst.Module): src_module = src_module_code.visit(FutureAliasedImportTransformer()) else: src_module = cst.parse_module(src_module_code).visit(FutureAliasedImportTransformer()) + + # Early exit: check if source has any imports at module level + has_module_level_imports = any( + isinstance(stmt, (cst.Import, cst.ImportFrom)) + for stmt in src_module.body + if isinstance(stmt, cst.SimpleStatementLine) + for s in stmt.body + if isinstance(s, (cst.Import, cst.ImportFrom)) + ) + + if not has_module_level_imports: + return dst_code_fallback + + gatherer: GatherImportsVisitor = GatherImportsVisitor( + CodemodContext( + filename=src_path.name, + full_module_name=src_module_and_package.name, + full_package_name=src_module_and_package.package, + ) + ) + # Exclude function/class bodies so GatherImportsVisitor only sees module-level imports. # Nested imports (inside functions) are part of function logic and must not be # scheduled for add/remove — RemoveImportsVisitor would strip them as "unused". @@ -582,22 +602,30 @@ def add_needed_imports_from_module( body=[stmt for stmt in src_module.body if not isinstance(stmt, (cst.FunctionDef, cst.ClassDef))] ) module_level_only.visit(gatherer) + + # Early exit: if no imports were gathered, return destination as-is + if ( + not gatherer.module_imports + and not gatherer.object_mapping + and not gatherer.module_aliases + and not gatherer.alias_mapping + ): + return dst_code_fallback + except Exception as e: logger.error(f"Error parsing source module code: {e}") return dst_code_fallback dotted_import_collector = DottedImportCollector() - if isinstance(dst_module_code, cst.Module): - parsed_dst_module = dst_module_code - parsed_dst_module.visit(dotted_import_collector) - else: + if parsed_dst_module is None: try: parsed_dst_module = cst.parse_module(dst_module_code) - parsed_dst_module.visit(dotted_import_collector) except cst.ParserSyntaxError as e: logger.exception(f"Syntax error in destination module code: {e}") return dst_code_fallback + parsed_dst_module.visit(dotted_import_collector) + try: for mod in gatherer.module_imports: # Skip __future__ imports as they cannot be imported directly