Skip to content

Commit e820e11

Browse files
authored
Merge pull request #1853 from codeflash-ai/codeflash/optimize-pr1852-2026-03-17T06.54.17
⚡️ Speed up function `re_extract_from_cache` by 18,223% in PR #1852 (`cf-1846-port-perf-improvements`)
2 parents 9f20cdf + 5b52fba commit e820e11

2 files changed

Lines changed: 60 additions & 13 deletions

File tree

codeflash/languages/python/context/code_context_extractor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,3 +1770,22 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b
17701770
return False
17711771
except ValueError:
17721772
return False
1773+
1774+
1775+
def _maybe_strip_docstring(node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfig) -> cst.FunctionDef | cst.ClassDef:
1776+
"""Strip docstring from function or class if configured to do so."""
1777+
if not cfg.remove_docstrings or not isinstance(node.body, cst.IndentedBlock):
1778+
return node
1779+
1780+
body_stmts = node.body.body
1781+
if not body_stmts:
1782+
return node
1783+
1784+
first_stmt = body_stmts[0]
1785+
if isinstance(first_stmt, cst.SimpleStatementLine) and len(first_stmt.body) == 1:
1786+
expr_stmt = first_stmt.body[0]
1787+
if isinstance(expr_stmt, cst.Expr) and isinstance(expr_stmt.value, cst.SimpleString | cst.ConcatenatedString):
1788+
new_body = body_stmts[1:] or [cst.SimpleStatementLine(body=[cst.Pass()])]
1789+
return node.with_changes(body=node.body.with_changes(body=new_body))
1790+
1791+
return node

codeflash/languages/python/static_analysis/code_extractor.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,13 @@ def add_needed_imports_from_module(
553553
if not helper_functions_fqn:
554554
helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])}
555555

556-
dst_code_fallback = dst_module_code if isinstance(dst_module_code, str) else dst_module_code.code
556+
# Cache the fallback early to avoid repeated isinstance checks
557+
if isinstance(dst_module_code, str):
558+
dst_code_fallback = dst_module_code
559+
parsed_dst_module = None
560+
else:
561+
dst_code_fallback = dst_module_code.code
562+
parsed_dst_module = dst_module_code
557563

558564
src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path)
559565
dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path)
@@ -563,41 +569,63 @@ def add_needed_imports_from_module(
563569
full_module_name=dst_module_and_package.name,
564570
full_package_name=dst_module_and_package.package,
565571
)
566-
gatherer: GatherImportsVisitor = GatherImportsVisitor(
567-
CodemodContext(
568-
filename=src_path.name,
569-
full_module_name=src_module_and_package.name,
570-
full_package_name=src_module_and_package.package,
571-
)
572-
)
573572
try:
574573
if isinstance(src_module_code, cst.Module):
575574
src_module = src_module_code.visit(FutureAliasedImportTransformer())
576575
else:
577576
src_module = cst.parse_module(src_module_code).visit(FutureAliasedImportTransformer())
577+
578+
# Early exit: check if source has any imports at module level
579+
has_module_level_imports = any(
580+
isinstance(stmt, (cst.Import, cst.ImportFrom))
581+
for stmt in src_module.body
582+
if isinstance(stmt, cst.SimpleStatementLine)
583+
for s in stmt.body
584+
if isinstance(s, (cst.Import, cst.ImportFrom))
585+
)
586+
587+
if not has_module_level_imports:
588+
return dst_code_fallback
589+
590+
gatherer: GatherImportsVisitor = GatherImportsVisitor(
591+
CodemodContext(
592+
filename=src_path.name,
593+
full_module_name=src_module_and_package.name,
594+
full_package_name=src_module_and_package.package,
595+
)
596+
)
597+
578598
# Exclude function/class bodies so GatherImportsVisitor only sees module-level imports.
579599
# Nested imports (inside functions) are part of function logic and must not be
580600
# scheduled for add/remove — RemoveImportsVisitor would strip them as "unused".
581601
module_level_only = src_module.with_changes(
582602
body=[stmt for stmt in src_module.body if not isinstance(stmt, (cst.FunctionDef, cst.ClassDef))]
583603
)
584604
module_level_only.visit(gatherer)
605+
606+
# Early exit: if no imports were gathered, return destination as-is
607+
if (
608+
not gatherer.module_imports
609+
and not gatherer.object_mapping
610+
and not gatherer.module_aliases
611+
and not gatherer.alias_mapping
612+
):
613+
return dst_code_fallback
614+
585615
except Exception as e:
586616
logger.error(f"Error parsing source module code: {e}")
587617
return dst_code_fallback
588618

589619
dotted_import_collector = DottedImportCollector()
590-
if isinstance(dst_module_code, cst.Module):
591-
parsed_dst_module = dst_module_code
592-
parsed_dst_module.visit(dotted_import_collector)
593-
else:
620+
if parsed_dst_module is None:
594621
try:
595622
parsed_dst_module = cst.parse_module(dst_module_code)
596-
parsed_dst_module.visit(dotted_import_collector)
597623
except cst.ParserSyntaxError as e:
598624
logger.exception(f"Syntax error in destination module code: {e}")
599625
return dst_code_fallback
600626

627+
parsed_dst_module.visit(dotted_import_collector)
628+
601629
try:
602630
for mod in gatherer.module_imports:
603631
# Skip __future__ imports as they cannot be imported directly

0 commit comments

Comments
 (0)