Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions codeflash/languages/python/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,3 +1770,22 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b
return False
except ValueError:
return False


def _maybe_strip_docstring(node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfig) -> cst.FunctionDef | cst.ClassDef:
"""Strip docstring from function or class if configured to do so."""
if not cfg.remove_docstrings or not isinstance(node.body, cst.IndentedBlock):
return node

body_stmts = node.body.body
if not body_stmts:
return node

first_stmt = body_stmts[0]
if isinstance(first_stmt, cst.SimpleStatementLine) and len(first_stmt.body) == 1:
expr_stmt = first_stmt.body[0]
if isinstance(expr_stmt, cst.Expr) and isinstance(expr_stmt.value, cst.SimpleString | cst.ConcatenatedString):
new_body = body_stmts[1:] or [cst.SimpleStatementLine(body=[cst.Pass()])]
return node.with_changes(body=node.body.with_changes(body=new_body))

return node
54 changes: 41 additions & 13 deletions codeflash/languages/python/static_analysis/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,13 @@ def add_needed_imports_from_module(
if not helper_functions_fqn:
helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])}

dst_code_fallback = dst_module_code if isinstance(dst_module_code, str) else dst_module_code.code
# Cache the fallback early to avoid repeated isinstance checks
if isinstance(dst_module_code, str):
dst_code_fallback = dst_module_code
parsed_dst_module = None
else:
dst_code_fallback = dst_module_code.code
parsed_dst_module = dst_module_code

src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path)
dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path)
Expand All @@ -563,41 +569,63 @@ def add_needed_imports_from_module(
full_module_name=dst_module_and_package.name,
full_package_name=dst_module_and_package.package,
)
gatherer: GatherImportsVisitor = GatherImportsVisitor(
CodemodContext(
filename=src_path.name,
full_module_name=src_module_and_package.name,
full_package_name=src_module_and_package.package,
)
)
try:
if isinstance(src_module_code, cst.Module):
src_module = src_module_code.visit(FutureAliasedImportTransformer())
else:
src_module = cst.parse_module(src_module_code).visit(FutureAliasedImportTransformer())

# Early exit: check if source has any imports at module level
has_module_level_imports = any(
isinstance(stmt, (cst.Import, cst.ImportFrom))
for stmt in src_module.body
if isinstance(stmt, cst.SimpleStatementLine)
for s in stmt.body
if isinstance(s, (cst.Import, cst.ImportFrom))
)

if not has_module_level_imports:
return dst_code_fallback

gatherer: GatherImportsVisitor = GatherImportsVisitor(
CodemodContext(
filename=src_path.name,
full_module_name=src_module_and_package.name,
full_package_name=src_module_and_package.package,
)
)

# Exclude function/class bodies so GatherImportsVisitor only sees module-level imports.
# Nested imports (inside functions) are part of function logic and must not be
# scheduled for add/remove — RemoveImportsVisitor would strip them as "unused".
module_level_only = src_module.with_changes(
body=[stmt for stmt in src_module.body if not isinstance(stmt, (cst.FunctionDef, cst.ClassDef))]
)
module_level_only.visit(gatherer)

# Early exit: if no imports were gathered, return destination as-is
if (
not gatherer.module_imports
and not gatherer.object_mapping
and not gatherer.module_aliases
and not gatherer.alias_mapping
):
return dst_code_fallback

except Exception as e:
logger.error(f"Error parsing source module code: {e}")
return dst_code_fallback

dotted_import_collector = DottedImportCollector()
if isinstance(dst_module_code, cst.Module):
parsed_dst_module = dst_module_code
parsed_dst_module.visit(dotted_import_collector)
else:
if parsed_dst_module is None:
try:
parsed_dst_module = cst.parse_module(dst_module_code)
parsed_dst_module.visit(dotted_import_collector)
except cst.ParserSyntaxError as e:
logger.exception(f"Syntax error in destination module code: {e}")
return dst_code_fallback

parsed_dst_module.visit(dotted_import_collector)

try:
for mod in gatherer.module_imports:
# Skip __future__ imports as they cannot be imported directly
Expand Down
Loading