Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 31 additions & 107 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import random
import warnings
from _ast import AsyncFunctionDef, ClassDef, FunctionDef
from collections import defaultdict
from functools import cache
from pathlib import Path
Expand All @@ -16,7 +15,7 @@
from rich.tree import Tree

from codeflash.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again
from codeflash.cli_cmds.console import DEBUG_MODE, console, logger
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.code_utils import (
exit_with_message,
is_class_defined_in_file,
Expand Down Expand Up @@ -47,10 +46,6 @@

from rich.text import Text

_property_id = "property"

_ast_name = ast.Name


@dataclass(frozen=True)
class FunctionProperties:
Expand All @@ -73,9 +68,9 @@ def visit_Return(self, node: cst.Return) -> None:
class FunctionVisitor(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.PositionProvider, cst.metadata.ParentNodeProvider)

def __init__(self, file_path: str) -> None:
def __init__(self, file_path: Path) -> None:
super().__init__()
self.file_path: str = file_path
self.file_path: Path = file_path
self.functions: list[FunctionToOptimize] = []

@staticmethod
Expand All @@ -91,15 +86,26 @@ def is_pytest_fixture(node: cst.FunctionDef) -> bool:
return True
return False

@staticmethod
def is_property(node: cst.FunctionDef) -> bool:
for decorator in node.decorators:
dec = decorator.decorator
if isinstance(dec, cst.Name) and dec.value in ("property", "cached_property"):
return True
return False

def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
return_visitor: ReturnStatementVisitor = ReturnStatementVisitor()
node.visit(return_visitor)
if return_visitor.has_return_statement and not self.is_pytest_fixture(node):
if return_visitor.has_return_statement and not self.is_pytest_fixture(node) and not self.is_property(node):
pos: CodeRange = self.get_metadata(cst.metadata.PositionProvider, node)
parents: CSTNode | None = self.get_metadata(cst.metadata.ParentNodeProvider, node)
ast_parents: list[FunctionParent] = []
while parents is not None:
if isinstance(parents, (cst.FunctionDef, cst.ClassDef)):
if isinstance(parents, cst.FunctionDef):
# Skip nested functions — only discover top-level and class-level functions
return
if isinstance(parents, cst.ClassDef):
ast_parents.append(FunctionParent(parents.name.value, parents.__class__.__name__))
parents = self.get_metadata(cst.metadata.ParentNodeProvider, parents, default=None)
self.functions.append(
Expand All @@ -114,32 +120,6 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
)


def find_functions_with_return_statement(ast_module: ast.Module, file_path: Path) -> list[FunctionToOptimize]:
results: list[FunctionToOptimize] = []
# (node, parent_path) — iterative DFS avoids RecursionError on deeply nested ASTs
stack: list[tuple[ast.AST, list[FunctionParent]]] = [(ast_module, [])]
while stack:
node, ast_path = stack.pop()
if isinstance(node, (FunctionDef, AsyncFunctionDef)):
if function_has_return_statement(node) and not function_is_a_property(node):
results.append(
FunctionToOptimize(
function_name=node.name,
file_path=file_path,
parents=ast_path[:],
is_async=isinstance(node, AsyncFunctionDef),
)
)
# Don't recurse into function bodies (matches original visitor behaviour)
continue
child_path = (
[*ast_path, FunctionParent(node.name, node.__class__.__name__)] if isinstance(node, ClassDef) else ast_path
)
for child in reversed(list(ast.iter_child_nodes(node))):
stack.append((child, child_path))
return results


# =============================================================================
# Multi-language support helpers
# =============================================================================
Expand Down Expand Up @@ -250,23 +230,6 @@ def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bo
return True, None


def _find_all_functions_in_python_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
"""Find all optimizable functions in a Python file using AST parsing.

This is the original Python implementation preserved for backward compatibility.
"""
functions: dict[Path, list[FunctionToOptimize]] = {}
with file_path.open(encoding="utf8") as f:
try:
ast_module = ast.parse(f.read())
except Exception as e:
if DEBUG_MODE:
logger.exception(e)
return functions
functions[file_path] = find_functions_with_return_statement(ast_module, file_path)
return functions


def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
"""Find all optimizable functions using the language support abstraction.

Expand All @@ -280,7 +243,6 @@ def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list
try:
lang_support = get_language_support(file_path)
criteria = FunctionFilterCriteria(require_return=True)
# discover_functions already returns FunctionToOptimize objects
functions[file_path] = lang_support.discover_functions(file_path, criteria)
except Exception as e:
logger.debug(f"Failed to discover functions in {file_path}: {e}")
Expand All @@ -302,7 +264,7 @@ def get_functions_to_optimize(
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
"Only one of optimize_all, replay_test, or file should be provided"
)
functions: dict[str, list[FunctionToOptimize]]
functions: dict[Path, list[FunctionToOptimize]]
trace_file_path: Path | None = None
is_lsp = is_LSP_enabled()
with warnings.catch_warnings():
Expand All @@ -319,7 +281,7 @@ def get_functions_to_optimize(
logger.info("!lsp|Finding all functions in the file '%s'…", file)
console.rule()
file = Path(file) if isinstance(file, str) else file
functions: dict[Path, list[FunctionToOptimize]] = find_all_functions_in_file(file)
functions = find_all_functions_in_file(file)
if only_get_this_function is not None:
split_function = only_get_this_function.split(".")
if len(split_function) > 2:
Expand Down Expand Up @@ -354,6 +316,7 @@ def get_functions_to_optimize(
f"Function {only_get_this_function} not found in file {file}\nor the function does not have a 'return' statement or is a property"
)

assert found_function is not None
# For JavaScript/TypeScript, verify that the function (or its parent class) is exported
# Non-exported functions cannot be imported by tests
if found_function.language in ("javascript", "typescript"):
Expand Down Expand Up @@ -397,7 +360,7 @@ def get_functions_to_optimize(
return filtered_modified_functions, functions_count, trace_file_path


def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]:
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[Path, list[FunctionToOptimize]]:
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes)
return get_functions_within_lines(modified_lines)

Expand Down Expand Up @@ -438,7 +401,7 @@ def closest_matching_file_function_name(
closest_match = function
closest_file = file_path

if closest_match is not None:
if closest_match is not None and closest_file is not None:
return closest_file, closest_match
return None

Expand Down Expand Up @@ -472,13 +435,13 @@ def levenshtein_distance(s1: str, s2: str) -> int:
return previous[len1]


def get_functions_inside_a_commit(commit_hash: str) -> dict[str, list[FunctionToOptimize]]:
def get_functions_inside_a_commit(commit_hash: str) -> dict[Path, list[FunctionToOptimize]]:
modified_lines: dict[str, list[int]] = get_git_diff(only_this_commit=commit_hash)
return get_functions_within_lines(modified_lines)


def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str, list[FunctionToOptimize]]:
functions: dict[str, list[FunctionToOptimize]] = {}
def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[Path, list[FunctionToOptimize]]:
functions: dict[Path, list[FunctionToOptimize]] = {}
for path_str, lines_in_file in modified_lines.items():
path = Path(path_str)
if not path.exists():
Expand All @@ -490,9 +453,9 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str
except Exception as e:
logger.exception(e)
continue
function_lines = FunctionVisitor(file_path=str(path))
function_lines = FunctionVisitor(file_path=path)
wrapper.visit(function_lines)
functions[str(path)] = [
functions[path] = [
function_to_optimize
for function_to_optimize in function_lines.functions
if (start_line := function_to_optimize.starting_line) is not None
Expand All @@ -504,7 +467,7 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str

def get_all_files_and_functions(
module_root_path: Path, ignore_paths: list[Path], language: Language | None = None
) -> dict[str, list[FunctionToOptimize]]:
) -> dict[Path, list[FunctionToOptimize]]:
"""Get all optimizable functions from files in the module root.

Args:
Expand All @@ -516,9 +479,8 @@ def get_all_files_and_functions(
Dictionary mapping file paths to lists of FunctionToOptimize.

"""
functions: dict[str, list[FunctionToOptimize]] = {}
functions: dict[Path, list[FunctionToOptimize]] = {}
for file_path in get_files_for_language(module_root_path, ignore_paths, language):
# Find all the functions in the file
functions.update(find_all_functions_in_file(file_path).items())
# Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time.
# Helpful if an optimize-all run is stuck and we restart it.
Expand All @@ -545,16 +507,6 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
if not is_language_supported(file_path):
return {}

try:
lang_support = get_language_support(file_path)
except Exception:
return {}

# Route to Python-specific implementation for backward compatibility
if lang_support.language == Language.PYTHON:
return _find_all_functions_in_python_file(file_path)

# Use language support abstraction for other languages
return _find_all_functions_via_language_support(file_path)


Expand Down Expand Up @@ -833,7 +785,7 @@ def filter_functions(
disable_logs: bool = False,
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
resolved_project_root = project_root.resolve()
filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {}
filtered_modified_functions: dict[Path, list[FunctionToOptimize]] = {}
blocklist_funcs = get_blocklisted_functions()
logger.debug(f"Blocklisted functions: {blocklist_funcs}")
# Remove any function that we don't want to optimize
Expand Down Expand Up @@ -940,7 +892,7 @@ def is_test_file(file_path_normalized: str) -> bool:
functions_tmp.append(function)
_functions = functions_tmp

filtered_modified_functions[file_path] = _functions
filtered_modified_functions[file_path_path] = _functions
functions_count += len(_functions)

if not disable_logs:
Expand All @@ -961,7 +913,7 @@ def is_test_file(file_path_normalized: str) -> bool:
if len(tree.children) > 0:
console.print(tree)
console.rule()
return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count
return {k: v for k, v in filtered_modified_functions.items() if v}, functions_count


def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list[Path], module_root: Path) -> bool:
Expand All @@ -984,31 +936,3 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
file_path in submodule_paths
or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths)
)


def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
# Custom DFS, return True as soon as a Return node is found
stack: list[ast.AST] = list(function_node.body)
while stack:
node = stack.pop()
if isinstance(node, ast.Return):
return True
# Only push child nodes that are statements; Return nodes are statements,
# so this preserves correctness while avoiding unnecessary traversal into expr/Name/etc.
for field in getattr(node, "_fields", ()):
child = getattr(node, field, None)
if isinstance(child, list):
for item in child:
if isinstance(item, ast.stmt):
stack.append(item)
elif isinstance(child, ast.stmt):
stack.append(child)
return False


def function_is_a_property(function_node: FunctionDef | AsyncFunctionDef) -> bool:
for node in function_node.decorator_list: # noqa: SIM110
# Use isinstance rather than type(...) is ... for better performance with single inheritance trees like ast
if isinstance(node, _ast_name) and node.id == _property_id:
return True
return False
81 changes: 32 additions & 49 deletions codeflash/languages/python/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,56 +130,39 @@ def discover_functions(

criteria = filter_criteria or FunctionFilterCriteria()

try:
# Read and parse the file using libcst with metadata
source = file_path.read_text(encoding="utf-8")
try:
tree = cst.parse_module(source)
except Exception:
return []

# Use the libcst-based FunctionVisitor for accurate line numbers
wrapper = cst.metadata.MetadataWrapper(tree)
function_visitor = FunctionVisitor(file_path=str(file_path))
wrapper.visit(function_visitor)

functions: list[FunctionToOptimize] = []
for func in function_visitor.functions:
if not isinstance(func, FunctionToOptimize):
continue

# Apply filter criteria
if not criteria.include_async and func.is_async:
continue

if not criteria.include_methods and func.parents:
continue

# Check for return statement requirement (FunctionVisitor already filters this)
# but we double-check here for consistency
if criteria.require_return and func.starting_line is None:
continue

# Add is_method field based on parents
func_with_is_method = FunctionToOptimize(
function_name=func.function_name,
file_path=file_path,
parents=func.parents,
starting_line=func.starting_line,
ending_line=func.ending_line,
starting_col=func.starting_col,
ending_col=func.ending_col,
is_async=func.is_async,
is_method=len(func.parents) > 0 and any(p.type == "ClassDef" for p in func.parents),
language="python",
)
functions.append(func_with_is_method)

return functions
source = file_path.read_text(encoding="utf-8")
tree = cst.parse_module(source)

wrapper = cst.metadata.MetadataWrapper(tree)
function_visitor = FunctionVisitor(file_path=file_path)
wrapper.visit(function_visitor)

functions: list[FunctionToOptimize] = []
for func in function_visitor.functions:
if not criteria.include_async and func.is_async:
continue

if not criteria.include_methods and func.parents:
continue

if criteria.require_return and func.starting_line is None:
continue

func_with_is_method = FunctionToOptimize(
function_name=func.function_name,
file_path=file_path,
parents=func.parents,
starting_line=func.starting_line,
ending_line=func.ending_line,
starting_col=func.starting_col,
ending_col=func.ending_col,
is_async=func.is_async,
is_method=len(func.parents) > 0 and any(p.type == "ClassDef" for p in func.parents),
language="python",
)
functions.append(func_with_is_method)

except Exception as e:
logger.warning("Failed to discover functions in %s: %s", file_path, e)
return []
return functions

def discover_tests(
self, test_root: Path, source_functions: Sequence[FunctionToOptimize]
Expand Down
Loading
Loading