diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 64de4c54b..b64c8869a 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -14,7 +14,7 @@ from codeflash.code_utils.env_utils import get_codeflash_api_key from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name from codeflash.code_utils.time_utils import humanize_runtime -from codeflash.languages import Language, current_language +from codeflash.languages import Language, current_language, current_language_support from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( AIServiceRefinerRequest, @@ -58,6 +58,8 @@ def add_language_metadata( payload: dict[str, Any], language_version: str | None = None, module_system: str | None = None ) -> None: """Add language version and module system metadata to an API payload.""" + if language_version is None: + language_version = current_language_support().language_version payload["language_version"] = language_version payload["python_version"] = language_version if current_language() == Language.PYTHON else None diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index 7797ca3e4..ff6494d73 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -50,6 +50,8 @@ MAX_CONTEXT_LEN_REVIEW = 1000 +HIGH_EFFORT_TOP_N = 15 + class EffortLevel(str, Enum): LOW = "low" diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index fa1ebb16e..d9c3d4e3c 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -11,6 +11,7 @@ import subprocess import unittest from collections import defaultdict +from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING, Callable, Optional, final @@ -35,6 +36,21 @@ from codeflash.verification.verification_utils import TestConfig +def existing_unit_test_count( + func: FunctionToOptimize, project_root: Path, function_to_tests: dict[str, set[FunctionCalledInTest]] +) -> int: + key = f"{module_name_from_file_path_cached(func.file_path, project_root)}.{func.qualified_name}" + tests = function_to_tests.get(key, set()) + seen: set[tuple[Path, str | None, str]] = set() + for t in tests: + if t.tests_in_file.test_type != TestType.EXISTING_UNIT_TEST: + continue + tif = t.tests_in_file + base_name = tif.test_function.split("[", 1)[0] + seen.add((tif.test_file, tif.test_class, base_name)) + return len(seen) + + @final class PytestExitCode(enum.IntEnum): # don't need to import entire pytest just for this #: Tests passed. @@ -1079,3 +1095,9 @@ def process_test_files( tests_cache.close() return dict(function_to_test_map), num_discovered_tests, num_discovered_replay_tests + + +# Cache module name resolution to avoid repeated Path.resolve()/relative_to() calls +@lru_cache(maxsize=128) +def module_name_from_file_path_cached(file_path: Path, project_root: Path) -> str: + return module_name_from_file_path(file_path, project_root) diff --git a/codeflash/languages/base.py b/codeflash/languages/base.py index 686aa7fd8..d2b2a3184 100644 --- a/codeflash/languages/base.py +++ b/codeflash/languages/base.py @@ -18,6 +18,7 @@ from pathlib import Path from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.models.call_graph import CallGraph from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId, ValidCode from codeflash.verification.verification_utils import TestConfig @@ -250,6 +251,12 @@ def count_callees_per_function( """Return the number of callees for each (file_path, qualified_name) pair.""" ... + def get_call_graph( + self, file_path_to_qualified_names: dict[Path, set[str]], *, include_metadata: bool = False + ) -> CallGraph: + """Return a CallGraph with full caller→callee edges for the given functions.""" + ... + def close(self) -> None: """Release resources (e.g. database connections).""" ... diff --git a/codeflash/languages/code_replacer.py b/codeflash/languages/code_replacer.py index 140690882..e05138d51 100644 --- a/codeflash/languages/code_replacer.py +++ b/codeflash/languages/code_replacer.py @@ -6,13 +6,15 @@ from __future__ import annotations -from pathlib import Path +import os from typing import TYPE_CHECKING from codeflash.cli_cmds.console import logger from codeflash.languages.base import FunctionFilterCriteria, Language if TYPE_CHECKING: + from pathlib import Path + from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import LanguageSupport from codeflash.models.models import CodeStringsMarkdown @@ -38,7 +40,7 @@ def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStrin # directory prefix but the correct filename target_name = relative_path.name basename_matches = [ - code for path, code in file_to_code_context.items() if path != "None" and Path(path).name == target_name + code for path, code in file_to_code_context.items() if path != "None" and os.path.basename(path) == target_name ] if len(basename_matches) == 1: logger.debug(f"Using basename-matched code block for {relative_path}") diff --git a/codeflash/languages/function_optimizer.py b/codeflash/languages/function_optimizer.py index 1fec830d1..31f8a92a4 100644 --- a/codeflash/languages/function_optimizer.py +++ b/codeflash/languages/function_optimizer.py @@ -425,6 +425,7 @@ def __init__( args: Namespace | None = None, replay_tests_dir: Path | None = None, call_graph: DependencyResolver | None = None, + effort_override: str | None = None, ) -> None: self.project_root = test_cfg.project_root_path.resolve() self.test_cfg = test_cfg @@ -451,7 +452,8 @@ def __init__( self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None self.test_files = TestFiles(test_files=[]) - self.effort = getattr(args, "effort", EffortLevel.MEDIUM.value) if args else EffortLevel.MEDIUM.value + default_effort = getattr(args, "effort", EffortLevel.MEDIUM.value) if args else EffortLevel.MEDIUM.value + self.effort = effort_override or default_effort self.args = args # Check defaults for these self.function_trace_id: str = str(uuid.uuid4()) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 83eb49bea..af54a56fb 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -3,9 +3,9 @@ import ast import hashlib import os -from collections import defaultdict, deque +from collections import defaultdict +from dataclasses import dataclass, field from itertools import chain -from pathlib import Path from typing import TYPE_CHECKING import libcst as cst @@ -39,28 +39,42 @@ ) if TYPE_CHECKING: + from pathlib import Path + from jedi.api.classes import Name from codeflash.languages.base import DependencyResolver from codeflash.languages.python.context.unused_definition_remover import UsageInfo +@dataclass +class FileContextCache: + original_module: cst.Module + cleaned_module: cst.Module + fto_names: set[str] + hoh_names: set[str] + helper_functions: list[FunctionSource] + file_path: Path + relative_path: Path + + +@dataclass +class AllContextResults: + read_writable: CodeStringsMarkdown + read_only: CodeStringsMarkdown + hashing: CodeStringsMarkdown + testgen: CodeStringsMarkdown + file_caches: list[FileContextCache] = field(default_factory=list, repr=False) + + def build_testgen_context( - helpers_of_fto_dict: dict[Path, set[FunctionSource]], - helpers_of_helpers_dict: dict[Path, set[FunctionSource]], + testgen_base: CodeStringsMarkdown, project_root_path: Path, *, - remove_docstrings: bool = False, include_enrichment: bool = True, function_to_optimize: FunctionToOptimize | None = None, ) -> CodeStringsMarkdown: - testgen_context = extract_code_markdown_context_from_files( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=remove_docstrings, - code_context_type=CodeContextType.TESTGEN, - ) + testgen_context = testgen_base if include_enrichment: enrichment = enrich_testgen_context(testgen_context, project_root_path) @@ -113,14 +127,10 @@ def get_code_optimization_context( helpers_of_fto_qualified_names_dict, project_root_path ) - # Extract code context for optimization - final_read_writable_code = extract_code_markdown_context_from_files( - helpers_of_fto_dict, - {}, - project_root_path, - remove_docstrings=False, - code_context_type=CodeContextType.READ_WRITABLE, - ) + # Extract all code contexts in a single pass (one CST parse per file) + all_ctx = extract_all_contexts_from_files(helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path) + + final_read_writable_code = all_ctx.read_writable # Ensure the target file is first in the code blocks so the LLM knows which file to optimize target_relative = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve()) @@ -129,20 +139,7 @@ def get_code_optimization_context( if target_blocks: final_read_writable_code.code_strings = target_blocks + other_blocks - read_only_code_markdown = extract_code_markdown_context_from_files( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=False, - code_context_type=CodeContextType.READ_ONLY, - ) - hashing_code_context = extract_code_markdown_context_from_files( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=True, - code_context_type=CodeContextType.HASHING, - ) + read_only_code_markdown = all_ctx.read_only # Handle token limits final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.markdown) @@ -162,8 +159,8 @@ def get_code_optimization_context( read_only_tokens = encoded_tokens_len(read_only_context_code) if final_read_writable_tokens + read_only_tokens > optim_token_limit: logger.debug("Code context has exceeded token limit, removing docstrings from read-only code") - read_only_code_no_docstrings = extract_code_markdown_context_from_files( - helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True + read_only_code_no_docstrings = re_extract_from_cache( + all_ctx.file_caches, CodeContextType.READ_ONLY, project_root_path ) read_only_context_code = read_only_code_no_docstrings.markdown if final_read_writable_tokens + encoded_tokens_len(read_only_context_code) > optim_token_limit: @@ -172,32 +169,23 @@ def get_code_optimization_context( # Progressive fallback for testgen context token limits testgen_context = build_testgen_context( - helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, function_to_optimize=function_to_optimize + all_ctx.testgen, project_root_path, function_to_optimize=function_to_optimize ) if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: logger.debug("Testgen context exceeded token limit, removing docstrings") + testgen_base_no_docs = re_extract_from_cache(all_ctx.file_caches, CodeContextType.TESTGEN, project_root_path) testgen_context = build_testgen_context( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=True, - function_to_optimize=function_to_optimize, + testgen_base_no_docs, project_root_path, function_to_optimize=function_to_optimize ) if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: logger.debug("Testgen context still exceeded token limit, removing enrichment") - testgen_context = build_testgen_context( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=True, - include_enrichment=False, - ) + testgen_context = build_testgen_context(testgen_base_no_docs, project_root_path, include_enrichment=False) if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: raise ValueError(TESTGEN_LIMIT_ERROR) - code_hash_context = hashing_code_context.markdown + code_hash_context = all_ctx.hashing.markdown code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest() all_helper_fqns = list({fs.fully_qualified_name for fs in helpers_of_fto_list + helpers_of_helpers_list}) @@ -214,130 +202,263 @@ def get_code_optimization_context( ) -def process_file_context( - file_path: Path, - primary_qualified_names: set[str], - secondary_qualified_names: set[str], - code_context_type: CodeContextType, - remove_docstrings: bool, +def extract_all_contexts_from_files( + helpers_of_fto: dict[Path, set[FunctionSource]], + helpers_of_helpers: dict[Path, set[FunctionSource]], project_root_path: Path, - helper_functions: list[FunctionSource], -) -> CodeString | None: - try: - original_code = file_path.read_text("utf8") - except Exception as e: - logger.exception(f"Error while parsing {file_path}: {e}") - return None +) -> AllContextResults: + """Extract all 4 code context types from files in a single pass, parsing each file only once.""" + # Deduplicate: remove HoH entries that overlap with FTO (without mutating the caller's dict) + hoh_deduped: dict[Path, set[FunctionSource]] = {} + helpers_of_helpers_no_overlap: dict[Path, set[FunctionSource]] = {} + for file_path, function_sources in helpers_of_helpers.items(): + if file_path in helpers_of_fto: + hoh_deduped[file_path] = function_sources - helpers_of_fto[file_path] + else: + helpers_of_helpers_no_overlap[file_path] = function_sources - try: - all_names = primary_qualified_names | secondary_qualified_names - code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, all_names) - pruned_module = parse_code_and_prune_cst( - code_without_unused_defs, - code_context_type, - primary_qualified_names, - secondary_qualified_names, - remove_docstrings, - ) - except ValueError as e: - logger.debug(f"Error while getting read-only code: {e}") - return None + rw = CodeStringsMarkdown() + ro = CodeStringsMarkdown() + hashing = CodeStringsMarkdown() + testgen = CodeStringsMarkdown() + file_caches: list[FileContextCache] = [] + + # Process files containing FTO helpers (all 4 context types) + for file_path, function_sources in helpers_of_fto.items(): + fto_names = {func.qualified_name for func in function_sources} + hoh_funcs = hoh_deduped.get(file_path, set()) + hoh_names = {func.qualified_name for func in hoh_funcs} + rw_helper_functions = list(function_sources) + all_helper_functions = list(function_sources | hoh_funcs) + + try: + original_code = file_path.read_text("utf8") + except Exception as e: + logger.exception(f"Error while parsing {file_path}: {e}") + continue + + try: + original_module = cst.parse_module(original_code) + except Exception as e: + logger.debug(f"Failed to parse {file_path} with libcst: {type(e).__name__}: {e}") + continue - if pruned_module.code.strip(): - if code_context_type == CodeContextType.HASHING: - code_context = ast.unparse(ast.parse(pruned_module.code)) - else: - code_context = add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=pruned_module, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=helper_functions, - ) try: relative_path = file_path.resolve().relative_to(project_root_path.resolve()) except ValueError: relative_path = file_path - return CodeString(code=code_context, file_path=relative_path) - return None + # Compute defs once for fto_names and reuse across remove + prune + fto_defs = collect_top_level_defs_with_usages(original_module, fto_names) + # Clean by fto_names only (for RW) + rw_cleaned = remove_unused_definitions_by_function_names(original_module, fto_names, defs_with_usages=fto_defs) + # Clean by all names (for RO/HASH/TESTGEN) — reuse rw_cleaned if no extra HoH names + all_names = fto_names | hoh_names + all_cleaned = ( + remove_unused_definitions_by_function_names(original_module, all_names) if hoh_names else rw_cleaned + ) -def extract_code_markdown_context_from_files( - helpers_of_fto: dict[Path, set[FunctionSource]], - helpers_of_helpers: dict[Path, set[FunctionSource]], - project_root_path: Path, - remove_docstrings: bool = False, - code_context_type: CodeContextType = CodeContextType.READ_ONLY, -) -> CodeStringsMarkdown: - """Extract code context from files containing target functions and their helpers, formatting them as markdown. + # READ_WRITABLE + try: + rw_pruned = parse_code_and_prune_cst( + rw_cleaned, + CodeContextType.READ_WRITABLE, + fto_names, + set(), + remove_docstrings=False, + defs_with_usages=fto_defs, + ) + if rw_pruned.code.strip(): + rw_code = add_needed_imports_from_module( + src_module_code=original_module, + dst_module_code=rw_pruned, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=rw_helper_functions, + ) + rw.code_strings.append(CodeString(code=rw_code, file_path=relative_path)) + except ValueError as e: + logger.debug(f"Error while getting read-writable code: {e}") - This function processes two sets of files: - 1. Files containing the function to optimize (fto) and their first-degree helpers - 2. Files containing only helpers of helpers (with no overlap with the first set) + # READ_ONLY + try: + ro_pruned = parse_code_and_prune_cst( + all_cleaned, CodeContextType.READ_ONLY, fto_names, hoh_names, remove_docstrings=False + ) + if ro_pruned.code.strip(): + ro_code = add_needed_imports_from_module( + src_module_code=original_module, + dst_module_code=ro_pruned, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=all_helper_functions, + ) + ro.code_strings.append(CodeString(code=ro_code, file_path=relative_path)) + except ValueError as e: + logger.debug(f"Error while getting read-only code: {e}") - For each file, it extracts relevant code based on the specified context type, adds necessary - imports, and combines them into a structured markdown format. + # HASHING + try: + hash_pruned = parse_code_and_prune_cst( + all_cleaned, CodeContextType.HASHING, fto_names, hoh_names, remove_docstrings=True + ) + if hash_pruned.code.strip(): + hash_code = ast.unparse(ast.parse(hash_pruned.code)) + hashing.code_strings.append(CodeString(code=hash_code, file_path=relative_path)) + except ValueError as e: + logger.debug(f"Error while getting hashing code: {e}") - Args: - ---- - helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers - helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions - project_root_path: Root path of the project - remove_docstrings: Whether to remove docstrings from the extracted code - code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN) + # TESTGEN + try: + testgen_pruned = parse_code_and_prune_cst( + all_cleaned, CodeContextType.TESTGEN, fto_names, hoh_names, remove_docstrings=False + ) + if testgen_pruned.code.strip(): + testgen_code = add_needed_imports_from_module( + src_module_code=original_module, + dst_module_code=testgen_pruned, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=all_helper_functions, + ) + testgen.code_strings.append(CodeString(code=testgen_code, file_path=relative_path)) + except ValueError as e: + logger.debug(f"Error while getting testgen code: {e}") + + file_caches.append( + FileContextCache( + original_module=original_module, + cleaned_module=all_cleaned, + fto_names=fto_names, + hoh_names=hoh_names, + helper_functions=all_helper_functions, + file_path=file_path, + relative_path=relative_path, + ) + ) - Returns: - ------- - CodeStringsMarkdown containing the extracted code context with necessary imports, - formatted for inclusion in markdown + # Process files containing only helpers of helpers (RO/HASH/TESTGEN only) + for file_path, function_sources in helpers_of_helpers_no_overlap.items(): + hoh_names = {func.qualified_name for func in function_sources} + helper_functions = list(function_sources) - """ - # Rearrange to remove overlaps, so we only access each file path once - helpers_of_helpers_no_overlap = defaultdict(set) - for file_path, function_sources in helpers_of_helpers.items(): - if file_path in helpers_of_fto: - # Remove duplicates within the same file path, in case a helper of helper is also a helper of fto - helpers_of_helpers[file_path] -= helpers_of_fto[file_path] - else: - helpers_of_helpers_no_overlap[file_path] = function_sources - code_context_markdown = CodeStringsMarkdown() - # Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files - for file_path, function_sources in helpers_of_fto.items(): - qualified_function_names = {func.qualified_name for func in function_sources} - helpers_of_helpers_qualified_names = {func.qualified_name for func in helpers_of_helpers.get(file_path, set())} - helper_functions = list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())) - - result = process_file_context( - file_path=file_path, - primary_qualified_names=qualified_function_names, - secondary_qualified_names=helpers_of_helpers_qualified_names, - code_context_type=code_context_type, - remove_docstrings=remove_docstrings, - project_root_path=project_root_path, - helper_functions=helper_functions, - ) + try: + original_code = file_path.read_text("utf8") + except Exception as e: + logger.exception(f"Error while parsing {file_path}: {e}") + continue - if result is not None: - code_context_markdown.code_strings.append(result) - # Extract code from file paths containing helpers of helpers - for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items(): - qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} - helper_functions = list(helpers_of_helpers_no_overlap.get(file_path, set())) - - result = process_file_context( - file_path=file_path, - primary_qualified_names=set(), - secondary_qualified_names=qualified_helper_function_names, - code_context_type=code_context_type, - remove_docstrings=remove_docstrings, - project_root_path=project_root_path, - helper_functions=helper_functions, + try: + original_module = cst.parse_module(original_code) + except Exception as e: + logger.debug(f"Failed to parse {file_path} with libcst: {type(e).__name__}: {e}") + continue + + try: + relative_path = file_path.resolve().relative_to(project_root_path.resolve()) + except ValueError: + relative_path = file_path + + cleaned = remove_unused_definitions_by_function_names(original_module, hoh_names) + + # READ_ONLY + try: + ro_pruned = parse_code_and_prune_cst( + cleaned, CodeContextType.READ_ONLY, set(), hoh_names, remove_docstrings=False + ) + if ro_pruned.code.strip(): + ro_code = add_needed_imports_from_module( + src_module_code=original_module, + dst_module_code=ro_pruned, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=helper_functions, + ) + ro.code_strings.append(CodeString(code=ro_code, file_path=relative_path)) + except ValueError as e: + logger.debug(f"Error while getting read-only code: {e}") + + # HASHING + try: + hash_pruned = parse_code_and_prune_cst( + cleaned, CodeContextType.HASHING, set(), hoh_names, remove_docstrings=True + ) + if hash_pruned.code.strip(): + hash_code = ast.unparse(ast.parse(hash_pruned.code)) + hashing.code_strings.append(CodeString(code=hash_code, file_path=relative_path)) + except ValueError as e: + logger.debug(f"Error while getting hashing code: {e}") + + # TESTGEN + try: + testgen_pruned = parse_code_and_prune_cst( + cleaned, CodeContextType.TESTGEN, set(), hoh_names, remove_docstrings=False + ) + if testgen_pruned.code.strip(): + testgen_code = add_needed_imports_from_module( + src_module_code=original_module, + dst_module_code=testgen_pruned, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=helper_functions, + ) + testgen.code_strings.append(CodeString(code=testgen_code, file_path=relative_path)) + except ValueError as e: + logger.debug(f"Error while getting testgen code: {e}") + + file_caches.append( + FileContextCache( + original_module=original_module, + cleaned_module=cleaned, + fto_names=set(), + hoh_names=hoh_names, + helper_functions=helper_functions, + file_path=file_path, + relative_path=relative_path, + ) ) - if result is not None: - code_context_markdown.code_strings.append(result) - return code_context_markdown + return AllContextResults(read_writable=rw, read_only=ro, hashing=hashing, testgen=testgen, file_caches=file_caches) + + +def re_extract_from_cache( + file_caches: list[FileContextCache], + code_context_type: CodeContextType, + project_root_path: Path, + remove_docstrings: bool = True, +) -> CodeStringsMarkdown: + """Re-extract context from cached modules without file I/O or CST parsing.""" + result = CodeStringsMarkdown() + for cache in file_caches: + try: + pruned = parse_code_and_prune_cst( + cache.cleaned_module, + code_context_type, + cache.fto_names, + cache.hoh_names, + remove_docstrings=remove_docstrings, + ) + except ValueError: + continue + if pruned.code.strip(): + if code_context_type == CodeContextType.HASHING: + code = ast.unparse(ast.parse(pruned.code)) + else: + code = add_needed_imports_from_module( + src_module_code=cache.original_module, + dst_module_code=pruned, + src_path=cache.file_path, + dst_path=cache.file_path, + project_root=project_root_path, + helper_functions=cache.helper_functions, + ) + result.code_strings.append(CodeString(code=code, file_path=cache.relative_path)) + return result def get_function_to_optimize_as_function_source( @@ -420,10 +541,7 @@ def get_function_sources_from_jedi( and definition.full_name.startswith(definition.module_name) ) if is_valid_definition and definition.type in ("function", "class", "statement"): - if definition.type == "function": - fqn = definition.full_name - func_name = definition.name - elif definition.type == "class": + if definition.type == "class": fqn = f"{definition.full_name}.__init__" func_name = "__init__" else: @@ -452,73 +570,30 @@ def _parse_and_collect_imports(code_context: CodeStringsMarkdown) -> tuple[ast.M tree = ast.parse(all_code) except SyntaxError: return None - imported_names: dict[str, str] = {} - - # Directly iterate over the module body and nested structures instead of ast.walk - # This avoids traversing every single node in the tree - def collect_imports(nodes: list[ast.stmt]) -> None: - for node in nodes: - if isinstance(node, ast.ImportFrom) and node.module: - for alias in node.names: - if alias.name != "*": - imported_name = alias.asname if alias.asname else alias.name - imported_names[imported_name] = node.module - # Recursively check nested structures (function defs, class defs, if statements, etc.) - elif isinstance( - node, - ( - ast.FunctionDef, - ast.AsyncFunctionDef, - ast.ClassDef, - ast.If, - ast.For, - ast.AsyncFor, - ast.While, - ast.With, - ast.AsyncWith, - ast.Try, - ast.ExceptHandler, - ), - ): - if hasattr(node, "body"): - collect_imports(node.body) - if hasattr(node, "orelse"): - collect_imports(node.orelse) - if hasattr(node, "finalbody"): - collect_imports(node.finalbody) - if hasattr(node, "handlers"): - for handler in node.handlers: - collect_imports(handler.body) - # Handle match/case statements (Python 3.10+) - elif hasattr(ast, "Match") and isinstance(node, ast.Match): - for case in node.cases: - collect_imports(case.body) - - collect_imports(tree.body) - return tree, imported_names + collector = ImportCollector() + collector.visit(tree) + return tree, collector.imported_names def collect_existing_class_names(tree: ast.Module) -> set[str]: class_names = set() - stack = list(tree.body) + stack: list[ast.AST] = [tree] while stack: node = stack.pop() + if isinstance(node, ast.ClassDef): class_names.add(node.name) - stack.extend(node.body) - elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - stack.extend(node.body) - elif isinstance(node, (ast.If, ast.For, ast.While, ast.With)): + + # Only traverse nodes that can contain ClassDef nodes + if hasattr(node, "body"): stack.extend(node.body) if hasattr(node, "orelse"): stack.extend(node.orelse) - elif isinstance(node, ast.Try): - stack.extend(node.body) - stack.extend(node.orelse) - stack.extend(node.finalbody) - for handler in node.handlers: - stack.extend(handler.body) + if hasattr(node, "finalbody"): + stack.extend(node.finalbody) + if hasattr(node, "handlers"): + stack.extend(node.handlers) return class_names @@ -621,79 +696,484 @@ def collect_type_names_from_annotation(node: ast.expr | None) -> set[str]: return set() +MAX_RAW_PROJECT_CLASS_BODY_ITEMS = 8 +MAX_RAW_PROJECT_CLASS_LINES = 40 + + +def _get_expr_name(node: ast.AST | None) -> str | None: + if node is None: + return None + + # Iteratively collect attribute parts and skip Call nodes to avoid recursion. + parts: list[str] = [] + current = node + # Walk down attribute/call chain collecting attribute names. + while True: + if isinstance(current, ast.Attribute): + # collect attrs in reverse (will join later) + parts.append(current.attr) + current = current.value + continue + if isinstance(current, ast.Call): + current = current.func + continue + if isinstance(current, ast.Name): + # If we reached a base name, include it at the front. + base_name = current.id + else: + base_name = None + break + + if not parts: + # No attribute parts collected: return base name or None (matches original). + return base_name + + # parts were collected from outermost to innermost attr (append order), + # but we want base-first order. Reverse to get innermost-first, then prepend base if present. + parts.reverse() + if base_name is not None: + parts.insert(0, base_name) + # Join parts with dots. If base_name is None, this still returns the joined attrs, + # which matches the original behavior where an Attribute with non-name base returns attr(s). + return ".".join(parts) + + +def _collect_import_aliases(module_tree: ast.Module) -> dict[str, str]: + aliases: dict[str, str] = {} + for node in module_tree.body: + if isinstance(node, ast.Import): + for alias in node.names: + bound_name = alias.asname if alias.asname else alias.name.split(".")[0] + aliases[bound_name] = alias.name + elif isinstance(node, ast.ImportFrom) and node.module: + for alias in node.names: + bound_name = alias.asname if alias.asname else alias.name + aliases[bound_name] = f"{node.module}.{alias.name}" + return aliases + + +def _find_class_node_by_name(class_name: str, module_tree: ast.Module) -> ast.ClassDef | None: + stack: list[ast.AST] = [module_tree] + while stack: + node = stack.pop() + body = getattr(node, "body", None) + if body: + for item in body: + if isinstance(item, ast.ClassDef): + if item.name == class_name: + return item + stack.append(item) + elif isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): + stack.append(item) + return None + + +def _expr_matches_name(node: ast.AST | None, import_aliases: dict[str, str], suffix: str) -> bool: + expr_name = _get_expr_name(node) + if expr_name is None: + return False + + # Precompute ".suffix" to avoid repeated f-string allocations. + suffix_dot = "." + suffix + if expr_name == suffix or expr_name.endswith(suffix_dot): + return True + resolved_name = import_aliases.get(expr_name) + return resolved_name is not None and (resolved_name == suffix or resolved_name.endswith(suffix_dot)) + + +def _get_node_source(node: ast.AST | None, module_source: str, fallback: str = "...") -> str: + if node is None: + return fallback + source_segment = ast.get_source_segment(module_source, node) + if source_segment is not None: + return source_segment + try: + return ast.unparse(node) + except Exception: + return fallback + + +def _bool_literal(node: ast.AST) -> bool | None: + if isinstance(node, ast.Constant) and isinstance(node.value, bool): + return node.value + return None + + +def _is_namedtuple_class(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> bool: + for base in class_node.bases: # noqa: SIM110 + if _expr_matches_name(base, import_aliases, "NamedTuple"): + return True + return False + + +def _get_dataclass_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]: + for decorator in class_node.decorator_list: + if not _expr_matches_name(decorator, import_aliases, "dataclass"): + continue + init_enabled = True + kw_only = False + if isinstance(decorator, ast.Call): + for keyword in decorator.keywords: + literal_value = _bool_literal(keyword.value) + if literal_value is None: + continue + if keyword.arg == "init": + init_enabled = literal_value + elif keyword.arg == "kw_only": + kw_only = literal_value + return True, init_enabled, kw_only + return False, False, False + + +def _is_classvar_annotation(annotation: ast.expr, import_aliases: dict[str, str]) -> bool: + annotation_root = annotation.value if isinstance(annotation, ast.Subscript) else annotation + return _expr_matches_name(annotation_root, import_aliases, "ClassVar") + + +def _is_project_path(module_path: Path, project_root_path: Path) -> bool: + return str(module_path.resolve()).startswith(str(project_root_path.resolve()) + os.sep) + + +def _get_class_start_line(class_node: ast.ClassDef) -> int: + if class_node.decorator_list: + return min(d.lineno for d in class_node.decorator_list) + return class_node.lineno + + +def _class_has_explicit_init(class_node: ast.ClassDef) -> bool: + for item in class_node.body: + if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__": + return True + return False + + +def _collect_synthetic_constructor_type_names(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> set[str]: + is_dataclass, dataclass_init_enabled, _ = _get_dataclass_config(class_node, import_aliases) + if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass: + return set() + if is_dataclass and not dataclass_init_enabled: + return set() + + names = set[str]() + for item in class_node.body: + if not isinstance(item, ast.AnnAssign) or not isinstance(item.target, ast.Name) or item.annotation is None: + continue + if _is_classvar_annotation(item.annotation, import_aliases): + continue + + include_in_init = True + if isinstance(item.value, ast.Call) and _expr_matches_name(item.value.func, import_aliases, "field"): + for keyword in item.value.keywords: + if keyword.arg != "init": + continue + literal_value = _bool_literal(keyword.value) + if literal_value is not None: + include_in_init = literal_value + break + + if include_in_init: + names |= collect_type_names_from_annotation(item.annotation) + + return names + + +def _extract_synthetic_init_parameters( + class_node: ast.ClassDef, module_source: str, import_aliases: dict[str, str], *, kw_only_by_default: bool +) -> list[tuple[str, str, str | None, bool]]: + parameters: list[tuple[str, str, str | None, bool]] = [] + for item in class_node.body: + if not isinstance(item, ast.AnnAssign) or not isinstance(item.target, ast.Name): + continue + if _is_classvar_annotation(item.annotation, import_aliases): + continue + + include_in_init = True + kw_only = kw_only_by_default + default_value: str | None = None + if item.value is not None: + if isinstance(item.value, ast.Call) and _expr_matches_name(item.value.func, import_aliases, "field"): + for keyword in item.value.keywords: + if keyword.arg == "init": + literal_value = _bool_literal(keyword.value) + if literal_value is not None: + include_in_init = literal_value + elif keyword.arg == "kw_only": + literal_value = _bool_literal(keyword.value) + if literal_value is not None: + kw_only = literal_value + elif keyword.arg == "default": + default_value = _get_node_source(keyword.value, module_source) + elif keyword.arg == "default_factory": + # Default factories still imply an optional constructor parameter, but + # the generated __init__ does not use the field() call directly. + default_value = "..." + else: + default_value = _get_node_source(item.value, module_source) + + if not include_in_init: + continue + + parameters.append( + (item.target.id, _get_node_source(item.annotation, module_source, "Any"), default_value, kw_only) + ) + return parameters + + +def _build_synthetic_init_stub( + class_node: ast.ClassDef, module_source: str, import_aliases: dict[str, str] +) -> str | None: + is_namedtuple = _is_namedtuple_class(class_node, import_aliases) + is_dataclass, dataclass_init_enabled, dataclass_kw_only = _get_dataclass_config(class_node, import_aliases) + if not is_namedtuple and not is_dataclass: + return None + if is_dataclass and not dataclass_init_enabled: + return None + + parameters = _extract_synthetic_init_parameters( + class_node, module_source, import_aliases, kw_only_by_default=dataclass_kw_only + ) + if not parameters: + return None + + signature_parts = ["self"] + inserted_kw_only_marker = False + for param_name, annotation_source, default_value, kw_only in parameters: + if kw_only and not inserted_kw_only_marker: + signature_parts.append("*") + inserted_kw_only_marker = True + part = f"{param_name}: {annotation_source}" + if default_value is not None: + part += f" = {default_value}" + signature_parts.append(part) + + signature = ", ".join(signature_parts) + return f" def __init__({signature}):\n ..." + + +def _extract_function_stub_snippet(fn_node: ast.FunctionDef | ast.AsyncFunctionDef, module_lines: list[str]) -> str: + start_line = min(d.lineno for d in fn_node.decorator_list) if fn_node.decorator_list else fn_node.lineno + return "\n".join(module_lines[start_line - 1 : fn_node.end_lineno]) + + +def _extract_raw_class_context(class_node: ast.ClassDef, module_source: str, module_tree: ast.Module) -> str: + class_source = "\n".join(module_source.splitlines()[_get_class_start_line(class_node) - 1 : class_node.end_lineno]) + needed_imports = extract_imports_for_class(module_tree, class_node, module_source) + if needed_imports: + return f"{needed_imports}\n\n{class_source}" + return class_source + + +def _has_non_property_method_decorator( + fn_node: ast.FunctionDef | ast.AsyncFunctionDef, import_aliases: dict[str, str] +) -> bool: + for decorator in fn_node.decorator_list: + if _expr_matches_name(decorator, import_aliases, "property"): + continue + decorator_name = _get_expr_name(decorator) + if decorator_name and decorator_name.endswith((".setter", ".deleter")): + continue + return True + return False + + +def _has_descriptor_like_class_fields(class_node: ast.ClassDef) -> bool: + for item in class_node.body: + if isinstance(item, (ast.Assign, ast.AnnAssign)) and isinstance(item.value, ast.Call): + return True + return False + + +def _should_use_raw_project_class_context(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> bool: + # Check decorator presence first - cheapest check that can short-circuit + if class_node.decorator_list: + return True + + # Check for namedtuple/dataclass early - these are common patterns that avoid body scanning + if _is_namedtuple_class(class_node, import_aliases): + return True + is_dataclass, _, _ = _get_dataclass_config(class_node, import_aliases) + if is_dataclass: + return True + + # Calculate size metrics once + start_line = _get_class_start_line(class_node) + assert class_node.end_lineno is not None + class_line_count = class_node.end_lineno - start_line + 1 + is_small = ( + class_line_count <= MAX_RAW_PROJECT_CLASS_LINES and len(class_node.body) <= MAX_RAW_PROJECT_CLASS_BODY_ITEMS + ) + + # Single-pass body scan with early returns + has_explicit_init = False + + for item in class_node.body: + if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): + if item.name == "__init__": + has_explicit_init = True + if is_small: + return True + if _has_non_property_method_decorator(item, import_aliases): + return True + elif isinstance(item, (ast.Assign, ast.AnnAssign)) and isinstance(item.value, ast.Call): + return True + + return False + + def extract_init_stub_from_class(class_name: str, module_source: str, module_tree: ast.Module) -> str | None: - class_node = None - # Use a deque-based BFS to find the first matching ClassDef (preserves ast.walk order) - q: deque[ast.AST] = deque([module_tree]) - while q: - candidate = q.popleft() - if isinstance(candidate, ast.ClassDef) and candidate.name == class_name: - class_node = candidate - break - q.extend(ast.iter_child_nodes(candidate)) + class_node = _find_class_node_by_name(class_name, module_tree) if class_node is None: return None lines = module_source.splitlines() - relevant_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = [] + import_aliases = _collect_import_aliases(module_tree) + explicit_init_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = [] + support_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = [] for item in class_node.body: if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): - is_relevant = False - if item.name in ("__init__", "__post_init__"): - is_relevant = True - else: - # Check decorators explicitly to avoid generator overhead - for d in item.decorator_list: - if (isinstance(d, ast.Name) and d.id == "property") or ( - isinstance(d, ast.Attribute) and d.attr == "property" - ): - is_relevant = True - break - if is_relevant: - relevant_nodes.append(item) - - if not relevant_nodes: - return None + if item.name == "__init__": + explicit_init_nodes.append(item) + support_nodes.append(item) + continue + if item.name == "__post_init__": + support_nodes.append(item) + continue + # Check decorators explicitly to avoid generator overhead + for d in item.decorator_list: + if (isinstance(d, ast.Name) and d.id == "property") or ( + isinstance(d, ast.Attribute) and d.attr == "property" + ): + support_nodes.append(item) + break snippets: list[str] = [] - for fn_node in relevant_nodes: - start = fn_node.lineno - if fn_node.decorator_list: - # Compute minimum decorator lineno with an explicit loop (avoids generator/min overhead) - m = start - for d in fn_node.decorator_list: - m = min(m, d.lineno) - start = m - snippets.append("\n".join(lines[start - 1 : fn_node.end_lineno])) + if not explicit_init_nodes: + synthetic_init = _build_synthetic_init_stub(class_node, module_source, import_aliases) + if synthetic_init is not None: + snippets.append(synthetic_init) + for fn_node in support_nodes: + snippets.append(_extract_function_stub_snippet(fn_node, lines)) + + if not snippets: + return None return f"class {class_name}:\n" + "\n".join(snippets) -def extract_parameter_type_constructors( - function_to_optimize: FunctionToOptimize, project_root_path: Path, existing_class_names: set[str] -) -> CodeStringsMarkdown: +def _get_module_source_and_tree( + module_path: Path, module_cache: dict[Path, tuple[str, ast.Module]] +) -> tuple[str, ast.Module] | None: + if module_path in module_cache: + return module_cache[module_path] + try: + module_source = module_path.read_text(encoding="utf-8") + module_tree = ast.parse(module_source) + except Exception: + return None + module_cache[module_path] = (module_source, module_tree) + return module_source, module_tree + + +def _resolve_imported_class_reference( + base_expr_name: str, + current_module_tree: ast.Module, + current_module_path: Path, + project_root_path: Path, + module_cache: dict[Path, tuple[str, ast.Module]], +) -> tuple[str, Path] | None: import jedi + import_aliases = _collect_import_aliases(current_module_tree) + class_name = base_expr_name.rsplit(".", 1)[-1] + if "." not in base_expr_name and _find_class_node_by_name(class_name, current_module_tree) is not None: + return class_name, current_module_path + + resolved_name = base_expr_name + if base_expr_name in import_aliases: + resolved_name = import_aliases[base_expr_name] + elif "." in base_expr_name: + head, tail = base_expr_name.split(".", 1) + if head in import_aliases: + resolved_name = f"{import_aliases[head]}.{tail}" + + if "." not in resolved_name: + return None + + module_name, class_name = resolved_name.rsplit(".", 1) try: - source = function_to_optimize.file_path.read_text(encoding="utf-8") - tree = ast.parse(source) + script_code = f"from {module_name} import {class_name}" + script = jedi.Script(script_code, project=jedi.Project(path=project_root_path)) + definitions = script.goto(1, len(f"from {module_name} import ") + len(class_name), follow_imports=True) except Exception: - return CodeStringsMarkdown(code_strings=[]) + return None - func_node = None - for node in ast.walk(tree): - if ( - isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) - and node.name == function_to_optimize.function_name - ): - if function_to_optimize.starting_line is not None and node.lineno != function_to_optimize.starting_line: - continue - func_node = node - break - if func_node is None: - return CodeStringsMarkdown(code_strings=[]) + if not definitions or definitions[0].module_path is None: + return None + module_path = definitions[0].module_path + if not _is_project_path(module_path, project_root_path): + return None + if _get_module_source_and_tree(module_path, module_cache) is None: + return None + return class_name, module_path + + +def _append_project_class_context( + class_name: str, + module_path: Path, + project_root_path: Path, + module_cache: dict[Path, tuple[str, ast.Module]], + existing_class_names: set[str], + emitted_classes: set[tuple[Path, str]], + emitted_class_names: set[str], + code_strings: list[CodeString], +) -> bool: + module_result = _get_module_source_and_tree(module_path, module_cache) + if module_result is None: + return False + module_source, module_tree = module_result + class_node = _find_class_node_by_name(class_name, module_tree) + if class_node is None: + return False + + class_key = (module_path, class_name) + if class_key in emitted_classes or class_name in existing_class_names: + return True + + for base in class_node.bases: + base_expr_name = _get_expr_name(base) + if base_expr_name is None: + continue + resolved = _resolve_imported_class_reference( + base_expr_name, module_tree, module_path, project_root_path, module_cache + ) + if resolved is None: + continue + base_name, base_module_path = resolved + if base_name in existing_class_names: + continue + _append_project_class_context( + base_name, + base_module_path, + project_root_path, + module_cache, + existing_class_names, + emitted_classes, + emitted_class_names, + code_strings, + ) + + code_strings.append( + CodeString(code=_extract_raw_class_context(class_node, module_source, module_tree), file_path=module_path) + ) + emitted_classes.add(class_key) + emitted_class_names.add(class_name) + return True + +def _collect_type_names_from_function( + func_node: ast.FunctionDef | ast.AsyncFunctionDef, tree: ast.Module, class_name: str | None +) -> set[str]: type_names: set[str] = set() for arg in func_node.args.args + func_node.args.posonlyargs + func_node.args.kwonlyargs: type_names |= collect_type_names_from_annotation(arg.annotation) @@ -701,8 +1181,6 @@ def extract_parameter_type_constructors( type_names |= collect_type_names_from_annotation(func_node.args.vararg.annotation) if func_node.args.kwarg: type_names |= collect_type_names_from_annotation(func_node.args.kwarg.annotation) - - # Scan function body for isinstance(x, SomeType) and type(x) is/== SomeType patterns for body_node in ast.walk(func_node): if ( isinstance(body_node, ast.Call) @@ -718,7 +1196,6 @@ def extract_parameter_type_constructors( if isinstance(elt, ast.Name): type_names.add(elt.id) elif isinstance(body_node, ast.Compare): - # type(x) is/== SomeType if ( isinstance(body_node.left, ast.Call) and isinstance(body_node.left.func, ast.Name) @@ -727,78 +1204,136 @@ def extract_parameter_type_constructors( for comparator in body_node.comparators: if isinstance(comparator, ast.Name): type_names.add(comparator.id) - - # Collect base class names from enclosing class (if this is a method) - if function_to_optimize.class_name is not None: + if class_name is not None: for top_node in ast.walk(tree): - if isinstance(top_node, ast.ClassDef) and top_node.name == function_to_optimize.class_name: + if isinstance(top_node, ast.ClassDef) and top_node.name == class_name: for base in top_node.bases: if isinstance(base, ast.Name): type_names.add(base.id) break + return type_names - type_names -= BUILTIN_AND_TYPING_NAMES - type_names -= existing_class_names - if not type_names: - return CodeStringsMarkdown(code_strings=[]) +def _build_import_from_map(tree: ast.Module) -> dict[str, str]: import_map: dict[str, str] = {} for node in ast.walk(tree): if isinstance(node, ast.ImportFrom) and node.module: for alias in node.names: - name = alias.asname if alias.asname else alias.name - import_map[name] = node.module + import_map[alias.asname if alias.asname else alias.name] = node.module + return import_map + + +def extract_parameter_type_constructors( + function_to_optimize: FunctionToOptimize, project_root_path: Path, existing_class_names: set[str] +) -> CodeStringsMarkdown: + import jedi + + try: + source = function_to_optimize.file_path.read_text(encoding="utf-8") + tree = ast.parse(source) + except Exception: + return CodeStringsMarkdown(code_strings=[]) + + func_node = None + for node in ast.walk(tree): + if ( + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == function_to_optimize.function_name + ): + if function_to_optimize.starting_line is not None and node.lineno != function_to_optimize.starting_line: + continue + func_node = node + break + if func_node is None: + return CodeStringsMarkdown(code_strings=[]) + + type_names = _collect_type_names_from_function(func_node, tree, function_to_optimize.class_name) + type_names -= BUILTIN_AND_TYPING_NAMES + type_names -= existing_class_names + if not type_names: + return CodeStringsMarkdown(code_strings=[]) + + import_map = _build_import_from_map(tree) code_strings: list[CodeString] = [] module_cache: dict[Path, tuple[str, ast.Module]] = {} + emitted_classes: set[tuple[Path, str]] = set() + emitted_class_names: set[str] = set() - for type_name in sorted(type_names): - module_name = import_map.get(type_name) - if not module_name: - continue + def append_type_context(type_name: str, module_name: str, *, transitive: bool = False) -> None: try: script_code = f"from {module_name} import {type_name}" script = jedi.Script(script_code, project=jedi.Project(path=project_root_path)) definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True) if not definitions: - continue + return module_path = definitions[0].module_path if not module_path: - continue - - if module_path in module_cache: - mod_source, mod_tree = module_cache[module_path] - else: - mod_source = module_path.read_text(encoding="utf-8") - mod_tree = ast.parse(mod_source) - module_cache[module_path] = (mod_source, mod_tree) + return + resolved_module = module_path.resolve() + module_str = str(resolved_module) + is_project = _is_project_path(module_path, project_root_path) + is_third_party = "site-packages" in module_str + if transitive and not is_project and not is_third_party: + return + + module_result = _get_module_source_and_tree(module_path, module_cache) + if module_result is None: + return + mod_source, mod_tree = module_result + + class_key = (module_path, type_name) + if class_key in emitted_classes or type_name in existing_class_names: + return + + class_node = _find_class_node_by_name(type_name, mod_tree) + if class_node is not None and is_project: + import_aliases = _collect_import_aliases(mod_tree) + if _should_use_raw_project_class_context(class_node, import_aliases): + if _append_project_class_context( + type_name, + module_path, + project_root_path, + module_cache, + existing_class_names, + emitted_classes, + emitted_class_names, + code_strings, + ): + return stub = extract_init_stub_from_class(type_name, mod_source, mod_tree) if stub: code_strings.append(CodeString(code=stub, file_path=module_path)) + emitted_classes.add(class_key) + emitted_class_names.add(type_name) except Exception: - logger.debug(f"Error extracting constructor stub for {type_name} from {module_name}") + if transitive: + logger.debug(f"Error extracting transitive constructor stub for {type_name} from {module_name}") + else: + logger.debug(f"Error extracting constructor stub for {type_name} from {module_name}") + + for type_name in sorted(type_names): + module_name = import_map.get(type_name) + if not module_name: continue + append_type_context(type_name, module_name) # Transitive extraction (one level): for each extracted stub, find __init__ param types and extract their stubs - # Build an extended import map that includes imports from source modules of already-extracted stubs transitive_import_map = dict(import_map) for _, cached_tree in module_cache.values(): - for cache_node in ast.walk(cached_tree): - if isinstance(cache_node, ast.ImportFrom) and cache_node.module: - for alias in cache_node.names: - name = alias.asname if alias.asname else alias.name - if name not in transitive_import_map: - transitive_import_map[name] = cache_node.module - - emitted_names = type_names | existing_class_names | BUILTIN_AND_TYPING_NAMES + for name, module in _build_import_from_map(cached_tree).items(): + transitive_import_map.setdefault(name, module) + + emitted_names = type_names | existing_class_names | emitted_class_names | BUILTIN_AND_TYPING_NAMES transitive_type_names: set[str] = set() for cs in code_strings: try: stub_tree = ast.parse(cs.code) except SyntaxError: continue + import_aliases = _collect_import_aliases(stub_tree) for stub_node in ast.walk(stub_tree): if isinstance(stub_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and stub_node.name in ( "__init__", @@ -806,32 +1341,14 @@ def extract_parameter_type_constructors( ): for arg in stub_node.args.args + stub_node.args.posonlyargs + stub_node.args.kwonlyargs: transitive_type_names |= collect_type_names_from_annotation(arg.annotation) + elif isinstance(stub_node, ast.ClassDef): + transitive_type_names |= _collect_synthetic_constructor_type_names(stub_node, import_aliases) transitive_type_names -= emitted_names for type_name in sorted(transitive_type_names): module_name = transitive_import_map.get(type_name) if not module_name: continue - try: - script_code = f"from {module_name} import {type_name}" - script = jedi.Script(script_code, project=jedi.Project(path=project_root_path)) - definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True) - if not definitions: - continue - module_path = definitions[0].module_path - if not module_path: - continue - if module_path in module_cache: - mod_source, mod_tree = module_cache[module_path] - else: - mod_source = module_path.read_text(encoding="utf-8") - mod_tree = ast.parse(mod_source) - module_cache[module_path] = (mod_source, mod_tree) - stub = extract_init_stub_from_class(type_name, mod_source, mod_tree) - if stub: - code_strings.append(CodeString(code=stub, file_path=module_path)) - except Exception: - logger.debug(f"Error extracting transitive constructor stub for {type_name} from {module_name}") - continue + append_type_context(type_name, module_name, transitive=True) return CodeStringsMarkdown(code_strings=code_strings) @@ -877,30 +1394,13 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path: extracted_classes: set[tuple[Path, str]] = set() module_cache: dict[Path, tuple[str, ast.Module]] = {} - def get_module_source_and_tree(module_path: Path) -> tuple[str, ast.Module] | None: - if module_path in module_cache: - return module_cache[module_path] - try: - module_source = module_path.read_text(encoding="utf-8") - module_tree = ast.parse(module_source) - except Exception: - return None - else: - module_cache[module_path] = (module_source, module_tree) - return module_source, module_tree - def extract_class_and_bases( class_name: str, module_path: Path, module_source: str, module_tree: ast.Module ) -> None: if (module_path, class_name) in extracted_classes: return - class_node = None - for node in ast.walk(module_tree): - if isinstance(node, ast.ClassDef) and node.name == class_name: - class_node = node - break - + class_node = _find_class_node_by_name(class_name, module_tree) if class_node is None: return @@ -918,14 +1418,9 @@ def extract_class_and_bases( return lines = module_source.split("\n") - start_line = class_node.lineno - if class_node.decorator_list: - start_line = min(d.lineno for d in class_node.decorator_list) - class_source = "\n".join(lines[start_line - 1 : class_node.end_lineno]) - - full_source = class_source + class_source = "\n".join(lines[_get_class_start_line(class_node) - 1 : class_node.end_lineno]) - code_strings.append(CodeString(code=full_source, file_path=module_path)) + code_strings.append(CodeString(code=class_source, file_path=module_path)) extracted_classes.add((module_path, class_name)) emitted_class_names.add(class_name) @@ -951,7 +1446,7 @@ def extract_class_and_bases( if not is_project and not is_third_party: continue - mod_result = get_module_source_and_tree(module_path) + mod_result = _get_module_source_and_tree(module_path, module_cache) if mod_result is None: continue module_source, module_tree = mod_result @@ -964,7 +1459,7 @@ def extract_class_and_bases( extract_class_and_bases(resolved_class, module_path, module_source, module_tree) elif is_third_party: target_name = name - if not any(isinstance(n, ast.ClassDef) and n.name == name for n in ast.walk(module_tree)): + if _find_class_node_by_name(name, module_tree) is None: resolved_class = resolve_instance_class_name(name, module_tree) if resolved_class: target_name = resolved_class @@ -981,175 +1476,6 @@ def extract_class_and_bases( return CodeStringsMarkdown(code_strings=code_strings) -def resolve_classes_from_modules(candidates: set[tuple[str, str]]) -> list[tuple[type, str]]: - """Import modules and resolve candidate (class_name, module_name) pairs to class objects.""" - import importlib - import inspect - - resolved: list[tuple[type, str]] = [] - module_cache: dict[str, object] = {} - - for class_name, module_name in candidates: - try: - module = module_cache.get(module_name) - if module is None: - module = importlib.import_module(module_name) - module_cache[module_name] = module - - cls = getattr(module, class_name, None) - if cls is not None and inspect.isclass(cls): - resolved.append((cls, class_name)) - except (ImportError, ModuleNotFoundError, AttributeError): - logger.debug(f"Failed to import {module_name}.{class_name}") - - return resolved - - -MAX_TRANSITIVE_DEPTH = 5 - - -def extract_classes_from_type_hint(hint: object) -> list[type]: - """Recursively extract concrete class objects from a type annotation. - - Unwraps Optional, Union, List, Dict, Callable, Annotated, etc. - Filters out builtins and typing module types. - """ - import typing - - classes: list[type] = [] - origin = getattr(hint, "__origin__", None) - args = getattr(hint, "__args__", None) - - if origin is not None and args: - for arg in args: - classes.extend(extract_classes_from_type_hint(arg)) - elif isinstance(hint, type): - module = getattr(hint, "__module__", "") - if module not in ("builtins", "typing", "typing_extensions", "types"): - classes.append(hint) - # Handle typing.Annotated on older Pythons where __origin__ may not be set - if hasattr(typing, "get_args") and origin is None and args is None: - try: - inner_args = typing.get_args(hint) - if inner_args: - for arg in inner_args: - classes.extend(extract_classes_from_type_hint(arg)) - except Exception: - pass - - return classes - - -def resolve_transitive_type_deps(cls: type) -> list[type]: - """Find external classes referenced in cls.__init__ type annotations. - - Returns classes from site-packages that have a custom __init__. - """ - import inspect - import typing - - try: - init_method = getattr(cls, "__init__") - hints = typing.get_type_hints(init_method) - except Exception: - return [] - - deps: list[type] = [] - for param_name, hint in hints.items(): - if param_name == "return": - continue - for dep_cls in extract_classes_from_type_hint(hint): - if dep_cls is cls: - continue - init_method = getattr(dep_cls, "__init__", None) - if init_method is None or init_method is object.__init__: - continue - try: - class_file = Path(inspect.getfile(dep_cls)) - except (OSError, TypeError): - continue - if not path_belongs_to_site_packages(class_file): - continue - deps.append(dep_cls) - - return deps - - -def extract_init_stub(cls: type, class_name: str, require_site_packages: bool = True) -> CodeString | None: - """Extract a stub containing the class definition with only its __init__ method. - - Args: - cls: The class object to extract __init__ from - class_name: Name to use for the class in the stub - require_site_packages: If True, only extract from site-packages. If False, include stdlib too. - - """ - import inspect - import textwrap - - init_method = getattr(cls, "__init__", None) - if init_method is None or init_method is object.__init__: - return None - - try: - class_file = Path(inspect.getfile(cls)) - except (OSError, TypeError): - return None - - if require_site_packages and not path_belongs_to_site_packages(class_file): - return None - - try: - init_source = inspect.getsource(init_method) - init_source = textwrap.dedent(init_source) - except (OSError, TypeError): - return None - - parts = class_file.parts - if "site-packages" in parts: - idx = parts.index("site-packages") - class_file = Path(*parts[idx + 1 :]) - - class_source = f"class {class_name}:\n" + textwrap.indent(init_source, " ") - return CodeString(code=class_source, file_path=class_file) - - -def _is_project_module_cached(module_name: str, project_root_path: Path, cache: dict[str, bool]) -> bool: - cached = cache.get(module_name) - if cached is not None: - return cached - is_project = _is_project_module(module_name, project_root_path) - cache[module_name] = is_project - return is_project - - -def is_project_path(module_path: Path | None, project_root_path: Path) -> bool: - if module_path is None: - return False - # site-packages must be checked first because .venv/site-packages is under project root - if path_belongs_to_site_packages(module_path): - return False - try: - module_path.resolve().relative_to(project_root_path.resolve()) - return True - except ValueError: - return False - - -def _is_project_module(module_name: str, project_root_path: Path) -> bool: - """Check if a module is part of the project (not external/stdlib).""" - import importlib.util - - try: - spec = importlib.util.find_spec(module_name) - except (ImportError, ModuleNotFoundError, ValueError): - return False - else: - if spec is None or spec.origin is None: - return False - return is_project_path(Path(spec.origin), project_root_path) - - def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str: """Extract import statements needed for a class definition. @@ -1184,26 +1510,22 @@ def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, if isinstance(item.value.func, ast.Name): needed_names.add(item.value.func.id) - # Find imports that provide these names import_lines: list[str] = [] source_lines = module_source.split("\n") - added_imports: set[int] = set() # Track line numbers to avoid duplicates - + added_imports: set[int] = set() for node in module_tree.body: - if isinstance(node, ast.Import): - for alias in node.names: - name = alias.asname if alias.asname else alias.name.split(".")[0] - if name in needed_names and node.lineno not in added_imports: - import_lines.append(source_lines[node.lineno - 1]) - added_imports.add(node.lineno) - break - elif isinstance(node, ast.ImportFrom): - for alias in node.names: - name = alias.asname if alias.asname else alias.name - if name in needed_names and node.lineno not in added_imports: - import_lines.append(source_lines[node.lineno - 1]) - added_imports.add(node.lineno) - break + if not isinstance(node, (ast.Import, ast.ImportFrom)) or node.lineno in added_imports: + continue + for alias in node.names: + name = ( + alias.asname + if alias.asname + else (alias.name.split(".")[0] if isinstance(node, ast.Import) else alias.name) + ) + if name in needed_names: + import_lines.append(source_lines[node.lineno - 1]) + added_imports.add(node.lineno) + break return "\n".join(import_lines) @@ -1225,12 +1547,7 @@ def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None: names.add(node.value.id) -def is_dunder_method(name: str) -> bool: - return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__") - - def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode: - """Removes the docstring from an indented block if it exists.""" if not isinstance(indented_block.body[0], cst.SimpleStatementLine): return indented_block first_stmt = indented_block.body[0].body[0] @@ -1239,46 +1556,73 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode return indented_block +def _maybe_strip_docstring(node: cst.FunctionDef | cst.ClassDef, cfg: PruneConfig) -> cst.FunctionDef | cst.ClassDef: + if cfg.remove_docstrings and isinstance(node.body, cst.IndentedBlock): + return node.with_changes(body=remove_docstring_from_body(node.body)) + return node + + +class ImportCollector(ast.NodeVisitor): + def __init__(self) -> None: + self.imported_names: dict[str, str] = {} + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if node.module: + for alias in node.names: + if alias.name != "*": + self.imported_names[alias.asname if alias.asname else alias.name] = node.module + + +@dataclass(frozen=True) +class PruneConfig: + defs_with_usages: dict[str, UsageInfo] | None = None + helpers: set[str] | None = None + remove_docstrings: bool = False + include_target_in_output: bool = True + exclude_init_from_targets: bool = False + keep_class_init: bool = False + include_dunder_methods: bool = False + include_init_dunder: bool = False + + def parse_code_and_prune_cst( - code: str, + code: str | cst.Module, code_context_type: CodeContextType, target_functions: set[str], helpers_of_helper_functions: set[str] = set(), # noqa: B006 remove_docstrings: bool = False, + defs_with_usages: dict[str, UsageInfo] | None = None, ) -> cst.Module: """Parse and filter the code CST, returning the pruned Module.""" - module = cst.parse_module(code) - defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions) + module = code if isinstance(code, cst.Module) else cst.parse_module(code) if code_context_type == CodeContextType.READ_WRITABLE: - filtered_node, found_target = prune_cst( - module, target_functions, defs_with_usages=defs_with_usages, keep_class_init=True - ) + if defs_with_usages is None: + defs_with_usages = collect_top_level_defs_with_usages( + module, target_functions | helpers_of_helper_functions + ) + cfg = PruneConfig(defs_with_usages=defs_with_usages, keep_class_init=True) elif code_context_type == CodeContextType.READ_ONLY: - filtered_node, found_target = prune_cst( - module, - target_functions, + cfg = PruneConfig( helpers=helpers_of_helper_functions, remove_docstrings=remove_docstrings, include_target_in_output=False, include_dunder_methods=True, ) elif code_context_type == CodeContextType.TESTGEN: - filtered_node, found_target = prune_cst( - module, - target_functions, + cfg = PruneConfig( helpers=helpers_of_helper_functions, remove_docstrings=remove_docstrings, include_dunder_methods=True, include_init_dunder=True, ) elif code_context_type == CodeContextType.HASHING: - filtered_node, found_target = prune_cst( - module, target_functions, remove_docstrings=True, exclude_init_from_targets=True - ) + cfg = PruneConfig(remove_docstrings=True, exclude_init_from_targets=True) else: raise ValueError(f"Unknown code_context_type: {code_context_type}") # noqa: EM102 + filtered_node, found_target = prune_cst(module, target_functions, cfg) + if not found_target: raise ValueError("No target functions found in the provided code") if filtered_node and isinstance(filtered_node, cst.Module): @@ -1287,80 +1631,36 @@ def parse_code_and_prune_cst( def prune_cst( - node: cst.CSTNode, - target_functions: set[str], - prefix: str = "", - *, - defs_with_usages: dict[str, UsageInfo] | None = None, - helpers: set[str] | None = None, - remove_docstrings: bool = False, - include_target_in_output: bool = True, - exclude_init_from_targets: bool = False, - keep_class_init: bool = False, - include_dunder_methods: bool = False, - include_init_dunder: bool = False, + node: cst.CSTNode, target_functions: set[str], cfg: PruneConfig, prefix: str = "" ) -> tuple[cst.CSTNode | None, bool]: - """Unified function to prune CST nodes based on various filtering criteria. - - Args: - node: The CST node to filter - target_functions: Set of qualified function names that are targets - prefix: Current qualified name prefix (for class methods) - defs_with_usages: Dict of definitions with usage info (for READ_WRITABLE mode) - helpers: Set of helper function qualified names (for READ_ONLY/TESTGEN modes) - remove_docstrings: Whether to remove docstrings from output - include_target_in_output: Whether to include target functions in output - exclude_init_from_targets: Whether to exclude __init__ from targets (HASHING mode) - keep_class_init: Whether to keep __init__ methods in classes (READ_WRITABLE mode) - include_dunder_methods: Whether to include dunder methods (READ_ONLY/TESTGEN modes) - include_init_dunder: Whether to include __init__ in dunder methods - - Returns: - (filtered_node, found_target): - filtered_node: The modified CST node or None if it should be removed. - found_target: True if a target function was found in this node's subtree. - - """ if isinstance(node, (cst.Import, cst.ImportFrom)): return None, False if isinstance(node, cst.FunctionDef): qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value - # Check if it's a helper function (higher priority than target) - if helpers and qualified_name in helpers: - if remove_docstrings and isinstance(node.body, cst.IndentedBlock): - return node.with_changes(body=remove_docstring_from_body(node.body)), True - return node, True + if cfg.helpers and qualified_name in cfg.helpers: + return _maybe_strip_docstring(node, cfg), True - # Check if it's a target function if qualified_name in target_functions: - # Handle exclude_init_from_targets for HASHING mode - if exclude_init_from_targets and node.name.value == "__init__": + if cfg.exclude_init_from_targets and node.name.value == "__init__": return None, False - - if include_target_in_output: - if remove_docstrings and isinstance(node.body, cst.IndentedBlock): - return node.with_changes(body=remove_docstring_from_body(node.body)), True - return node, True + if cfg.include_target_in_output: + return _maybe_strip_docstring(node, cfg), True return None, True - # Handle class __init__ for READ_WRITABLE mode - if keep_class_init and node.name.value == "__init__": + if cfg.keep_class_init and node.name.value == "__init__": return node, False - # Handle dunder methods for READ_ONLY/TESTGEN modes if ( - include_dunder_methods + cfg.include_dunder_methods and len(node.name.value) > 4 and node.name.value.startswith("__") and node.name.value.endswith("__") ): - if not include_init_dunder and node.name.value == "__init__": + if not cfg.include_init_dunder and node.name.value == "__init__": return None, False - if remove_docstrings and isinstance(node.body, cst.IndentedBlock): - return node.with_changes(body=remove_docstring_from_body(node.body)), False - return node, False + return _maybe_strip_docstring(node, cfg), False return None, False @@ -1369,62 +1669,41 @@ def prune_cst( return None, False if not isinstance(node.body, cst.IndentedBlock): raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 - class_prefix = node.name.value class_name = node.name.value # Handle dependency classes for READ_WRITABLE mode - if defs_with_usages: - # Check if this class contains any target functions + if cfg.defs_with_usages: has_target_functions = any( - isinstance(stmt, cst.FunctionDef) and f"{class_prefix}.{stmt.name.value}" in target_functions + isinstance(stmt, cst.FunctionDef) and f"{class_name}.{stmt.name.value}" in target_functions for stmt in node.body.body ) - - # If the class is used as a dependency (not containing target functions), keep it entirely if ( not has_target_functions - and class_name in defs_with_usages - and defs_with_usages[class_name].used_by_qualified_function + and class_name in cfg.defs_with_usages + and cfg.defs_with_usages[class_name].used_by_qualified_function ): return node, True - # Recursively filter each statement in the class body new_class_body: list[cst.CSTNode] = [] found_in_class = False for stmt in node.body.body: - filtered, found_target = prune_cst( - stmt, - target_functions, - class_prefix, - defs_with_usages=defs_with_usages, - helpers=helpers, - remove_docstrings=remove_docstrings, - include_target_in_output=include_target_in_output, - exclude_init_from_targets=exclude_init_from_targets, - keep_class_init=keep_class_init, - include_dunder_methods=include_dunder_methods, - include_init_dunder=include_init_dunder, - ) + filtered, found_target = prune_cst(stmt, target_functions, cfg, class_name) found_in_class |= found_target if filtered: new_class_body.append(filtered) if not found_in_class: return None, False - - # Apply docstring removal to class if needed - if remove_docstrings and new_class_body: - updated_body = node.body.with_changes(body=new_class_body) - assert isinstance(updated_body, cst.IndentedBlock) - return node.with_changes(body=remove_docstring_from_body(updated_body)), True - - return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True + if not new_class_body: + return None, True + updated = node.with_changes(body=node.body.with_changes(body=new_class_body)) + return _maybe_strip_docstring(updated, cfg), True # Handle assignments for READ_WRITABLE mode - if defs_with_usages is not None: + if cfg.defs_with_usages is not None: if isinstance(node, (cst.Assign, cst.AnnAssign, cst.AugAssign)): - if is_assignment_used(node, defs_with_usages): + if is_assignment_used(node, cfg.defs_with_usages): return node, True return None, False @@ -1433,66 +1712,14 @@ def prune_cst( if not section_names: return node, False - if helpers is not None: - return recurse_sections( - node, - section_names, - lambda child: prune_cst( - child, - target_functions, - prefix, - defs_with_usages=defs_with_usages, - helpers=helpers, - remove_docstrings=remove_docstrings, - include_target_in_output=include_target_in_output, - exclude_init_from_targets=exclude_init_from_targets, - keep_class_init=keep_class_init, - include_dunder_methods=include_dunder_methods, - include_init_dunder=include_init_dunder, - ), - keep_non_target_children=True, - ) return recurse_sections( node, section_names, - lambda child: prune_cst( - child, - target_functions, - prefix, - defs_with_usages=defs_with_usages, - helpers=helpers, - remove_docstrings=remove_docstrings, - include_target_in_output=include_target_in_output, - exclude_init_from_targets=exclude_init_from_targets, - keep_class_init=keep_class_init, - include_dunder_methods=include_dunder_methods, - include_init_dunder=include_init_dunder, - ), + lambda child: prune_cst(child, target_functions, cfg, prefix), + keep_non_target_children=cfg.helpers is not None, ) -def belongs_to_method(name: Name, class_name: str, method_name: str) -> bool: - """Check if the given name belongs to the specified method.""" - return belongs_to_function(name, method_name) and belongs_to_class(name, class_name) - - -def belongs_to_function(name: Name, function_name: str) -> bool: - """Check if the given jedi Name is a direct child of the specified function.""" - if name.name == function_name: # Handles function definition and recursive function calls - return False - if (name := name.parent()) and name.type == "function": - return bool(name.name == function_name) - return False - - -def belongs_to_class(name: Name, class_name: str) -> bool: - """Check if given jedi Name is a direct child of the specified class.""" - while name := name.parent(): - if name.type == "class": - return bool(name.name == class_name) - return False - - def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> bool: """Check if the given jedi Name is a direct child of the specified function, matched by qualified function name.""" try: diff --git a/codeflash/languages/python/context/unused_definition_remover.py b/codeflash/languages/python/context/unused_definition_remover.py index fff255f67..2c19be8d3 100644 --- a/codeflash/languages/python/context/unused_definition_remover.py +++ b/codeflash/languages/python/context/unused_definition_remover.py @@ -54,20 +54,16 @@ def extract_names_from_targets(target: cst.CSTNode) -> list[str]: def is_assignment_used(node: cst.CSTNode, definitions: dict[str, UsageInfo], name_prefix: str = "") -> bool: if isinstance(node, cst.Assign): - for target in node.targets: - names = extract_names_from_targets(target.target) - for name in names: - lookup = f"{name_prefix}{name}" if name_prefix else name - if lookup in definitions and definitions[lookup].used_by_qualified_function: - return True + targets = [target.target for target in node.targets] + elif isinstance(node, (cst.AnnAssign, cst.AugAssign)): + targets = [node.target] + else: return False - if isinstance(node, (cst.AnnAssign, cst.AugAssign)): - names = extract_names_from_targets(node.target) - for name in names: + for target in targets: + for name in extract_names_from_targets(target): lookup = f"{name_prefix}{name}" if name_prefix else name if lookup in definitions and definitions[lookup].used_by_qualified_function: return True - return False return False @@ -119,84 +115,43 @@ def collect_top_level_definitions( node: cst.CSTNode, definitions: Optional[dict[str, UsageInfo]] = None ) -> dict[str, UsageInfo]: """Recursively collect all top-level variable, function, and class definitions.""" - # Locally bind types and helpers for faster lookup - FunctionDef = cst.FunctionDef # noqa: N806 - ClassDef = cst.ClassDef # noqa: N806 - Assign = cst.Assign # noqa: N806 - AnnAssign = cst.AnnAssign # noqa: N806 - AugAssign = cst.AugAssign # noqa: N806 - IndentedBlock = cst.IndentedBlock # noqa: N806 - if definitions is None: definitions = {} - # Speed: Single isinstance+local var instead of several type calls - node_type = type(node) - # Fast path: function def - if node_type is FunctionDef: + if isinstance(node, cst.FunctionDef): name = node.name.value - definitions[name] = UsageInfo( - name=name, - used_by_qualified_function=False, # Will be marked later if in qualified functions - ) + definitions[name] = UsageInfo(name=name) return definitions - # Fast path: class def - if node_type is ClassDef: + if isinstance(node, cst.ClassDef): name = node.name.value definitions[name] = UsageInfo(name=name) - - # Collect class methods - body = getattr(node, "body", None) - if body is not None and type(body) is IndentedBlock: - statements = body.body - # Precompute f-string template for efficiency + if isinstance(node.body, cst.IndentedBlock): prefix = name + "." - for statement in statements: - if type(statement) is FunctionDef: + for statement in node.body.body: + if isinstance(statement, cst.FunctionDef): method_name = prefix + statement.name.value definitions[method_name] = UsageInfo(name=method_name) - return definitions - # Fast path: assignment - if node_type is Assign: - # Inline extract_names_from_targets for single-target speed - targets = node.targets - append_def = definitions.__setitem__ - for target in targets: - names = extract_names_from_targets(target.target) - for name in names: - append_def(name, UsageInfo(name=name)) + if isinstance(node, cst.Assign): + for target in node.targets: + for name in extract_names_from_targets(target.target): + definitions[name] = UsageInfo(name=name) return definitions - if node_type is AnnAssign or node_type is AugAssign: - tgt = node.target - if type(tgt) is cst.Name: - name = tgt.value + if isinstance(node, (cst.AnnAssign, cst.AugAssign)): + for name in extract_names_from_targets(node.target): definitions[name] = UsageInfo(name=name) - else: - names = extract_names_from_targets(tgt) - for name in names: - definitions[name] = UsageInfo(name=name) return definitions - # Recursively process children. Takes care of top level assignments in if/else/while/for blocks - section_names = get_section_names(node) - - if section_names: - getattr_ = getattr - for section in section_names: - original_content = getattr_(node, section, None) - # Instead of isinstance check for list/tuple, rely on duck-type via iter - # If section contains a list of nodes - if isinstance(original_content, (list, tuple)): - defs = definitions # Move out for minor speed - for child in original_content: - collect_top_level_definitions(child, defs) - # If section contains a single node - elif original_content is not None: - collect_top_level_definitions(original_content, definitions) + for section in get_section_names(node): + original_content = getattr(node, section, None) + if isinstance(original_content, (list, tuple)): + for child in original_content: + collect_top_level_definitions(child, definitions) + elif original_content is not None: + collect_top_level_definitions(original_content, definitions) return definitions @@ -237,43 +192,24 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # Regular top-level function self.current_top_level_name = function_name - # Check parameter type annotations for dependencies - if hasattr(node, "params") and node.params: - for param in node.params.params: - if param.annotation: - # Visit the annotation to extract dependencies - self._collect_annotation_dependencies(param.annotation) + for param in node.params.params: + if param.annotation: + self._extract_names_from_annotation(param.annotation.annotation) self.function_depth += 1 - def _collect_annotation_dependencies(self, annotation: cst.Annotation) -> None: - """Extract dependencies from type annotations.""" - if hasattr(annotation, "annotation"): - # Extract names from annotation (could be Name, Attribute, Subscript, etc.) - self._extract_names_from_annotation(annotation.annotation) - def _extract_names_from_annotation(self, node: cst.CSTNode) -> None: - """Extract names from a type annotation node.""" - # Simple name reference like 'int', 'str', or custom type if isinstance(node, cst.Name): name = node.value if name in self.definitions and name != self.current_top_level_name and self.current_top_level_name: self.definitions[self.current_top_level_name].dependencies.add(name) - - # Handle compound annotations like List[int], Dict[str, CustomType], etc. elif isinstance(node, cst.Subscript): - if hasattr(node, "value"): - self._extract_names_from_annotation(node.value) - if hasattr(node, "slice"): - for slice_item in node.slice: - if hasattr(slice_item, "slice"): - self._extract_names_from_annotation(slice_item.slice) - - # Handle attribute access like module.Type + self._extract_names_from_annotation(node.value) + for slice_item in node.slice: + if hasattr(slice_item, "slice"): + self._extract_names_from_annotation(slice_item.slice) elif isinstance(node, cst.Attribute): - if hasattr(node, "value"): - self._extract_names_from_annotation(node.value) - # No need to check the attribute name itself as it's likely not a top-level definition + self._extract_names_from_annotation(node.value) def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: self.function_depth -= 1 @@ -334,21 +270,16 @@ def leave_Assign(self, original_node: cst.Assign) -> None: self.current_top_level_name = "" def visit_AnnAssign(self, node: cst.AnnAssign) -> None: - # Extract names from the variable annotations - if hasattr(node, "annotation") and node.annotation: - # First mark we're processing a variable to avoid recording it as a dependency of itself - self.processing_variable = True - if isinstance(node.target, cst.Name): - self.current_variable_names.add(node.target.value) - else: - self.current_variable_names.update(extract_names_from_targets(node.target)) + self.processing_variable = True + if isinstance(node.target, cst.Name): + self.current_variable_names.add(node.target.value) + else: + self.current_variable_names.update(extract_names_from_targets(node.target)) - # Process the annotation - self._collect_annotation_dependencies(node.annotation) + self._extract_names_from_annotation(node.annotation.annotation) - # Reset processing state - self.processing_variable = False - self.current_variable_names.clear() + self.processing_variable = False + self.current_variable_names.clear() def visit_Name(self, node: cst.Name) -> None: name = node.value @@ -406,18 +337,8 @@ def _expand_qualified_functions(self) -> set[str]: def mark_used_definitions(self) -> None: """Find all qualified functions and mark them and their dependencies as used.""" - # Avoid list comprehension for set intersection - expanded_names = self.expanded_qualified_functions defs = self.definitions - # Use set intersection but only if defs.keys is a set (Python 3.12 dict_keys supports it efficiently) - fnames = ( - expanded_names & defs.keys() - if isinstance(expanded_names, set) - else [name for name in expanded_names if name in defs] - ) - - # For each specified function, mark it and all its dependencies as used - for func_name in fnames: + for func_name in self.expanded_qualified_functions & defs.keys(): defs[func_name].used_by_qualified_function = True for dep in defs[func_name].dependencies: self.mark_as_used_recursively(dep) @@ -442,17 +363,7 @@ def remove_unused_definitions_recursively( ) -> tuple[cst.CSTNode | None, bool]: """Recursively filter the node to remove unused definitions. - Args: - ---- - node: The CST node to process - definitions: Dictionary of definition info - - Returns: - ------- - (filtered_node, used_by_function): - filtered_node: The modified CST node or None if it should be removed - used_by_function: True if this node or any child is used by qualified functions - + Returns (filtered_node_or_None, used_by_function). """ # Skip import statements if isinstance(node, (cst.Import, cst.ImportFrom)): @@ -462,50 +373,25 @@ def remove_unused_definitions_recursively( if isinstance(node, cst.FunctionDef): return node, True - # Never remove class definitions if isinstance(node, cst.ClassDef): class_name = node.name.value + class_has_dependencies = class_name in definitions and definitions[class_name].used_by_qualified_function - # Check if any methods or variables in this class are used - method_or_var_used = False - class_has_dependencies = False - - # Check if class itself is marked as used - if class_name in definitions and definitions[class_name].used_by_qualified_function: - class_has_dependencies = True - - if hasattr(node, "body") and isinstance(node.body, cst.IndentedBlock): - updates = {} + if isinstance(node.body, cst.IndentedBlock): new_statements = [] - for statement in node.body.body: - # Keep all function definitions if isinstance(statement, cst.FunctionDef): - method_name = f"{class_name}.{statement.name.value}" - if method_name in definitions and definitions[method_name].used_by_qualified_function: - method_or_var_used = True new_statements.append(statement) - # Only process variable assignments elif isinstance(statement, (cst.Assign, cst.AnnAssign, cst.AugAssign)): - var_used = False - - if is_assignment_used(statement, definitions, name_prefix=f"{class_name}."): - var_used = True - method_or_var_used = True - - if var_used or class_has_dependencies: + if class_has_dependencies or is_assignment_used( + statement, definitions, name_prefix=f"{class_name}." + ): new_statements.append(statement) else: - # Keep all other statements in the class new_statements.append(statement) + return node.with_changes(body=node.body.with_changes(body=new_statements)), True - # Update the class body - new_body = node.body.with_changes(body=new_statements) - updates["body"] = new_body - - return node.with_changes(**updates), True - - return node, method_or_var_used or class_has_dependencies + return node, class_has_dependencies # Handle assignments (Assign, AnnAssign, AugAssign) if isinstance(node, (cst.Assign, cst.AnnAssign, cst.AugAssign)): @@ -541,69 +427,52 @@ def collect_top_level_defs_with_usages( return definitions -def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str: - """Analyze a file and remove top level definitions not used by specified functions. - - Top level definitions, in this context, are only classes, variables or functions. - If a class is referenced by a qualified function, we keep the entire class. - - Args: - ---- - code: The code to process - qualified_function_names: Set of function names to keep. For methods, use format 'classname.methodname' - - """ +def remove_unused_definitions_by_function_names( + code: Union[str, cst.Module], + qualified_function_names: set[str], + defs_with_usages: dict[str, UsageInfo] | None = None, +) -> cst.Module: + """Remove top-level definitions (classes, variables, functions) not used by the specified qualified function names.""" try: - module = cst.parse_module(code) + module = code if isinstance(code, cst.Module) else cst.parse_module(code) except Exception as e: logger.debug(f"Failed to parse code with libcst: {type(e).__name__}: {e}") - return code + return code if isinstance(code, cst.Module) else cst.parse_module("") try: - defs_with_usages = collect_top_level_defs_with_usages(module, qualified_function_names) + if defs_with_usages is None: + defs_with_usages = collect_top_level_defs_with_usages(module, qualified_function_names) # Apply the recursive removal transformation modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages) - return modified_module.code if modified_module else "" + return modified_module if modified_module else cst.parse_module("") except Exception as e: # If any other error occurs during processing, return the original code logger.debug(f"Error processing code to remove unused definitions: {type(e).__name__}: {e}") - return code + return module def revert_unused_helper_functions( project_root: Path, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str] ) -> None: - """Revert unused helper functions back to their original definitions. - - Args: - project_root: project_root - unused_helpers: List of unused helper functions to revert - original_helper_code: Dictionary mapping file paths to their original code - - """ + """Revert unused helper functions back to their original definitions.""" if not unused_helpers: return logger.debug(f"Reverting {len(unused_helpers)} unused helper function(s) to original definitions") - # Resolve all path keys for consistent comparison (Windows 8.3 short names may differ from Jedi-resolved paths) + # Resolve path keys for consistent comparison (Windows 8.3 short names may differ from Jedi-resolved paths) resolved_original_helper_code = {p.resolve(): code for p, code in original_helper_code.items()} - # Group unused helpers by file path unused_helpers_by_file = defaultdict(list) for helper in unused_helpers: unused_helpers_by_file[helper.file_path.resolve()].append(helper) - # For each file, revert the unused helper functions to their original definitions for file_path, helpers_in_file in unused_helpers_by_file.items(): if file_path in resolved_original_helper_code: try: - # Get original code for this file original_code = resolved_original_helper_code[file_path] - - # Use the code replacer to selectively revert only the unused helper functions helper_names = [helper.qualified_name for helper in helpers_in_file] reverted_code = replace_function_definitions_in_module( function_names=helper_names, @@ -611,11 +480,11 @@ def revert_unused_helper_functions( code_strings=[ CodeString(code=original_code, file_path=Path(file_path).relative_to(project_root)) ] - ), # Use original code as the "optimized" code to revert + ), module_abspath=file_path, preexisting_objects=set(), # Empty set since we're reverting project_root_path=project_root, - should_add_global_assignments=False, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice. + should_add_global_assignments=False, # file already has global assignments from the optimization pass ) if reverted_code: @@ -628,16 +497,7 @@ def revert_unused_helper_functions( def _analyze_imports_in_optimized_code( optimized_ast: ast.AST, code_context: CodeOptimizationContext ) -> dict[str, set[str]]: - """Analyze import statements in optimized code to map imported names to qualified helper names. - - Args: - optimized_ast: The AST of the optimized code - code_context: The code optimization context containing helper functions - - Returns: - Dictionary mapping imported names to sets of possible qualified helper names - - """ + """Map imported names to qualified helper names based on import statements in optimized code.""" imported_names_map = defaultdict(set) # Precompute a two-level dict: module_name -> func_name -> [helpers] @@ -652,23 +512,7 @@ def _analyze_imports_in_optimized_code( helpers_by_file_and_func[module_name].setdefault(func_name, []).append(helper) helpers_by_file[module_name].append(helper) - # Collect only import nodes to avoid per-node isinstance checks across the whole AST - class _ImportCollector(ast.NodeVisitor): - def __init__(self) -> None: - self.nodes: list[ast.AST] = [] - - def visit_Import(self, node: ast.Import) -> None: - self.nodes.append(node) - # No need to recurse further for import nodes - - def visit_ImportFrom(self, node: ast.ImportFrom) -> None: - self.nodes.append(node) - # No need to recurse further for import-from nodes - - collector = _ImportCollector() - collector.visit(optimized_ast) - - for node in collector.nodes: + for node in ast.walk(optimized_ast): if isinstance(node, ast.ImportFrom): # Handle "from module import function" statements module_name = node.module @@ -731,22 +575,53 @@ def find_target_node( return None +def _collect_attr_names( + value_id: str, attr_name: str, class_name: str | None, names: set[str], imported_names_map: dict[str, set[str]] +) -> None: + if value_id == "self": + names.add(attr_name) + if class_name: + names.add(f"{class_name}.{attr_name}") + else: + names.add(attr_name) + full_ref = f"{value_id}.{attr_name}" + names.add(full_ref) + mapped_names = imported_names_map.get(full_ref) + if mapped_names: + names.update(mapped_names) + + +def _collect_called_names( + entrypoint_ast: ast.FunctionDef | ast.AsyncFunctionDef, + function_to_optimize: FunctionToOptimize, + imported_names_map: dict[str, set[str]], +) -> set[str]: + called = {function_to_optimize.function_name} + class_name = function_to_optimize.parents[0].name if function_to_optimize.parents else None + + for node in ast.walk(entrypoint_ast): + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + called.add(node.func.id) + mapped_names = imported_names_map.get(node.func.id) + if mapped_names: + called.update(mapped_names) + elif isinstance(node.func, ast.Attribute): + if isinstance(node.func.value, ast.Name): + _collect_attr_names(node.func.value.id, node.func.attr, class_name, called, imported_names_map) + else: + called.add(node.func.attr) + elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name): + _collect_attr_names(node.value.id, node.attr, class_name, called, imported_names_map) + + return called + + def detect_unused_helper_functions( function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext, optimized_code: str | CodeStringsMarkdown, ) -> list[FunctionSource]: - """Detect helper functions that are no longer called by the optimized entrypoint function. - - Args: - function_to_optimize: The function to optimize - code_context: The code optimization context containing helper functions - optimized_code: The optimized code to analyze - - Returns: - List of FunctionSource objects representing unused helper functions - - """ # Skip this analysis for non-Python languages since we use Python's ast module if current_language() != Language.PYTHON: logger.debug("Skipping unused helper function detection for non-Python languages") @@ -761,107 +636,43 @@ def detect_unused_helper_functions( ) try: - # Parse the optimized code to analyze function calls and imports optimized_ast = ast.parse(optimized_code) - - # Find the optimized entrypoint function entrypoint_function_ast = find_target_node(optimized_ast, function_to_optimize) if not entrypoint_function_ast: logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code") return [] - # First, analyze imports to build a mapping of imported names to their original qualified names imported_names_map = _analyze_imports_in_optimized_code(optimized_ast, code_context) - - # Extract all function calls and attribute references in the entrypoint function - called_function_names = {function_to_optimize.function_name} - for node in ast.walk(entrypoint_function_ast): - if isinstance(node, ast.Call): - if isinstance(node.func, ast.Name): - # Regular function call: function_name() - called_name = node.func.id - called_function_names.add(called_name) - # Also add the qualified name if this is an imported function - mapped_names = imported_names_map.get(called_name) - if mapped_names: - called_function_names.update(mapped_names) - elif isinstance(node.func, ast.Attribute): - # Method call: obj.method() or self.method() or module.function() - if isinstance(node.func.value, ast.Name): - attr_name = node.func.attr - value_id = node.func.value.id - if value_id == "self": - # self.method_name() -> add both method_name and ClassName.method_name - called_function_names.add(attr_name) - # For class methods, also add the qualified name - if hasattr(function_to_optimize, "parents") and function_to_optimize.parents: - class_name = function_to_optimize.parents[0].name - called_function_names.add(f"{class_name}.{attr_name}") - else: - called_function_names.add(attr_name) - full_call = f"{value_id}.{attr_name}" - called_function_names.add(full_call) - # Check if this is a module.function call that maps to a helper - mapped_names = imported_names_map.get(full_call) - if mapped_names: - called_function_names.update(mapped_names) - # Handle nested attribute access like obj.attr.method() - else: - called_function_names.add(node.func.attr) - elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name): - # Attribute reference without call: e.g. self._parse1 = self._parse_literal - # This covers methods used as callbacks, stored in variables, passed as arguments, etc. - attr_name = node.attr - value_id = node.value.id - if value_id == "self": - called_function_names.add(attr_name) - if hasattr(function_to_optimize, "parents") and function_to_optimize.parents: - class_name = function_to_optimize.parents[0].name - called_function_names.add(f"{class_name}.{attr_name}") - else: - called_function_names.add(attr_name) - full_ref = f"{value_id}.{attr_name}" - called_function_names.add(full_ref) - mapped_names = imported_names_map.get(full_ref) - if mapped_names: - called_function_names.update(mapped_names) + called_function_names = _collect_called_names(entrypoint_function_ast, function_to_optimize, imported_names_map) logger.debug(f"Functions called in optimized entrypoint: {called_function_names}") logger.debug(f"Imported names mapping: {imported_names_map}") - # Find helper functions that are no longer called unused_helpers = [] entrypoint_file_path = function_to_optimize.file_path for helper_function in code_context.helper_functions: - jedi_type = helper_function.definition_type - if jedi_type != "class": # Include when definition_type is None (non-Python) - # Check if the helper function is called using multiple name variants - helper_qualified_name = helper_function.qualified_name - helper_simple_name = helper_function.only_function_name - helper_fully_qualified_name = helper_function.fully_qualified_name - - # Check membership efficiently - exit early on first match - if ( - helper_qualified_name in called_function_names - or helper_simple_name in called_function_names - or helper_fully_qualified_name in called_function_names - ): - is_called = True - # For cross-file helpers, also consider module-based calls - elif helper_function.file_path != entrypoint_file_path: - # Add potential module.function combinations - module_name = helper_function.file_path.stem - module_call = f"{module_name}.{helper_simple_name}" - is_called = module_call in called_function_names - else: - is_called = False + if helper_function.definition_type == "class": + continue + helper_qualified_name = helper_function.qualified_name + helper_simple_name = helper_function.only_function_name + helper_fully_qualified_name = helper_function.fully_qualified_name + + is_called = ( + helper_qualified_name in called_function_names + or helper_simple_name in called_function_names + or helper_fully_qualified_name in called_function_names + or ( + helper_function.file_path != entrypoint_file_path + and f"{helper_function.file_path.stem}.{helper_simple_name}" in called_function_names + ) + ) - if not is_called: - unused_helpers.append(helper_function) - logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code") - else: - logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code") + if not is_called: + unused_helpers.append(helper_function) + logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code") + else: + logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code") except Exception as e: logger.debug(f"Error detecting unused helper functions: {e}") diff --git a/codeflash/languages/python/reference_graph.py b/codeflash/languages/python/reference_graph.py index 46867a74d..949d1c10b 100644 --- a/codeflash/languages/python/reference_graph.py +++ b/codeflash/languages/python/reference_graph.py @@ -3,20 +3,21 @@ import hashlib import os import sqlite3 -from collections import defaultdict from pathlib import Path from typing import TYPE_CHECKING from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages from codeflash.languages.base import IndexResult -from codeflash.models.models import FunctionSource if TYPE_CHECKING: from collections.abc import Callable, Iterable from jedi.api.classes import Name + from codeflash.models.call_graph import CallGraph + from codeflash.models.models import FunctionSource + # --------------------------------------------------------------------------- # Module-level helpers (must be top-level for ProcessPoolExecutor pickling) @@ -41,7 +42,7 @@ def _init_index_worker(project_root: str) -> None: def _resolve_definitions(ref: Name) -> list[Name]: try: inferred = ref.infer() - valid = [d for d in inferred if d.type in ("function", "class")] + valid = [d for d in inferred if d.type in ("function", "class", "statement")] if valid: return valid except Exception: @@ -68,7 +69,7 @@ def _is_valid_definition(definition: Name, caller_qualified_name: str, project_r if not definition.full_name or not definition.full_name.startswith(definition.module_name): return False - if definition.type not in ("function", "class"): + if definition.type not in ("function", "class", "statement"): return False try: @@ -163,6 +164,20 @@ def _analyze_file(file_path: Path, jedi_project: object, project_root_str: str) definition.get_line_code(), ) ) + elif definition.type == "statement": + callee_qn = get_qualified_name(definition.module_name, definition.full_name) + if len(callee_qn.split(".")) > 2: + continue + edges.add( + ( + *edge_base, + callee_qn, + definition.full_name, + definition.name, + definition.type, + definition.get_line_code(), + ) + ) except Exception: continue @@ -199,6 +214,7 @@ def __init__(self, project_root: Path, language: str = "python", db_path: Path | self.conn = sqlite3.connect(str(db_path)) self.conn.execute("PRAGMA journal_mode=WAL") self.indexed_file_hashes: dict[str, str] = {} + self._resolved_paths: dict[Path, str] = {} self._init_schema() def _init_schema(self) -> None: @@ -259,59 +275,28 @@ def _init_schema(self) -> None: ) self.conn.commit() + def resolve_path(self, file_path: Path) -> str: + cached = self._resolved_paths.get(file_path) + if cached is not None: + return cached + resolved = str(file_path.resolve()) + self._resolved_paths[file_path] = resolved + return resolved + def get_callees( self, file_path_to_qualified_names: dict[Path, set[str]] ) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: - file_path_to_function_source: dict[Path, set[FunctionSource]] = defaultdict(set) - function_source_list: list[FunctionSource] = [] - - all_caller_keys: list[tuple[str, str]] = [] - for file_path, qualified_names in file_path_to_qualified_names.items(): - resolved = str(file_path.resolve()) - self.ensure_file_indexed(file_path, resolved) - all_caller_keys.extend((resolved, qn) for qn in qualified_names) - - if not all_caller_keys: - return file_path_to_function_source, function_source_list - - cur = self.conn.cursor() - cur.execute("CREATE TEMP TABLE IF NOT EXISTS _caller_keys (caller_file TEXT, caller_qualified_name TEXT)") - cur.execute("DELETE FROM _caller_keys") - cur.executemany("INSERT INTO _caller_keys VALUES (?, ?)", all_caller_keys) - - rows = cur.execute( - """ - SELECT ce.callee_file, ce.callee_qualified_name, ce.callee_fully_qualified_name, - ce.callee_only_function_name, ce.callee_definition_type, ce.callee_source_line - FROM call_edges ce - INNER JOIN _caller_keys ck - ON ce.caller_file = ck.caller_file AND ce.caller_qualified_name = ck.caller_qualified_name - WHERE ce.project_root = ? AND ce.language = ? - """, - (self.project_root_str, self.language), - ).fetchall() - - for callee_file, callee_qn, callee_fqn, callee_name, callee_type, callee_src in rows: - callee_path = Path(callee_file) - fs = FunctionSource( - file_path=callee_path, - qualified_name=callee_qn, - fully_qualified_name=callee_fqn, - only_function_name=callee_name, - source_code=callee_src, - definition_type=callee_type, - ) - file_path_to_function_source[callee_path].add(fs) - function_source_list.append(fs) + from codeflash.models.call_graph import callees_from_graph - return file_path_to_function_source, function_source_list + graph = self.get_call_graph(file_path_to_qualified_names, include_metadata=True) + return callees_from_graph(graph) def count_callees_per_function( self, file_path_to_qualified_names: dict[Path, set[str]] ) -> dict[tuple[Path, str], int]: all_caller_keys: list[tuple[Path, str, str]] = [] for file_path, qualified_names in file_path_to_qualified_names.items(): - resolved = str(file_path.resolve()) + resolved = self.resolve_path(file_path) self.ensure_file_indexed(file_path, resolved) all_caller_keys.extend((file_path, resolved, qn) for qn in qualified_names) @@ -346,9 +331,10 @@ def count_callees_per_function( def ensure_file_indexed(self, file_path: Path, resolved: str | None = None) -> IndexResult: if resolved is None: - resolved = str(file_path.resolve()) + resolved = self.resolve_path(file_path) - # Always read and hash the file before checking the cache so we detect on-disk changes + # Always read and hash to detect on-disk changes (no in-memory shortcut here; + # build_index has its own fast path for batch initialization) try: content = file_path.read_text(encoding="utf-8") except Exception: @@ -363,7 +349,7 @@ def ensure_file_indexed(self, file_path: Path, resolved: str | None = None) -> I def index_file(self, file_path: Path, file_hash: str, resolved: str | None = None) -> IndexResult: if resolved is None: - resolved = str(file_path.resolve()) + resolved = self.resolve_path(file_path) edges, had_error = _analyze_file(file_path, self.jedi_project, self.project_root_str) if had_error: logger.debug(f"ReferenceGraph: failed to parse {file_path}") @@ -426,7 +412,17 @@ def build_index(self, file_paths: Iterable[Path], on_progress: Callable[[IndexRe to_index: list[tuple[Path, str, str]] = [] for file_path in file_paths: - resolved = str(file_path.resolve()) + resolved = self.resolve_path(file_path) + + # Fast path: already indexed this session + if resolved in self.indexed_file_hashes: + self._report_progress( + on_progress, + IndexResult( + file_path=file_path, cached=True, num_edges=0, edges=(), cross_file_edges=0, error=False + ), + ) + continue try: content = file_path.read_text(encoding="utf-8") @@ -441,7 +437,7 @@ def build_index(self, file_paths: Iterable[Path], on_progress: Callable[[IndexRe file_hash = hashlib.sha256(content.encode("utf-8")).hexdigest() - # Check if already cached (in-memory or DB) + # Check if cached in DB if self._is_file_cached(resolved, file_hash): self._report_progress( on_progress, @@ -540,5 +536,90 @@ def _fallback_sequential_index( result = self.index_file(file_path, file_hash, resolved) self._report_progress(on_progress, result) + def get_call_graph( + self, file_path_to_qualified_names: dict[Path, set[str]], *, include_metadata: bool = False + ) -> CallGraph: + from codeflash.models.call_graph import CallEdge, CalleeMetadata, CallGraph, FunctionNode + + all_caller_keys: list[tuple[Path, str, str]] = [] + for file_path, qualified_names in file_path_to_qualified_names.items(): + resolved = self.resolve_path(file_path) + self.ensure_file_indexed(file_path, resolved) + all_caller_keys.extend((file_path, resolved, qn) for qn in qualified_names) + + if not all_caller_keys: + return CallGraph(edges=[]) + + cur = self.conn.cursor() + cur.execute("CREATE TEMP TABLE IF NOT EXISTS _graph_keys (caller_file TEXT, caller_qualified_name TEXT)") + cur.execute("DELETE FROM _graph_keys") + cur.executemany( + "INSERT INTO _graph_keys VALUES (?, ?)", [(resolved, qn) for _, resolved, qn in all_caller_keys] + ) + + if include_metadata: + rows = cur.execute( + """ + SELECT ce.caller_file, ce.caller_qualified_name, + ce.callee_file, ce.callee_qualified_name, + ce.callee_fully_qualified_name, ce.callee_only_function_name, + ce.callee_definition_type, ce.callee_source_line + FROM call_edges ce + INNER JOIN _graph_keys gk + ON ce.caller_file = gk.caller_file AND ce.caller_qualified_name = gk.caller_qualified_name + WHERE ce.project_root = ? AND ce.language = ? + """, + (self.project_root_str, self.language), + ).fetchall() + + edges: list[CallEdge] = [] + for ( + caller_file, + caller_qn, + callee_file, + callee_qn, + callee_fqn, + callee_name, + callee_type, + callee_src, + ) in rows: + edges.append( + CallEdge( + caller=FunctionNode(file_path=Path(caller_file), qualified_name=caller_qn), + callee=FunctionNode(file_path=Path(callee_file), qualified_name=callee_qn), + is_cross_file=caller_file != callee_file, + callee_metadata=CalleeMetadata( + fully_qualified_name=callee_fqn, + only_function_name=callee_name, + definition_type=callee_type, + source_line=callee_src, + ), + ) + ) + else: + rows = cur.execute( + """ + SELECT ce.caller_file, ce.caller_qualified_name, + ce.callee_file, ce.callee_qualified_name + FROM call_edges ce + INNER JOIN _graph_keys gk + ON ce.caller_file = gk.caller_file AND ce.caller_qualified_name = gk.caller_qualified_name + WHERE ce.project_root = ? AND ce.language = ? + """, + (self.project_root_str, self.language), + ).fetchall() + + edges = [] + for caller_file, caller_qn, callee_file, callee_qn in rows: + edges.append( + CallEdge( + caller=FunctionNode(file_path=Path(caller_file), qualified_name=caller_qn), + callee=FunctionNode(file_path=Path(callee_file), qualified_name=callee_qn), + is_cross_file=caller_file != callee_file, + ) + ) + + return CallGraph(edges=edges) + def close(self) -> None: self.conn.close() diff --git a/codeflash/languages/python/static_analysis/code_extractor.py b/codeflash/languages/python/static_analysis/code_extractor.py index 225a483cb..1b315d629 100644 --- a/codeflash/languages/python/static_analysis/code_extractor.py +++ b/codeflash/languages/python/static_analysis/code_extractor.py @@ -2,13 +2,10 @@ import ast import time -from dataclasses import dataclass from importlib.util import find_spec from itertools import chain -from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional -import jedi import libcst as cst from libcst.codemod import CodemodContext from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor @@ -17,9 +14,11 @@ from codeflash.cli_cmds.console import logger from codeflash.code_utils.config_consts import MAX_CONTEXT_LEN_REVIEW from codeflash.languages.base import Language -from codeflash.models.models import CodePosition, FunctionParent +from codeflash.models.models import FunctionParent if TYPE_CHECKING: + from pathlib import Path + from libcst.helpers import ModuleNameAndPackage from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -33,27 +32,16 @@ def __init__(self) -> None: super().__init__() self.functions: dict[str, cst.FunctionDef] = {} self.function_order: list[str] = [] - self.scope_depth = 0 def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: - if self.scope_depth == 0: - # Module-level function - name = node.name.value - self.functions[name] = node - if name not in self.function_order: - self.function_order.append(name) - self.scope_depth += 1 - return True - - def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: - self.scope_depth -= 1 + name = node.name.value + self.functions[name] = node + if name not in self.function_order: + self.function_order.append(name) + return False def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: - self.scope_depth += 1 - return True - - def leave_ClassDef(self, original_node: cst.ClassDef) -> None: - self.scope_depth -= 1 + return False class GlobalFunctionTransformer(cst.CSTTransformer): @@ -64,29 +52,19 @@ def __init__(self, new_functions: dict[str, cst.FunctionDef], new_function_order self.new_functions = new_functions self.new_function_order = new_function_order self.processed_functions: set[str] = set() - self.scope_depth = 0 - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: - self.scope_depth += 1 + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + return False def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: - self.scope_depth -= 1 - if self.scope_depth > 0: - return updated_node - - # Check if this is a module-level function we need to replace name = original_node.name.value if name in self.new_functions: self.processed_functions.add(name) return self.new_functions[name] return updated_node - def visit_ClassDef(self, node: cst.ClassDef) -> None: - self.scope_depth += 1 - - def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: - self.scope_depth -= 1 - return updated_node + def visit_ClassDef(self, node: cst.ClassDef) -> bool: + return False def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # Add any new functions that weren't in the original file @@ -138,23 +116,13 @@ def __init__(self) -> None: super().__init__() self.assignments: dict[str, cst.Assign | cst.AnnAssign] = {} self.assignment_order: list[str] = [] - # Track scope depth to identify global assignments - self.scope_depth = 0 self.if_else_depth = 0 def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: - self.scope_depth += 1 - return True - - def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: - self.scope_depth -= 1 + return False def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: - self.scope_depth += 1 - return True - - def leave_ClassDef(self, original_node: cst.ClassDef) -> None: - self.scope_depth -= 1 + return False def visit_If(self, node: cst.If) -> Optional[bool]: self.if_else_depth += 1 @@ -163,13 +131,8 @@ def visit_If(self, node: cst.If) -> Optional[bool]: def leave_If(self, original_node: cst.If) -> None: self.if_else_depth -= 1 - def visit_Else(self, node: cst.Else) -> Optional[bool]: - # Else blocks are already counted as part of the if statement - return True - def visit_Assign(self, node: cst.Assign) -> Optional[bool]: - # Only process global assignments (not inside functions, classes, etc.) - if self.scope_depth == 0 and self.if_else_depth == 0: # We're at module level + if self.if_else_depth == 0: for target in node.targets: if isinstance(target.target, cst.Name): name = target.target.value @@ -179,14 +142,7 @@ def visit_Assign(self, node: cst.Assign) -> Optional[bool]: return True def visit_AnnAssign(self, node: cst.AnnAssign) -> Optional[bool]: - # Handle annotated assignments like: _CACHE: Dict[str, int] = {} - # Only process module-level annotated assignments with a value - if ( - self.scope_depth == 0 - and self.if_else_depth == 0 - and isinstance(node.target, cst.Name) - and node.value is not None - ): + if self.if_else_depth == 0 and isinstance(node.target, cst.Name) and node.value is not None: name = node.target.value self.assignments[name] = node if name not in self.assignment_order: @@ -229,22 +185,13 @@ def __init__(self, new_assignments: dict[str, cst.Assign | cst.AnnAssign], new_a self.new_assignments = new_assignments self.new_assignment_order = new_assignment_order self.processed_assignments: set[str] = set() - self.scope_depth = 0 self.if_else_depth = 0 - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: - self.scope_depth += 1 - - def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: - self.scope_depth -= 1 - return updated_node - - def visit_ClassDef(self, node: cst.ClassDef) -> None: - self.scope_depth += 1 + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + return False - def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: - self.scope_depth -= 1 - return updated_node + def visit_ClassDef(self, node: cst.ClassDef) -> bool: + return False def visit_If(self, node: cst.If) -> None: self.if_else_depth += 1 @@ -253,12 +200,8 @@ def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If: self.if_else_depth -= 1 return updated_node - def visit_Else(self, node: cst.Else) -> None: - # Else blocks are already counted as part of the if statement - pass - def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.CSTNode: - if self.scope_depth > 0 or self.if_else_depth > 0: + if self.if_else_depth > 0: return updated_node # Check if this is a global assignment we need to replace @@ -272,7 +215,7 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c return updated_node def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.CSTNode: - if self.scope_depth > 0 or self.if_else_depth > 0: + if self.if_else_depth > 0: return updated_node # Check if this is a global annotated assignment we need to replace @@ -357,15 +300,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c class GlobalStatementTransformer(cst.CSTTransformer): - """Transformer that appends global statements at the end of the module. - - This ensures that global statements (like function calls at module level) are placed - after all functions, classes, and assignments they might reference, preventing NameError - at module load time. - - This transformer should be run LAST after GlobalFunctionTransformer and - GlobalAssignmentTransformer have already added their content. - """ + """Appends global statements at the end of the module. Run LAST after other transformers.""" def __init__(self, global_statements: list[cst.SimpleStatementLine]) -> None: super().__init__() @@ -390,70 +325,30 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c class GlobalStatementCollector(cst.CSTVisitor): - """Visitor that collects all global statements (excluding imports and functions/classes).""" + """Collects module-level statements (excluding imports, assignments, functions and classes).""" def __init__(self) -> None: super().__init__() - self.global_statements = [] - self.in_function_or_class = False + self.global_statements: list[cst.SimpleStatementLine] = [] def visit_ClassDef(self, node: cst.ClassDef) -> bool: - # Don't visit inside classes - self.in_function_or_class = True return False - def leave_ClassDef(self, original_node: cst.ClassDef) -> None: - self.in_function_or_class = False - def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: - # Don't visit inside functions - self.in_function_or_class = True return False - def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: - self.in_function_or_class = False - - def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: - if not self.in_function_or_class: - for statement in node.body: - # Skip imports and assignments (both regular and annotated) - if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign, cst.AnnAssign)): - self.global_statements.append(node) - break - - -class LastImportFinder(cst.CSTVisitor): - """Finds the position of the last import statement in the module.""" - - def __init__(self) -> None: - super().__init__() - self.last_import_line = 0 - self.current_line = 0 - def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: - self.current_line += 1 for statement in node.body: - if isinstance(statement, (cst.Import, cst.ImportFrom)): - self.last_import_line = self.current_line + if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign, cst.AnnAssign)): + self.global_statements.append(node) + break class DottedImportCollector(cst.CSTVisitor): - """Collects all top-level imports from a Python module in normalized dotted format, including top-level conditional imports like `if TYPE_CHECKING:`. - - Examples - -------- - import os ==> "os" - import dbt.adapters.factory ==> "dbt.adapters.factory" - from pathlib import Path ==> "pathlib.Path" - from recce.adapter.base import BaseAdapter ==> "recce.adapter.base.BaseAdapter" - from typing import Any, List, Optional ==> "typing.Any", "typing.List", "typing.Optional" - from recce.util.lineage import ( build_column_key, filter_dependency_maps) ==> "recce.util.lineage.build_column_key", "recce.util.lineage.filter_dependency_maps" - - """ + """Collects top-level imports as normalized dotted strings (e.g. 'from pathlib import Path' -> 'pathlib.Path').""" def __init__(self) -> None: self.imports: set[str] = set() - self.depth = 0 # top-level def get_full_dotted_name(self, expr: cst.BaseExpression) -> str: if isinstance(expr, cst.Name): @@ -488,28 +383,19 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None: self.imports.add(f"{module}.{asname}") def visit_Module(self, node: cst.Module) -> None: - self.depth = 0 self._collect_imports_from_block(node) - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: - self.depth += 1 - - def leave_FunctionDef(self, node: cst.FunctionDef) -> None: - self.depth -= 1 - - def visit_ClassDef(self, node: cst.ClassDef) -> None: - self.depth += 1 + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + return False - def leave_ClassDef(self, node: cst.ClassDef) -> None: - self.depth -= 1 + def visit_ClassDef(self, node: cst.ClassDef) -> bool: + return False def visit_If(self, node: cst.If) -> None: - if self.depth == 0: - self._collect_imports_from_block(node.body) + self._collect_imports_from_block(node.body) def visit_Try(self, node: cst.Try) -> None: - if self.depth == 0: - self._collect_imports_from_block(node.body) + self._collect_imports_from_block(node.body) def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]: @@ -520,14 +406,6 @@ def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.Si return module, collector.global_statements -def find_last_import_line(target_code: str) -> int: - """Find the line number of the last import statement.""" - module = cst.parse_module(target_code) - finder = LastImportFinder() - module.visit(finder) - return finder.last_import_line - - class FutureAliasedImportTransformer(cst.CSTTransformer): def leave_ImportFrom( self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom @@ -561,11 +439,6 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: continue unique_global_statements.append(stmt) - # Reuse already-parsed dst_module - original_module = dst_module - - # Parse the src_module_code once only (already done above: src_module) - # Collect assignments from the new file new_assignment_collector = GlobalAssignmentCollector() src_module.visit(new_assignment_collector) @@ -574,7 +447,7 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: src_module.visit(src_function_collector) dst_function_collector = GlobalFunctionCollector() - original_module.visit(dst_function_collector) + dst_module.visit(dst_function_collector) # Filter out functions that already exist in the destination (only add truly new functions) new_functions = { @@ -584,35 +457,22 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: } new_function_order = [name for name in src_function_collector.function_order if name in new_functions] - # If there are no assignments, no new functions, and no global statements, return unchanged if not new_assignment_collector.assignments and not new_functions and not unique_global_statements: return dst_module_code - # The order of transformations matters: - # 1. Functions first - so assignments and statements can reference them - # 2. Assignments second - so they come after functions but before statements - # 3. Global statements last - so they can reference both functions and assignments - - # Transform functions if any + # Transform in order: functions, then assignments, then global statements (so each can reference the previous) if new_functions: - function_transformer = GlobalFunctionTransformer(new_functions, new_function_order) - original_module = original_module.visit(function_transformer) + dst_module = dst_module.visit(GlobalFunctionTransformer(new_functions, new_function_order)) - # Transform assignments if any if new_assignment_collector.assignments: - transformer = GlobalAssignmentTransformer( - new_assignment_collector.assignments, new_assignment_collector.assignment_order + dst_module = dst_module.visit( + GlobalAssignmentTransformer(new_assignment_collector.assignments, new_assignment_collector.assignment_order) ) - original_module = original_module.visit(transformer) - # Insert global statements (like function calls at module level) LAST, - # after all functions and assignments are added, to ensure they can reference any - # functions or variables defined in the module if unique_global_statements: - statement_transformer = GlobalStatementTransformer(unique_global_statements) - original_module = original_module.visit(statement_transformer) + dst_module = dst_module.visit(GlobalStatementTransformer(unique_global_statements)) - return original_module.code + return dst_module.code def resolve_star_import(module_name: str, project_root: Path) -> set[str]: @@ -648,8 +508,6 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]: for elt in node.value.elts: if isinstance(elt, ast.Constant) and isinstance(elt.value, str): all_names.append(elt.value) - elif isinstance(elt, ast.Str): # Python < 3.8 compatibility - all_names.append(elt.s) break if all_names is not None: @@ -683,7 +541,7 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]: def add_needed_imports_from_module( - src_module_code: str, + src_module_code: str | cst.Module, dst_module_code: str | cst.Module, src_path: Path, dst_path: Path, @@ -692,7 +550,6 @@ def add_needed_imports_from_module( helper_functions_fqn: set[str] | None = None, ) -> str: """Add all needed and used source module code imports to the destination module code, and return it.""" - src_module_code = delete___future___aliased_imports(src_module_code) if not helper_functions_fqn: helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])} @@ -714,7 +571,10 @@ def add_needed_imports_from_module( ) ) try: - src_module = cst.parse_module(src_module_code) + if isinstance(src_module_code, cst.Module): + src_module = src_module_code.visit(FutureAliasedImportTransformer()) + else: + src_module = cst.parse_module(src_module_code).visit(FutureAliasedImportTransformer()) # Exclude function/class bodies so GatherImportsVisitor only sees module-level imports. # Nested imports (inside functions) are part of function logic and must not be # scheduled for add/remove — RemoveImportsVisitor would strip them as "unused". @@ -938,18 +798,6 @@ def find_target(node_list: list[ast.stmt], name_parts: tuple[str, str] | tuple[s return class_code + target_code, contextual_dunder_methods -def extract_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | None, set[tuple[str, str]]]: - edited_code, contextual_dunder_methods = get_code(functions_to_optimize) - if edited_code is None: - return None, set() - try: - compile(edited_code, "edited_code", "exec") - except SyntaxError as e: - logger.exception(f"extract_code - Syntax error in extracted optimization candidate code: {e}") - return None, set() - return edited_code, contextual_dunder_methods - - def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionParent, ...]]]: """Find all preexisting functions, classes or class methods in the source code.""" preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = set() @@ -969,417 +817,6 @@ def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionP return preexisting_objects -@dataclass -class FunctionCallLocation: - """Represents a location where the target function is called.""" - - calling_function: str - line: int - column: int - - -@dataclass -class FunctionDefinitionInfo: - """Contains information about a function definition.""" - - name: str - node: ast.FunctionDef - source_code: str - start_line: int - end_line: int - is_method: bool - class_name: Optional[str] = None - - -class FunctionCallFinder(ast.NodeVisitor): - """AST visitor that finds all function definitions that call a specific qualified function. - - Args: - target_function_name: The qualified name of the function to find (e.g., "module.function" or "function") - target_filepath: The filepath where the target function is defined - - """ - - def __init__(self, target_function_name: str, target_filepath: str, source_lines: list[str]) -> None: - self.target_function_name = target_function_name - self.target_filepath = target_filepath - self.source_lines = source_lines # Store original source lines for extraction - - # Parse the target function name into parts - self.target_parts = target_function_name.split(".") - self.target_base_name = self.target_parts[-1] - - # Track current context - self.current_function_stack: list[tuple[str, ast.FunctionDef]] = [] - self.current_class_stack: list[str] = [] - - # Track imports to resolve qualified names - self.imports: dict[str, str] = {} # Maps imported names to their full paths - - # Results - self.function_calls: list[FunctionCallLocation] = [] - self.calling_functions: set[str] = set() - self.function_definitions: dict[str, FunctionDefinitionInfo] = {} - - # Track if we found calls in the current function - self.found_call_in_current_function = False - self.functions_with_nested_calls: set[str] = set() - - def visit_Import(self, node: ast.Import) -> None: - """Track regular imports.""" - for alias in node.names: - if alias.asname: - # import module as alias - self.imports[alias.asname] = alias.name - else: - # import module - self.imports[alias.name.split(".")[-1]] = alias.name - self.generic_visit(node) - - def visit_ImportFrom(self, node: ast.ImportFrom) -> None: - """Track from imports.""" - if node.module: - for alias in node.names: - if alias.name == "*": - # from module import * - self.imports["*"] = node.module - elif alias.asname: - # from module import name as alias - self.imports[alias.asname] = f"{node.module}.{alias.name}" - else: - # from module import name - self.imports[alias.name] = f"{node.module}.{alias.name}" - self.generic_visit(node) - - def visit_ClassDef(self, node: ast.ClassDef) -> None: - """Track when entering a class definition.""" - self.current_class_stack.append(node.name) - self.generic_visit(node) - self.current_class_stack.pop() - - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - """Track when entering a function definition.""" - self._visit_function_def(node) - - def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: - """Track when entering an async function definition.""" - self._visit_function_def(node) - - def _visit_function_def(self, node: ast.FunctionDef) -> None: - """Track when entering a function definition.""" - func_name = node.name - - # Build the full qualified name including class if applicable - full_name = f"{'.'.join(self.current_class_stack)}.{func_name}" if self.current_class_stack else func_name - - self.current_function_stack.append((full_name, node)) - self.found_call_in_current_function = False - - # Visit the function body - self.generic_visit(node) - - # Process the function after visiting its body - if self.found_call_in_current_function and full_name not in self.function_definitions: - # Extract function source code - source_code = self._extract_source_code(node) - - self.function_definitions[full_name] = FunctionDefinitionInfo( - name=full_name, - node=node, - source_code=source_code, - start_line=node.lineno, - end_line=node.end_lineno if hasattr(node, "end_lineno") else node.lineno, - is_method=bool(self.current_class_stack), - class_name=self.current_class_stack[-1] if self.current_class_stack else None, - ) - - # Handle nested functions - mark parent as containing nested calls - if self.found_call_in_current_function and len(self.current_function_stack) > 1: - parent_name = self.current_function_stack[-2][0] - self.functions_with_nested_calls.add(parent_name) - - # Also store the parent function if not already stored - if parent_name not in self.function_definitions: - parent_node = self.current_function_stack[-2][1] - parent_source = self._extract_source_code(parent_node) - - # Check if parent is a method (excluding current level) - parent_class_context = self.current_class_stack if len(self.current_function_stack) == 2 else [] - - self.function_definitions[parent_name] = FunctionDefinitionInfo( - name=parent_name, - node=parent_node, - source_code=parent_source, - start_line=parent_node.lineno, - end_line=parent_node.end_lineno if hasattr(parent_node, "end_lineno") else parent_node.lineno, - is_method=bool(parent_class_context), - class_name=parent_class_context[-1] if parent_class_context else None, - ) - - self.current_function_stack.pop() - - # Reset flag for parent function - if self.current_function_stack: - parent_name = self.current_function_stack[-1][0] - self.found_call_in_current_function = parent_name in self.calling_functions - - def visit_Call(self, node: ast.Call) -> None: - """Check if this call matches our target function.""" - if not self.current_function_stack: - # Not inside a function, skip - self.generic_visit(node) - return - - if self._is_target_function_call(node): - current_func_name = self.current_function_stack[-1][0] - - call_location = FunctionCallLocation( - calling_function=current_func_name, line=node.lineno, column=node.col_offset - ) - - self.function_calls.append(call_location) - self.calling_functions.add(current_func_name) - self.found_call_in_current_function = True - - self.generic_visit(node) - - def _is_target_function_call(self, node: ast.Call) -> bool: - """Determine if this call node is calling our target function.""" - call_name = self._get_call_name(node.func) - if not call_name: - return False - - # Check if it matches directly - if call_name == self.target_function_name: - return True - - # Check if it's just the base name matching - if call_name == self.target_base_name: - # Could be imported with a different name, check imports - if call_name in self.imports: - imported_path = self.imports[call_name] - if imported_path == self.target_function_name or imported_path.endswith( - f".{self.target_function_name}" - ): - return True - # Could also be a direct call if we're in the same file - return True - - # Check for qualified calls with imports - call_parts = call_name.split(".") - if call_parts[0] in self.imports: - # Resolve the full path using imports - base_import = self.imports[call_parts[0]] - full_path = f"{base_import}.{'.'.join(call_parts[1:])}" if len(call_parts) > 1 else base_import - - if full_path == self.target_function_name or full_path.endswith(f".{self.target_function_name}"): - return True - - return False - - def _get_call_name(self, func_node) -> Optional[str]: - """Extract the name being called from a function node.""" - # Fast path short-circuit for ast.Name nodes - if isinstance(func_node, ast.Name): - return func_node.id - - # Fast attribute chain extraction (speed: append, loop, join, NO reversed) - if isinstance(func_node, ast.Attribute): - parts = [] - current = func_node - # Unwind attribute chain as tight as possible (checked at each loop iteration) - while True: - parts.append(current.attr) - val = current.value - if isinstance(val, ast.Attribute): - current = val - continue - if isinstance(val, ast.Name): - parts.append(val.id) - # Join in-place backwards via slice instead of reversed for slight speedup - return ".".join(parts[::-1]) - break - return None - - def _extract_source_code(self, node: ast.FunctionDef) -> str: - """Extract source code for a function node using original source lines.""" - if not self.source_lines or not hasattr(node, "lineno"): - # Fallback to ast.unparse if available (Python 3.9+) - try: - return ast.unparse(node) - except AttributeError: - return f"# Source code extraction not available for {node.name}" - - # Get the lines for this function - start_line = node.lineno - 1 # Convert to 0-based index - end_line = node.end_lineno if hasattr(node, "end_lineno") else len(self.source_lines) - - # Extract the function lines - func_lines = self.source_lines[start_line:end_line] - - # Find the minimum indentation (excluding empty lines) - min_indent = float("inf") - for line in func_lines: - if line.strip(): # Skip empty lines - indent = len(line) - len(line.lstrip()) - min_indent = min(min_indent, indent) - - # If this is a method (inside a class), preserve one level of indentation - if self.current_class_stack: - # Keep 4 spaces of indentation for methods - dedent_amount = max(0, min_indent - 4) - result_lines = [] - for line in func_lines: - if line.strip(): # Only dedent non-empty lines - result_lines.append(line[dedent_amount:] if len(line) > dedent_amount else line) - else: - result_lines.append(line) - else: - # For top-level functions, remove all leading indentation - result_lines = [] - for line in func_lines: - if line.strip(): # Only dedent non-empty lines - result_lines.append(line[min_indent:] if len(line) > min_indent else line) - else: - result_lines.append(line) - - return "".join(result_lines).rstrip() - - def get_results(self) -> dict[str, str]: - """Get the results of the analysis. - - Returns: - A dictionary mapping qualified function names to their source code definitions. - - """ - return {info.name: info.source_code for info in self.function_definitions.values()} - - -def find_function_calls(source_code: str, target_function_name: str, target_filepath: str) -> dict[str, str]: - """Find all function definitions that call a specific target function. - - Args: - source_code: The Python source code to analyze - target_function_name: The qualified name of the function to find (e.g., "module.function") - target_filepath: The filepath where the target function is defined - - Returns: - A dictionary mapping qualified function names to their source code definitions. - Example: {"function_a": "def function_a(): ...", "MyClass.method_one": "def method_one(self): ..."} - - """ - # Parse the source code - tree = ast.parse(source_code) - - # Split source into lines for source extraction - source_lines = source_code.splitlines(keepends=True) - - # Create and run the visitor - visitor = FunctionCallFinder(target_function_name, target_filepath, source_lines) - visitor.visit(tree) - - return visitor.get_results() - - -def find_occurances( - qualified_name: str, file_path: str, fn_matches: list[Path], project_root: Path, tests_root: Path -) -> list[str]: # max chars for context - context_len = 0 - fn_call_context = "" - for cur_file in fn_matches: - if context_len > MAX_CONTEXT_LEN_REVIEW: - break - cur_file_path = Path(cur_file) - # exclude references in tests - try: - if cur_file_path.relative_to(tests_root): - continue - except ValueError: - pass - with cur_file_path.open(encoding="utf8") as f: - file_content = f.read() - results = find_function_calls(file_content, target_function_name=qualified_name, target_filepath=file_path) - if results: - try: - path_relative_to_project_root = cur_file_path.relative_to(project_root) - except Exception as e: - # shouldn't happen but ensuring we don't crash - logger.debug(f"investigate {e}") - continue - fn_call_context += f"```python:{path_relative_to_project_root}\n" - for ( - fn_definition - ) in results.values(): # multiple functions in the file might be calling the desired function - fn_call_context += f"{fn_definition}\n" - context_len += len(fn_definition) - fn_call_context += "```\n" - return fn_call_context - - -def find_specific_function_in_file( - source_code: str, filepath: Union[str, Path], target_function: str, target_class: str | None -) -> Optional[tuple[int, int]]: - """Find a specific function definition in a Python file and return its location. - - Stops searching once the target is found (optimized for performance). - - Args: - source_code: Source code string - filepath: Path to the Python file - target_function: Function Name of the function to find - target_class: Class name of the function to find - - Returns: - Tuple of (line_number, column_offset) if found, None otherwise - - """ - script = jedi.Script(code=source_code, path=filepath) - names = script.get_names(all_scopes=True, definitions=True) - for name in names: - if name.type == "function" and name.name == target_function: - # If class name specified, check parent - if target_class: - parent = name.parent() - if parent and parent.name == target_class and parent.type == "class": - return CodePosition(line_no=name.line, col_no=name.column) - else: - # Top-level function match - return CodePosition(line_no=name.line, col_no=name.column) - - return None # Function not found - - -def get_fn_references_jedi( - source_code: str, file_path: Path, project_root: Path, target_function: str, target_class: str | None -) -> list[Path]: - start_time = time.perf_counter() - function_position: CodePosition | None = find_specific_function_in_file( - source_code, file_path, target_function, target_class - ) - if function_position is None: - # Function not found (may be non-Python code) - return [] - try: - script = jedi.Script(code=source_code, path=file_path, project=jedi.Project(path=project_root)) - # Get references to the function - references = script.get_references(line=function_position.line_no, column=function_position.col_no) - # Collect unique file paths where references are found - end_time = time.perf_counter() - logger.debug(f"Jedi for function references ran in {end_time - start_time:.2f} seconds") - reference_files = set() - for ref in references: - if ref.module_path: - # Convert to string and normalize path - ref_path = str(ref.module_path) - # Skip the definition itself - if not (ref_path == file_path and ref.line == function_position.line_no): - reference_files.add(ref_path) - return sorted(reference_files) - except Exception as e: - print(f"Error during Jedi analysis: {e}") - return [] - - has_numba = find_spec("numba") is not None NUMERICAL_MODULES = frozenset({"numpy", "torch", "numba", "jax", "tensorflow", "math", "scipy"}) @@ -1387,163 +824,54 @@ def get_fn_references_jedi( NUMBA_REQUIRED_MODULES = frozenset({"numpy", "math", "scipy"}) -class NumericalUsageChecker(ast.NodeVisitor): - """AST visitor that checks if a function uses numerical computing libraries.""" - - def __init__(self, numerical_names: set[str]) -> None: - self.numerical_names = numerical_names - self.found_numerical = False - - def visit_Call(self, node: ast.Call) -> None: - """Check function calls for numerical library usage.""" - if self.found_numerical: - return - call_name = self._get_root_name(node.func) - if call_name and call_name in self.numerical_names: - self.found_numerical = True - return - self.generic_visit(node) - - def visit_Attribute(self, node: ast.Attribute) -> None: - """Check attribute access for numerical library usage.""" - if self.found_numerical: - return - root_name = self._get_root_name(node) - if root_name and root_name in self.numerical_names: - self.found_numerical = True - return - self.generic_visit(node) - - def visit_Name(self, node: ast.Name) -> None: - """Check name references for numerical library usage.""" - if self.found_numerical: - return - if node.id in self.numerical_names: - self.found_numerical = True - - def _get_root_name(self, node: ast.expr) -> str | None: - """Get the root name from an expression (e.g., 'np' from 'np.array').""" - if isinstance(node, ast.Name): - return node.id - if isinstance(node, ast.Attribute): - return self._get_root_name(node.value) - return None +def _uses_numerical_names(node: ast.AST, numerical_names: set[str]) -> bool: + return any(isinstance(n, ast.Name) and n.id in numerical_names for n in ast.walk(node)) def _collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]: - """Collect names that reference numerical computing libraries from imports. - - Returns: - A tuple of (numerical_names, modules_used) where: - - numerical_names: set of names/aliases that reference numerical libraries - - modules_used: set of actual module names (e.g., "numpy", "math") being imported - - """ numerical_names: set[str] = set() modules_used: set[str] = set() - - stack: list[ast.AST] = [tree] - while stack: - node = stack.pop() + for node in ast.walk(tree): if isinstance(node, ast.Import): for alias in node.names: - # import numpy or import numpy as np module_root = alias.name.split(".")[0] if module_root in NUMERICAL_MODULES: - # Use the alias if present, otherwise the module name - name = alias.asname if alias.asname else alias.name.split(".")[0] - numerical_names.add(name) + numerical_names.add(alias.asname if alias.asname else module_root) modules_used.add(module_root) elif isinstance(node, ast.ImportFrom) and node.module: module_root = node.module.split(".")[0] if module_root in NUMERICAL_MODULES: - # from numpy import array, zeros as z for alias in node.names: if alias.name == "*": - # Can't track star imports, but mark the module as numerical numerical_names.add(module_root) else: - name = alias.asname if alias.asname else alias.name - numerical_names.add(name) + numerical_names.add(alias.asname if alias.asname else alias.name) modules_used.add(module_root) - else: - stack.extend(ast.iter_child_nodes(node)) - return numerical_names, modules_used def _find_function_node(tree: ast.Module, name_parts: list[str]) -> ast.FunctionDef | None: - """Find a function node in the AST given its qualified name parts. - - Note: This function only finds regular (sync) functions, not async functions. - - Args: - tree: The parsed AST module - name_parts: List of name parts, e.g., ["ClassName", "method_name"] or ["function_name"] - - Returns: - The function node if found, None otherwise - - """ - if not name_parts: - return None - - if len(name_parts) == 1: - # Top-level function - func_name = name_parts[0] - for node in tree.body: - if isinstance(node, ast.FunctionDef) and node.name == func_name: - return node + """Find a function node in the AST given its qualified name parts (e.g. ["ClassName", "method"] or ["func"]).""" + if not name_parts or len(name_parts) > 2: return None - - if len(name_parts) == 2: - # Class method: ClassName.method_name - class_name, method_name = name_parts - for node in tree.body: - if isinstance(node, ast.ClassDef) and node.name == class_name: - for class_node in node.body: - if isinstance(class_node, ast.FunctionDef) and class_node.name == method_name: - return class_node - return None - + body: list[ast.stmt] = tree.body + for part in name_parts[:-1]: + for node in body: + if isinstance(node, ast.ClassDef) and node.name == part: + body = node.body + break + else: + return None + for node in body: + if isinstance(node, ast.FunctionDef) and node.name == name_parts[-1]: + return node return None def is_numerical_code(code_string: str, function_name: str | None = None) -> bool: - """Check if a function uses numerical computing libraries. - - Detects usage of numpy, torch, numba, jax, tensorflow, scipy, and math libraries - within the specified function. - - Note: For math, numpy, and scipy usage, this function returns True only if numba - is installed in the environment, as numba is required to optimize such code. - - Args: - code_string: The entire file's content as a string - function_name: The name of the function to check. Can be a simple name like "foo" - or a qualified name like "ClassName.method_name" for methods, - staticmethods, or classmethods. - - Returns: - True if the function uses any numerical computing library functions, False otherwise. - Returns False for math/numpy/scipy usage if numba is not installed. - - Examples: - >>> code = ''' - ... import numpy as np - ... def process_data(x): - ... return np.sum(x) - ... ''' - >>> is_numerical_code(code, "process_data") # Returns True only if numba is installed - True - - >>> code = ''' - ... def simple_func(x): - ... return x + 1 - ... ''' - >>> is_numerical_code(code, "simple_func") - False + """Check if a function uses numerical computing libraries (numpy, torch, numba, jax, tensorflow, scipy, math). + Returns False for math/numpy/scipy if numba is not installed. """ try: tree = ast.parse(code_string) @@ -1565,11 +893,7 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo if target_function is None: return False - # Check if the function body uses any numerical library - checker = NumericalUsageChecker(numerical_names) - checker.visit(target_function) - - if not checker.found_numerical: + if not _uses_numerical_names(target_function, numerical_names): return False # If numba is not installed and all modules used require numba for optimization, @@ -1580,22 +904,7 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo def get_opt_review_metrics( source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path, language: Language ) -> str: - """Get function reference metrics for optimization review. - - Uses the LanguageSupport abstraction to find references, supporting both Python and JavaScript/TypeScript. - - Args: - source_code: Source code of the file containing the function. - file_path: Path to the file. - qualified_name: Qualified name of the function (e.g., "module.ClassName.method"). - project_root: Root of the project. - tests_root: Root of the tests directory. - language: The programming language. - - Returns: - Markdown-formatted string with code blocks showing calling functions. - - """ + """Get markdown-formatted calling function context for optimization review.""" from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.registry import get_language_support from codeflash.models.models import FunctionParent @@ -1649,18 +958,7 @@ def get_opt_review_metrics( def _format_references_as_markdown(references: list, file_path: Path, project_root: Path, language: Language) -> str: - """Format references as markdown code blocks with calling function code. - - Args: - references: List of ReferenceInfo objects. - file_path: Path to the source file (to exclude). - project_root: Root of the project. - language: The programming language. - - Returns: - Markdown-formatted string. - - """ + """Format references as markdown code blocks with calling function code.""" # Group references by file refs_by_file: dict[Path, list] = {} for ref in references: diff --git a/codeflash/languages/python/static_analysis/code_replacer.py b/codeflash/languages/python/static_analysis/code_replacer.py index 6d9a3b3f1..89dc2751e 100644 --- a/codeflash/languages/python/static_analysis/code_replacer.py +++ b/codeflash/languages/python/static_analysis/code_replacer.py @@ -42,53 +42,55 @@ def normalize_code(code: str) -> str: return ast.unparse(normalize_node(ast.parse(code))) +def has_autouse_fixture(node: cst.FunctionDef) -> bool: + for decorator in node.decorators: + dec = decorator.decorator + if not isinstance(dec, cst.Call): + continue + is_fixture = ( + isinstance(dec.func, cst.Attribute) + and isinstance(dec.func.value, cst.Name) + and dec.func.attr.value == "fixture" + and dec.func.value.value == "pytest" + ) or (isinstance(dec.func, cst.Name) and dec.func.value == "fixture") + if is_fixture: + for arg in dec.args: + if ( + arg.keyword + and arg.keyword.value == "autouse" + and isinstance(arg.value, cst.Name) + and arg.value.value == "True" + ): + return True + return False + + class AddRequestArgument(cst.CSTTransformer): METADATA_DEPENDENCIES = (PositionProvider,) def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: - # Matcher for '@fixture' or '@pytest.fixture' - for decorator in original_node.decorators: - dec = decorator.decorator - - if isinstance(dec, cst.Call): - func_name = "" - if isinstance(dec.func, cst.Attribute) and isinstance(dec.func.value, cst.Name): - if dec.func.attr.value == "fixture" and dec.func.value.value == "pytest": - func_name = "pytest.fixture" - elif isinstance(dec.func, cst.Name) and dec.func.value == "fixture": - func_name = "fixture" - - if func_name: - for arg in dec.args: - if ( - arg.keyword - and arg.keyword.value == "autouse" - and isinstance(arg.value, cst.Name) - and arg.value.value == "True" - ): - args = updated_node.params.params - arg_names = {arg.name.value for arg in args} - - # Skip if 'request' is already present - if "request" in arg_names: - return updated_node - - # Create a new 'request' param - request_param = cst.Param(name=cst.Name("request")) - - # Add 'request' as the first argument (after 'self' or 'cls' if needed) - if args: - first_arg = args[0].name.value - if first_arg in {"self", "cls"}: - new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005 - else: - new_params = [request_param] + list(args) # noqa: RUF005 - else: - new_params = [request_param] - - new_param_list = updated_node.params.with_changes(params=new_params) - return updated_node.with_changes(params=new_param_list) - return updated_node + if not has_autouse_fixture(original_node): + return updated_node + + args = updated_node.params.params + arg_names = {arg.name.value for arg in args} + + if "request" in arg_names: + return updated_node + + request_param = cst.Param(name=cst.Name("request")) + + if args: + first_arg = args[0].name.value + if first_arg in {"self", "cls"}: + new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005 + else: + new_params = [request_param] + list(args) # noqa: RUF005 + else: + new_params = [request_param] + + new_param_list = updated_node.params.with_changes(params=new_params) + return updated_node.with_changes(params=new_param_list) class PytestMarkAdder(cst.CSTTransformer): @@ -159,43 +161,15 @@ def _create_pytest_mark(self) -> cst.Decorator: class AutouseFixtureModifier(cst.CSTTransformer): def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: - # Matcher for '@fixture' or '@pytest.fixture' - for decorator in original_node.decorators: - dec = decorator.decorator - - if isinstance(dec, cst.Call): - func_name = "" - if isinstance(dec.func, cst.Attribute) and isinstance(dec.func.value, cst.Name): - if dec.func.attr.value == "fixture" and dec.func.value.value == "pytest": - func_name = "pytest.fixture" - elif isinstance(dec.func, cst.Name) and dec.func.value == "fixture": - func_name = "fixture" - - if func_name: - for arg in dec.args: - if ( - arg.keyword - and arg.keyword.value == "autouse" - and isinstance(arg.value, cst.Name) - and arg.value.value == "True" - ): - # Found a matching fixture with autouse=True - - # 1. The original body of the function will become the 'else' block. - # updated_node.body is an IndentedBlock, which is what cst.Else expects. - else_block = cst.Else(body=updated_node.body) - - # 2. Create the new 'if' block that will exit the fixture early. - if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")') - yield_statement = cst.parse_statement("yield") - if_body = cst.IndentedBlock(body=[yield_statement]) - - # 3. Construct the full if/else statement. - new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block) - - # 4. Replace the entire function's body with our new single statement. - return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement])) - return updated_node + if not has_autouse_fixture(original_node): + return updated_node + + else_block = cst.Else(body=updated_node.body) + if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")') + yield_statement = cst.parse_statement("yield") + if_body = cst.IndentedBlock(body=[yield_statement]) + new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block) + return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement])) def disable_autouse(test_path: Path) -> str: diff --git a/codeflash/languages/python/support.py b/codeflash/languages/python/support.py index e9a68aac2..38f205f08 100644 --- a/codeflash/languages/python/support.py +++ b/codeflash/languages/python/support.py @@ -957,7 +957,7 @@ def create_dependency_resolver(self, project_root: Path) -> DependencyResolver | try: return ReferenceGraph(project_root, language=self.language.value) except Exception: - logger.debug("Failed to initialize ReferenceGraph, falling back to per-function Jedi analysis") + logger.info("Failed to initialize ReferenceGraph, falling back to per-function Jedi analysis") return None def instrument_existing_test( diff --git a/codeflash/models/call_graph.py b/codeflash/models/call_graph.py new file mode 100644 index 000000000..9083af2d0 --- /dev/null +++ b/codeflash/models/call_graph.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import logging +from collections import defaultdict, deque +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, NamedTuple + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash.models.models import FunctionSource + + +class FunctionNode(NamedTuple): + file_path: Path + qualified_name: str + + +@dataclass(frozen=True) +class CalleeMetadata: + fully_qualified_name: str + only_function_name: str + definition_type: str + source_line: str + + +@dataclass(frozen=True) +class CallEdge: + caller: FunctionNode + callee: FunctionNode + is_cross_file: bool + call_count: int | None = None + total_time_ns: int | None = None + callee_metadata: CalleeMetadata | None = None + + +@dataclass +class CallGraph: + edges: list[CallEdge] + _forward: dict[FunctionNode, list[CallEdge]] = field(default_factory=dict, init=False, repr=False) + _reverse: dict[FunctionNode, list[CallEdge]] = field(default_factory=dict, init=False, repr=False) + _nodes: set[FunctionNode] = field(default_factory=set, init=False, repr=False) + + def __post_init__(self) -> None: + fwd: dict[FunctionNode, list[CallEdge]] = {} + rev: dict[FunctionNode, list[CallEdge]] = {} + nodes: set[FunctionNode] = set() + for edge in self.edges: + fwd.setdefault(edge.caller, []).append(edge) + rev.setdefault(edge.callee, []).append(edge) + nodes.add(edge.caller) + nodes.add(edge.callee) + self._forward = fwd + self._reverse = rev + self._nodes = nodes + + @property + def forward(self) -> dict[FunctionNode, list[CallEdge]]: + return self._forward + + @property + def reverse(self) -> dict[FunctionNode, list[CallEdge]]: + return self._reverse + + @property + def nodes(self) -> set[FunctionNode]: + return self._nodes + + def callees_of(self, node: FunctionNode) -> list[CallEdge]: + return self.forward.get(node, []) + + def callers_of(self, node: FunctionNode) -> list[CallEdge]: + return self.reverse.get(node, []) + + def descendants(self, node: FunctionNode, max_depth: int | None = None) -> set[FunctionNode]: + visited: set[FunctionNode] = set() + forward_map = self._forward + if max_depth is None: + queue: deque[FunctionNode] = deque([node]) + while queue: + current = queue.popleft() + for edge in forward_map.get(current, []): + if edge.callee not in visited: + visited.add(edge.callee) + queue.append(edge.callee) + else: + depth_queue: deque[tuple[FunctionNode, int]] = deque([(node, 0)]) + while depth_queue: + current, depth = depth_queue.popleft() + if depth >= max_depth: + continue + for edge in forward_map.get(current, []): + if edge.callee not in visited: + visited.add(edge.callee) + depth_queue.append((edge.callee, depth + 1)) + return visited + + def ancestors(self, node: FunctionNode, max_depth: int | None = None) -> set[FunctionNode]: + visited: set[FunctionNode] = set() + reverse_map = self._reverse + if max_depth is None: + queue: list[FunctionNode] = [node] + while queue: + current = queue.pop() + for edge in reverse_map.get(current, []): + if edge.caller not in visited: + visited.add(edge.caller) + queue.append(edge.caller) + else: + depth_queue: list[tuple[FunctionNode, int]] = [(node, 0)] + while depth_queue: + current, depth = depth_queue.pop() + if depth >= max_depth: + continue + for edge in reverse_map.get(current, []): + if edge.caller not in visited: + visited.add(edge.caller) + depth_queue.append((edge.caller, depth + 1)) + return visited + + def subgraph(self, nodes: set[FunctionNode]) -> CallGraph: + filtered = [e for e in self.edges if e.caller in nodes and e.callee in nodes] + return CallGraph(edges=filtered) + + def leaf_functions(self) -> set[FunctionNode]: + all_nodes = self.nodes + return all_nodes - set(self.forward.keys()) + + def root_functions(self) -> set[FunctionNode]: + all_nodes = self.nodes + return all_nodes - set(self.reverse.keys()) + + def topological_order(self) -> list[FunctionNode]: + in_degree: dict[FunctionNode, int] = {} + all_nodes = self._nodes + for node in all_nodes: + in_degree.setdefault(node, 0) + for edge in self.edges: + in_degree[edge.callee] = in_degree.get(edge.callee, 0) + 1 + + forward_map = self._forward + queue = deque(node for node, deg in in_degree.items() if deg == 0) + result: list[FunctionNode] = [] + while queue: + node = queue.popleft() + result.append(node) + for edge in forward_map.get(node, []): + in_degree[edge.callee] -= 1 + if in_degree[edge.callee] == 0: + queue.append(edge.callee) + + if len(result) < len(all_nodes): + logger.warning( + "Call graph contains cycles: %d of %d nodes excluded from topological order", + len(all_nodes) - len(result), + len(all_nodes), + ) + + # Leaves-first: reverse the topological order + result.reverse() + return result + + +def augment_with_trace(graph: CallGraph, trace_db_path: Path) -> CallGraph: + import sqlite3 + + conn = sqlite3.connect(str(trace_db_path)) + try: + rows = conn.execute( + "SELECT filename, function, class_name, call_count_nonrecursive, total_time_ns FROM pstats" + ).fetchall() + except sqlite3.OperationalError: + conn.close() + return graph + conn.close() + + lookup: dict[tuple[str, str], tuple[int, int]] = {} + for filename, function, class_name, call_count, total_time in rows: + if class_name: + qn = f"{class_name}.{function}" + else: + qn = function + lookup[(filename, qn)] = (call_count, total_time) + + augmented_edges: list[CallEdge] = [] + for edge in graph.edges: + callee_file = str(edge.callee.file_path) + callee_qn = edge.callee.qualified_name + stats = lookup.get((callee_file, callee_qn)) + if stats is not None: + call_count, total_time = stats + augmented_edges.append( + CallEdge( + caller=edge.caller, + callee=edge.callee, + is_cross_file=edge.is_cross_file, + call_count=call_count, + total_time_ns=total_time, + callee_metadata=edge.callee_metadata, + ) + ) + else: + augmented_edges.append(edge) + + return CallGraph(edges=augmented_edges) + + +def callees_from_graph(graph: CallGraph) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: + + from codeflash.models.models import FunctionSource + + file_path_to_function_source: dict[Path, set[FunctionSource]] = defaultdict(set) + function_source_list: list[FunctionSource] = [] + + for edge in graph.edges: + meta = edge.callee_metadata + if meta is None: + continue + callee_path = edge.callee.file_path + fs = FunctionSource( + file_path=callee_path, + qualified_name=edge.callee.qualified_name, + fully_qualified_name=meta.fully_qualified_name, + only_function_name=meta.only_function_name, + source_code=meta.source_line, + definition_type=meta.definition_type, + ) + file_path_to_function_source[callee_path].add(fs) + function_source_list.append(fs) + + return file_path_to_function_source, function_source_list diff --git a/codeflash/models/models.py b/codeflash/models/models.py index dd105556b..607708e39 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -340,12 +340,12 @@ def file_to_path(self) -> dict[str, str]: dict[str, str]: Mapping from file path (as string) to code. """ - if self._cache.get("file_to_path") is not None: + try: return self._cache["file_to_path"] - self._cache["file_to_path"] = { - str(code_string.file_path): code_string.code for code_string in self.code_strings - } - return self._cache["file_to_path"] + except KeyError: + mapping = {str(code_string.file_path): code_string.code for code_string in self.code_strings} + self._cache["file_to_path"] = mapping + return mapping @staticmethod def parse_markdown_code(markdown_code: str, expected_language: str = "python") -> CodeStringsMarkdown: diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index b96d4ca4d..35fa4d315 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -10,15 +10,10 @@ from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.api.cfapi import send_completion_email -from codeflash.cli_cmds.console import ( # noqa: F401 - call_graph_live_display, - call_graph_summary, - console, - logger, - progress_bar, -) +from codeflash.cli_cmds.console import call_graph_live_display, call_graph_summary, console, logger, progress_bar from codeflash.code_utils import env_utils from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file +from codeflash.code_utils.config_consts import HIGH_EFFORT_TOP_N, EffortLevel from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir from codeflash.code_utils.git_worktree_utils import ( @@ -70,6 +65,7 @@ def __init__(self, args: Namespace) -> None: self.current_worktree: Path | None = None self.original_args_and_test_cfg: tuple[Namespace, TestConfig] | None = None self.patch_files: list[Path] = [] + self._cached_callee_counts: dict[tuple[Path, str], int] = {} def run_benchmarks( self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int @@ -194,6 +190,7 @@ def create_function_optimizer( function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None, total_benchmark_timings: dict[BenchmarkKey, float] | None = None, call_graph: DependencyResolver | None = None, + effort_override: str | None = None, ) -> FunctionOptimizer | None: qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root) @@ -223,6 +220,7 @@ def create_function_optimizer( total_benchmark_timings=total_benchmark_timings if function_specific_timings else None, replay_tests_dir=self.replay_tests_dir, call_graph=call_graph, + effort_override=effort_override, ) if function_optimizer.function_to_optimize_ast is None and function_optimizer.requires_function_ast(): logger.info( @@ -334,6 +332,7 @@ def rank_all_functions_globally( file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], trace_file_path: Path | None, call_graph: DependencyResolver | None = None, + test_count_cache: dict[tuple[Path, str], int] | None = None, ) -> list[tuple[Path, FunctionToOptimize]]: """Rank all functions globally across all files based on trace data. @@ -356,7 +355,7 @@ def rank_all_functions_globally( # If no trace file, rank by dependency count if call graph is available if not trace_file_path or not trace_file_path.exists(): if call_graph is not None: - return self.rank_by_dependency_count(all_functions, call_graph) + return self.rank_by_dependency_count(all_functions, call_graph, test_count_cache=test_count_cache) logger.debug("No trace file available, using original function order") return all_functions @@ -383,12 +382,23 @@ def rank_all_functions_globally( # Use a tuple of unique identifiers as the key key: tuple[Path, str, int | None] = (func.file_path, func.qualified_name, func.starting_line) func_to_file_map[key] = file_path - globally_ranked = [] - for func in ranked_functions: + ranked_with_metadata: list[tuple[Path, FunctionToOptimize, float, int]] = [] + for rank_index, func in enumerate(ranked_functions): key = (func.file_path, func.qualified_name, func.starting_line) file_path = func_to_file_map.get(key) if file_path: - globally_ranked.append((file_path, func)) + ranked_with_metadata.append( + (file_path, func, ranker.get_function_addressable_time(func), rank_index) + ) + + if test_count_cache: + ranked_with_metadata.sort( + key=lambda item: (-item[2], -test_count_cache.get((item[0], item[1].qualified_name), 0), item[3]) + ) + + globally_ranked = [ + (file_path, func) for file_path, func, _addressable_time, _rank_index in ranked_with_metadata + ] console.rule() logger.info( @@ -408,15 +418,30 @@ def rank_all_functions_globally( return globally_ranked def rank_by_dependency_count( - self, all_functions: list[tuple[Path, FunctionToOptimize]], call_graph: DependencyResolver + self, + all_functions: list[tuple[Path, FunctionToOptimize]], + call_graph: DependencyResolver, + test_count_cache: dict[tuple[Path, str], int] | None = None, ) -> list[tuple[Path, FunctionToOptimize]]: file_to_qns: dict[Path, set[str]] = defaultdict(set) for file_path, func in all_functions: file_to_qns[file_path].add(func.qualified_name) callee_counts = call_graph.count_callees_per_function(dict(file_to_qns)) - ranked = sorted( - enumerate(all_functions), key=lambda x: (-callee_counts.get((x[1][0], x[1][1].qualified_name), 0), x[0]) - ) + self._cached_callee_counts = callee_counts + + if test_count_cache: + ranked = sorted( + enumerate(all_functions), + key=lambda x: ( + -callee_counts.get((x[1][0], x[1][1].qualified_name), 0), + -test_count_cache.get((x[1][0], x[1][1].qualified_name), 0), + x[0], + ), + ) + else: + ranked = sorted( + enumerate(all_functions), key=lambda x: (-callee_counts.get((x[1][0], x[1][1].qualified_name), 0), x[0]) + ) logger.debug(f"Ranked {len(ranked)} functions by dependency count (most complex first)") return [item for _, item in ranked] @@ -473,17 +498,16 @@ def run(self) -> None: # Skip in CI — the cache DB doesn't persist between runs on ephemeral runners lang_support = current_language_support() resolver = None - # CURRENTLY DISABLED: The resolver is currently not used for anything until i clean up the repo structure for python - # if lang_support and not env_utils.is_ci(): - # resolver = lang_support.create_dependency_resolver(self.args.project_root) - - # if resolver is not None and lang_support is not None and file_to_funcs_to_optimize: - # supported_exts = lang_support.file_extensions - # source_files = [f for f in file_to_funcs_to_optimize if f.suffix in supported_exts] - # with call_graph_live_display(len(source_files), project_root=self.args.project_root) as on_progress: - # resolver.build_index(source_files, on_progress=on_progress) - # console.rule() - # call_graph_summary(resolver, file_to_funcs_to_optimize) + if lang_support and not env_utils.is_ci(): + resolver = lang_support.create_dependency_resolver(self.args.project_root) + + if resolver is not None and lang_support is not None and file_to_funcs_to_optimize: + supported_exts = lang_support.file_extensions + source_files = [f for f in file_to_funcs_to_optimize if f.suffix in supported_exts] + with call_graph_live_display(len(source_files), project_root=self.args.project_root) as on_progress: + resolver.build_index(source_files, on_progress=on_progress) + console.rule() + call_graph_summary(resolver, file_to_funcs_to_optimize) optimizations_found: int = 0 self.test_cfg.concolic_test_root_dir = Path( @@ -499,13 +523,34 @@ def run(self) -> None: if self.args.all and not self.args.subagent: self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root) + # Pre-compute test counts once for ranking and logging + test_count_cache: dict[tuple[Path, str], int] + if function_to_tests: + from codeflash.discovery.discover_unit_tests import existing_unit_test_count + + test_count_cache = { + (fp, fn.qualified_name): existing_unit_test_count(fn, self.args.project_root, function_to_tests) + for fp, fns in file_to_funcs_to_optimize.items() + for fn in fns + } + else: + test_count_cache = {} + # GLOBAL RANKING: Rank all functions together before optimizing globally_ranked_functions = self.rank_all_functions_globally( - file_to_funcs_to_optimize, trace_file_path, call_graph=resolver + file_to_funcs_to_optimize, trace_file_path, call_graph=resolver, test_count_cache=test_count_cache ) # Cache for module preparation (avoid re-parsing same files) prepared_modules: dict[Path, tuple[dict[Path, ValidCode], ast.Module | None]] = {} + # Reuse callee counts from rank_by_dependency_count if available, otherwise compute + callee_counts = self._cached_callee_counts + if not callee_counts and resolver is not None: + file_to_qns: dict[Path, set[str]] = defaultdict(set) + for fp, fn in globally_ranked_functions: + file_to_qns[fp].add(fn.qualified_name) + callee_counts = resolver.count_callees_per_function(dict(file_to_qns)) + # Optimize functions in globally ranked order for i, (original_module_path, function_to_optimize) in enumerate(globally_ranked_functions): # Prepare module if not already cached @@ -520,9 +565,24 @@ def run(self) -> None: function_iterator_count = i + 1 line_suffix = f":{function_to_optimize.starting_line}" if function_to_optimize.starting_line else "" + + callee_count = callee_counts.get((original_module_path, function_to_optimize.qualified_name), 0) + callee_suffix = f", {callee_count} callees" if callee_count else "" + + test_count = test_count_cache.get((original_module_path, function_to_optimize.qualified_name), 0) + test_suffix = f", {test_count} tests" if test_count else "" + + effort_override: str | None = None + if i < HIGH_EFFORT_TOP_N and self.args.effort == EffortLevel.MEDIUM.value: + effort_override = EffortLevel.HIGH.value + logger.debug( + f"Escalating effort for {function_to_optimize.qualified_name} from medium to high" + f" (top {HIGH_EFFORT_TOP_N} ranked)" + ) + logger.info( f"Optimizing function {function_iterator_count} of {len(globally_ranked_functions)}: " - f"{function_to_optimize.qualified_name} (in {original_module_path}{line_suffix})" + f"{function_to_optimize.qualified_name} (in {original_module_path}{line_suffix}{callee_suffix}{test_suffix})" ) console.rule() function_optimizer = None @@ -534,6 +594,7 @@ def run(self) -> None: function_benchmark_timings=function_benchmark_timings, total_benchmark_timings=total_benchmark_timings, call_graph=resolver, + effort_override=effort_override, ) if function_optimizer is None: continue diff --git a/pyproject.toml b/pyproject.toml index 36e3aabf2..92c699d1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -295,6 +295,7 @@ ignore = [ "F841", # Unused variable (often intentional) "ANN202", # Missing return type for private functions "B009", # getattr-with-constant - needed to avoid mypy [misc] on dunder access + "PTH119", # os.path.basename — faster than Path().name for string paths ] [tool.ruff.lint.flake8-type-checking] diff --git a/tests/test_add_language_metadata.py b/tests/test_add_language_metadata.py index 10eac7deb..91a4f66c6 100644 --- a/tests/test_add_language_metadata.py +++ b/tests/test_add_language_metadata.py @@ -6,7 +6,7 @@ from __future__ import annotations -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -80,17 +80,21 @@ def test_typescript_same_as_javascript(self, _mock_lang: object) -> None: assert payload["module_system"] == "commonjs" @patch("codeflash.api.aiservice.current_language", return_value=Language.PYTHON) - def test_none_language_version_python(self, _mock_lang: object) -> None: - """When language_version is None for Python, payload should still have the keys.""" + @patch("codeflash.api.aiservice.current_language_support") + def test_none_language_version_python_auto_detects(self, mock_support: MagicMock, _mock_lang: object) -> None: + """When language_version is None for Python, it should auto-detect from language support.""" + mock_support.return_value.language_version = "3.12.0" payload: dict = {} AiServiceClient.add_language_metadata(payload, language_version=None) - assert payload["language_version"] is None - assert payload["python_version"] is None + assert payload["language_version"] == "3.12.0" + assert payload["python_version"] == "3.12.0" @patch("codeflash.api.aiservice.current_language", return_value=Language.JAVA) - def test_none_language_version_java(self, _mock_lang: object) -> None: - """When language_version is None for Java, payload should still have the keys.""" + @patch("codeflash.api.aiservice.current_language_support") + def test_none_language_version_java_auto_detects(self, mock_support: MagicMock, _mock_lang: object) -> None: + """When language_version is None for Java, it should auto-detect from language support.""" + mock_support.return_value.language_version = "17" payload: dict = {} AiServiceClient.add_language_metadata(payload, language_version=None) - assert payload["language_version"] is None + assert payload["language_version"] == "17" assert payload["python_version"] is None diff --git a/tests/test_call_graph.py b/tests/test_call_graph.py new file mode 100644 index 000000000..8b038d13e --- /dev/null +++ b/tests/test_call_graph.py @@ -0,0 +1,523 @@ +from __future__ import annotations + +import sqlite3 +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from pathlib import Path + +from codeflash.languages.python.reference_graph import ReferenceGraph +from codeflash.models.call_graph import ( + CallEdge, + CalleeMetadata, + CallGraph, + FunctionNode, + augment_with_trace, + callees_from_graph, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def node(name: str, file: str = "mod.py") -> FunctionNode: + return FunctionNode(file_path=__import__("pathlib").Path(file), qualified_name=name) + + +def edge(caller: str, callee: str, *, cross: bool = False, file: str = "mod.py") -> CallEdge: + return CallEdge(caller=node(caller, file), callee=node(callee, file), is_cross_file=cross) + + +def make_graph(edges: list[CallEdge]) -> CallGraph: + return CallGraph(edges=edges) + + +# --------------------------------------------------------------------------- +# CallGraph unit tests +# --------------------------------------------------------------------------- + + +class TestCalleesOf: + def test_returns_direct_callees(self) -> None: + g = make_graph([edge("a", "b"), edge("a", "c"), edge("b", "c")]) + callees = g.callees_of(node("a")) + callee_names = {e.callee.qualified_name for e in callees} + assert callee_names == {"b", "c"} + + def test_returns_empty_for_leaf(self) -> None: + g = make_graph([edge("a", "b")]) + assert g.callees_of(node("b")) == [] + + def test_returns_empty_for_unknown_node(self) -> None: + g = make_graph([edge("a", "b")]) + assert g.callees_of(node("z")) == [] + + +class TestCallersOf: + def test_returns_direct_callers(self) -> None: + g = make_graph([edge("a", "c"), edge("b", "c")]) + callers = g.callers_of(node("c")) + caller_names = {e.caller.qualified_name for e in callers} + assert caller_names == {"a", "b"} + + def test_returns_empty_for_root(self) -> None: + g = make_graph([edge("a", "b")]) + assert g.callers_of(node("a")) == [] + + +class TestDescendants: + def test_transitive_descendants(self) -> None: + g = make_graph([edge("a", "b"), edge("b", "c"), edge("c", "d")]) + desc = g.descendants(node("a")) + assert {n.qualified_name for n in desc} == {"b", "c", "d"} + + def test_max_depth_limits_traversal(self) -> None: + g = make_graph([edge("a", "b"), edge("b", "c"), edge("c", "d")]) + desc = g.descendants(node("a"), max_depth=1) + assert {n.qualified_name for n in desc} == {"b"} + + def test_max_depth_two(self) -> None: + g = make_graph([edge("a", "b"), edge("b", "c"), edge("c", "d")]) + desc = g.descendants(node("a"), max_depth=2) + assert {n.qualified_name for n in desc} == {"b", "c"} + + def test_handles_cycle(self) -> None: + g = make_graph([edge("a", "b"), edge("b", "a")]) + desc = g.descendants(node("a")) + assert {n.qualified_name for n in desc} == {"b", "a"} + + def test_empty_for_leaf(self) -> None: + g = make_graph([edge("a", "b")]) + assert g.descendants(node("b")) == set() + + +class TestAncestors: + def test_transitive_ancestors(self) -> None: + g = make_graph([edge("a", "b"), edge("b", "c"), edge("c", "d")]) + anc = g.ancestors(node("d")) + assert {n.qualified_name for n in anc} == {"a", "b", "c"} + + def test_max_depth_limits_traversal(self) -> None: + g = make_graph([edge("a", "b"), edge("b", "c"), edge("c", "d")]) + anc = g.ancestors(node("d"), max_depth=1) + assert {n.qualified_name for n in anc} == {"c"} + + def test_empty_for_root(self) -> None: + g = make_graph([edge("a", "b")]) + assert g.ancestors(node("a")) == set() + + +class TestLeafAndRootFunctions: + def test_leaf_functions(self) -> None: + g = make_graph([edge("a", "b"), edge("a", "c"), edge("b", "d")]) + leaves = g.leaf_functions() + assert {n.qualified_name for n in leaves} == {"c", "d"} + + def test_root_functions(self) -> None: + g = make_graph([edge("a", "b"), edge("a", "c"), edge("b", "d")]) + roots = g.root_functions() + assert {n.qualified_name for n in roots} == {"a"} + + def test_single_edge(self) -> None: + g = make_graph([edge("a", "b")]) + assert {n.qualified_name for n in g.leaf_functions()} == {"b"} + assert {n.qualified_name for n in g.root_functions()} == {"a"} + + +class TestSubgraph: + def test_filters_to_selected_nodes(self) -> None: + g = make_graph([edge("a", "b"), edge("b", "c"), edge("c", "d")]) + sub = g.subgraph({node("a"), node("b"), node("c")}) + assert len(sub.edges) == 2 + callee_names = {e.callee.qualified_name for e in sub.edges} + assert "d" not in callee_names + + def test_empty_subgraph(self) -> None: + g = make_graph([edge("a", "b")]) + sub = g.subgraph(set()) + assert sub.edges == [] + + +class TestTopologicalOrder: + def test_linear_chain(self) -> None: + # a -> b -> c -> d + g = make_graph([edge("a", "b"), edge("b", "c"), edge("c", "d")]) + order = g.topological_order() + names = [n.qualified_name for n in order] + # Leaves-first: d before c before b before a + assert names.index("d") < names.index("c") + assert names.index("c") < names.index("b") + assert names.index("b") < names.index("a") + + def test_diamond(self) -> None: + # a -> b, a -> c, b -> d, c -> d + g = make_graph([edge("a", "b"), edge("a", "c"), edge("b", "d"), edge("c", "d")]) + order = g.topological_order() + names = [n.qualified_name for n in order] + assert names.index("d") < names.index("b") + assert names.index("d") < names.index("c") + assert names.index("b") < names.index("a") + assert names.index("c") < names.index("a") + + def test_empty_graph(self) -> None: + g = make_graph([]) + assert g.topological_order() == [] + + +class TestNodes: + def test_collects_all_nodes(self) -> None: + g = make_graph([edge("a", "b"), edge("c", "d")]) + names = {n.qualified_name for n in g.nodes} + assert names == {"a", "b", "c", "d"} + + def test_empty_graph(self) -> None: + g = make_graph([]) + assert g.nodes == set() + + +# --------------------------------------------------------------------------- +# augment_with_trace tests +# --------------------------------------------------------------------------- + + +class TestAugmentWithTrace: + def test_overlays_runtime_data(self, tmp_path: Path) -> None: + db_path = tmp_path / "trace.db" + conn = sqlite3.connect(str(db_path)) + conn.execute( + """ + CREATE TABLE pstats ( + filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, + call_count_nonrecursive INTEGER, num_callers INTEGER, + total_time_ns INTEGER, cumulative_time_ns INTEGER, callers BLOB + ) + """ + ) + conn.execute( + "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ("mod.py", 1, "helper", None, 10, 1, 5000, 5000, b"[]"), + ) + conn.commit() + conn.close() + + g = make_graph([edge("caller", "helper")]) + augmented = augment_with_trace(g, db_path) + + assert len(augmented.edges) == 1 + e = augmented.edges[0] + assert e.call_count == 10 + assert e.total_time_ns == 5000 + + def test_unmatched_edges_preserved(self, tmp_path: Path) -> None: + db_path = tmp_path / "trace.db" + conn = sqlite3.connect(str(db_path)) + conn.execute( + """ + CREATE TABLE pstats ( + filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, + call_count_nonrecursive INTEGER, num_callers INTEGER, + total_time_ns INTEGER, cumulative_time_ns INTEGER, callers BLOB + ) + """ + ) + conn.commit() + conn.close() + + g = make_graph([edge("caller", "helper")]) + augmented = augment_with_trace(g, db_path) + + assert len(augmented.edges) == 1 + e = augmented.edges[0] + assert e.call_count is None + assert e.total_time_ns is None + + def test_missing_pstats_table(self, tmp_path: Path) -> None: + db_path = tmp_path / "trace.db" + conn = sqlite3.connect(str(db_path)) + conn.close() + + g = make_graph([edge("caller", "helper")]) + result = augment_with_trace(g, db_path) + assert result.edges == g.edges + + def test_class_method_matching(self, tmp_path: Path) -> None: + db_path = tmp_path / "trace.db" + conn = sqlite3.connect(str(db_path)) + conn.execute( + """ + CREATE TABLE pstats ( + filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, + call_count_nonrecursive INTEGER, num_callers INTEGER, + total_time_ns INTEGER, cumulative_time_ns INTEGER, callers BLOB + ) + """ + ) + conn.execute( + "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ("mod.py", 5, "process", "MyClass", 3, 2, 9000, 12000, b"[]"), + ) + conn.commit() + conn.close() + + callee = FunctionNode(file_path=__import__("pathlib").Path("mod.py"), qualified_name="MyClass.process") + caller = FunctionNode(file_path=__import__("pathlib").Path("mod.py"), qualified_name="main") + g = CallGraph(edges=[CallEdge(caller=caller, callee=callee, is_cross_file=False)]) + + augmented = augment_with_trace(g, db_path) + assert augmented.edges[0].call_count == 3 + assert augmented.edges[0].total_time_ns == 9000 + + +# --------------------------------------------------------------------------- +# ReferenceGraph.get_call_graph integration tests +# --------------------------------------------------------------------------- + + +def write_file(project: Path, name: str, content: str) -> Path: + fp = project / name + fp.write_text(content, encoding="utf-8") + return fp + + +@pytest.fixture +def project(tmp_path: Path) -> Path: + project_root = tmp_path / "project" + project_root.mkdir() + return project_root + + +@pytest.fixture +def db_path(tmp_path: Path) -> Path: + return tmp_path / "cache.db" + + +class TestReferenceGraphGetCallGraph: + def test_simple_call_graph(self, project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +def helper(): + return 1 + +def caller(): + return helper() +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + graph = cg.get_call_graph({project / "mod.py": {"caller"}}) + assert len(graph.edges) == 1 + assert graph.edges[0].caller.qualified_name == "caller" + assert graph.edges[0].callee.qualified_name == "helper" + assert not graph.edges[0].is_cross_file + finally: + cg.close() + + def test_cross_file_call_graph(self, project: Path, db_path: Path) -> None: + write_file(project, "utils.py", "def utility():\n return 42\n") + write_file( + project, + "main.py", + """\ +from utils import utility + +def caller(): + return utility() +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + graph = cg.get_call_graph({project / "main.py": {"caller"}}) + assert len(graph.edges) == 1 + assert graph.edges[0].is_cross_file + assert graph.edges[0].callee.qualified_name == "utility" + finally: + cg.close() + + def test_multiple_callees(self, project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +def a(): + return 1 + +def b(): + return 2 + +def caller(): + return a() + b() +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + graph = cg.get_call_graph({project / "mod.py": {"caller"}}) + callee_names = {e.callee.qualified_name for e in graph.edges} + assert callee_names == {"a", "b"} + finally: + cg.close() + + def test_empty_input(self, project: Path, db_path: Path) -> None: + cg = ReferenceGraph(project, db_path=db_path) + try: + graph = cg.get_call_graph({}) + assert graph.edges == [] + finally: + cg.close() + + def test_leaf_has_no_callees(self, project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +def leaf(): + return 42 +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + graph = cg.get_call_graph({project / "mod.py": {"leaf"}}) + assert graph.edges == [] + finally: + cg.close() + + def test_include_metadata(self, project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +def helper(): + return 1 + +def caller(): + return helper() +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + graph = cg.get_call_graph({project / "mod.py": {"caller"}}, include_metadata=True) + assert len(graph.edges) == 1 + e = graph.edges[0] + assert e.callee_metadata is not None + assert e.callee_metadata.only_function_name == "helper" + assert e.callee_metadata.definition_type == "function" + assert e.callee_metadata.fully_qualified_name != "" + assert e.callee_metadata.source_line != "" + finally: + cg.close() + + def test_include_metadata_false_has_no_metadata(self, project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +def helper(): + return 1 + +def caller(): + return helper() +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + graph = cg.get_call_graph({project / "mod.py": {"caller"}}) + assert len(graph.edges) == 1 + assert graph.edges[0].callee_metadata is None + finally: + cg.close() + + def test_get_callees_includes_statement_dependencies(self, project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +X = 1 + +def caller(): + return X + 1 +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + _, function_sources = cg.get_callees({project / "mod.py": {"caller"}}) + assert [(source.qualified_name, source.definition_type) for source in function_sources] == [ + ("X", "statement") + ] + finally: + cg.close() + + +# --------------------------------------------------------------------------- +# CalleeMetadata unit tests +# --------------------------------------------------------------------------- + + +def test_edge_with_metadata() -> None: + meta = CalleeMetadata( + fully_qualified_name="mod.helper", + only_function_name="helper", + definition_type="function", + source_line="def helper(): ...", + ) + e = CallEdge(caller=node("caller"), callee=node("helper"), is_cross_file=False, callee_metadata=meta) + assert e.callee_metadata is meta + assert e.callee_metadata.only_function_name == "helper" + + +def test_edge_without_metadata() -> None: + e = CallEdge(caller=node("caller"), callee=node("helper"), is_cross_file=False) + assert e.callee_metadata is None + + +# --------------------------------------------------------------------------- +# callees_from_graph unit tests +# --------------------------------------------------------------------------- + + +def test_callees_from_graph_extracts_function_sources() -> None: + meta = CalleeMetadata( + fully_qualified_name="mod.helper", + only_function_name="helper", + definition_type="function", + source_line="def helper(): ...", + ) + e = CallEdge(caller=node("caller"), callee=node("helper"), is_cross_file=False, callee_metadata=meta) + g = CallGraph(edges=[e]) + + file_map, source_list = callees_from_graph(g) + assert len(source_list) == 1 + fs = source_list[0] + assert fs.qualified_name == "helper" + assert fs.fully_qualified_name == "mod.helper" + assert fs.only_function_name == "helper" + assert fs.source_code == "def helper(): ..." + assert fs.definition_type == "function" + + from pathlib import Path + + assert Path("mod.py") in file_map + assert fs in file_map[Path("mod.py")] + + +def test_callees_from_graph_skips_edges_without_metadata() -> None: + e1 = CallEdge(caller=node("a"), callee=node("b"), is_cross_file=False) + meta = CalleeMetadata( + fully_qualified_name="mod.c", only_function_name="c", definition_type="function", source_line="def c(): ..." + ) + e2 = CallEdge(caller=node("a"), callee=node("c"), is_cross_file=False, callee_metadata=meta) + g = CallGraph(edges=[e1, e2]) + + _, source_list = callees_from_graph(g) + assert len(source_list) == 1 + assert source_list[0].qualified_name == "c" + + +def test_callees_from_graph_empty() -> None: + g = CallGraph(edges=[]) + file_map, source_list = callees_from_graph(g) + assert file_map == {} + assert source_list == [] diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index af466cd3a..a2b31eb94 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -4501,6 +4501,104 @@ def process(w: Widget) -> str: assert "size" in code +def test_extract_parameter_type_constructors_stdlib_type(tmp_path: Path) -> None: + pkg = tmp_path / "mypkg" + pkg.mkdir() + (pkg / "__init__.py").write_text("", encoding="utf-8") + (pkg / "processor.py").write_text( + """from argparse import Namespace + +def process(ns: Namespace) -> str: + return str(ns) +""", + encoding="utf-8", + ) + + fto = FunctionToOptimize( + function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4 + ) + result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set()) + assert len(result.code_strings) == 1 + code = result.code_strings[0].code + assert "class Namespace:" in code + assert "def __init__(self, **kwargs):" in code + + +def test_extract_parameter_type_constructors_namedtuple_project_type(tmp_path: Path) -> None: + pkg = tmp_path / "mypkg" + pkg.mkdir() + (pkg / "__init__.py").write_text("", encoding="utf-8") + (pkg / "models.py").write_text( + """from pathlib import Path +from typing import NamedTuple + +class FunctionNode(NamedTuple): + file_path: Path + qualified_name: str +""", + encoding="utf-8", + ) + (pkg / "processor.py").write_text( + """from mypkg.models import FunctionNode + +def process(node: FunctionNode) -> str: + return node.qualified_name +""", + encoding="utf-8", + ) + + fto = FunctionToOptimize( + function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4 + ) + result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set()) + assert len(result.code_strings) == 1 + code = result.code_strings[0].code + assert "class FunctionNode(NamedTuple):" in code + assert "file_path: Path" in code + assert "qualified_name: str" in code + + +def test_extract_parameter_type_constructors_uses_raw_project_context_for_small_class(tmp_path: Path) -> None: + pkg = tmp_path / "mypkg" + pkg.mkdir() + (pkg / "__init__.py").write_text("", encoding="utf-8") + (pkg / "models.py").write_text( + """from functools import total_ordering + +@total_ordering +class Rank: + def __init__(self, value: int): + self.value = value + + def __lt__(self, other: "Rank") -> bool: + return self.value < other.value + + def __eq__(self, other: object) -> bool: + return isinstance(other, Rank) and self.value == other.value +""", + encoding="utf-8", + ) + (pkg / "processor.py").write_text( + """from mypkg.models import Rank + +def process(rank: Rank) -> int: + return rank.value +""", + encoding="utf-8", + ) + + fto = FunctionToOptimize( + function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4 + ) + result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set()) + assert len(result.code_strings) == 1 + code = result.code_strings[0].code + assert "from functools import total_ordering" in code + assert "@total_ordering" in code + assert "def __lt__" in code + assert "def __eq__" in code + + def test_extract_parameter_type_constructors_excludes_builtins(tmp_path: Path) -> None: pkg = tmp_path / "mypkg" pkg.mkdir() @@ -4788,6 +4886,48 @@ def test_extract_parameter_type_constructors_base_classes(tmp_path: Path) -> Non assert "class BaseProcessor:" in result.code_strings[0].code +def test_extract_parameter_type_constructors_attribute_base_prefers_imported_project_class(tmp_path: Path) -> None: + pkg = tmp_path / "mypkg" + pkg.mkdir() + (pkg / "__init__.py").write_text("", encoding="utf-8") + (pkg / "external.py").write_text( + """class Base: + def __init__(self, x: int): + self.x = x +""", + encoding="utf-8", + ) + (pkg / "models.py").write_text( + """import mypkg.external as ext + +class Base: + pass + +class Child(ext.Base): + def __init__(self, x: int): + super().__init__(x) +""", + encoding="utf-8", + ) + (pkg / "processor.py").write_text( + """from mypkg.models import Child + +def process(c: Child) -> int: + return c.x +""", + encoding="utf-8", + ) + + fto = FunctionToOptimize( + function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4 + ) + result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set()) + combined = "\n".join(cs.code for cs in result.code_strings) + assert "class Child(ext.Base):" in combined + assert "self.x = x" in combined + assert "class Base:\n pass" not in combined + + def test_extract_parameter_type_constructors_isinstance_builtins_excluded(tmp_path: Path) -> None: """Isinstance with builtins (int, str, etc.) should not produce stubs.""" pkg = tmp_path / "mypkg" @@ -4828,6 +4968,51 @@ def test_extract_parameter_type_constructors_transitive(tmp_path: Path) -> None: assert "class Config:" in combined +def test_extract_parameter_type_constructors_uses_raw_project_context_for_dataclass_inheritance(tmp_path: Path) -> None: + pkg = tmp_path / "mypkg" + pkg.mkdir() + (pkg / "__init__.py").write_text("", encoding="utf-8") + (pkg / "base.py").write_text( + """from dataclasses import dataclass +from pathlib import Path + +@dataclass +class BaseConfig: + file_path: Path +""", + encoding="utf-8", + ) + (pkg / "models.py").write_text( + """from dataclasses import dataclass +from mypkg.base import BaseConfig + +@dataclass +class ChildConfig(BaseConfig): + qualified_name: str +""", + encoding="utf-8", + ) + (pkg / "processor.py").write_text( + """from mypkg.models import ChildConfig + +def process(cfg: ChildConfig) -> str: + return cfg.qualified_name +""", + encoding="utf-8", + ) + + fto = FunctionToOptimize( + function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4 + ) + result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set()) + combined = "\n".join(cs.code for cs in result.code_strings) + assert "@dataclass" in combined + assert "class BaseConfig" in combined + assert "file_path: Path" in combined + assert "class ChildConfig(BaseConfig):" in combined + assert "qualified_name: str" in combined + + def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None: """Third-party classes should produce compact __init__ stubs, not full class source.""" # Use a real third-party package (pydantic) so jedi can actually resolve it diff --git a/tests/test_ranking_boost.py b/tests/test_ranking_boost.py new file mode 100644 index 000000000..c3e6fcd80 --- /dev/null +++ b/tests/test_ranking_boost.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +from argparse import Namespace +from pathlib import Path +from unittest.mock import patch + +import pytest + +from codeflash.discovery.discover_unit_tests import existing_unit_test_count +from codeflash.models.function_types import FunctionToOptimize +from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile +from codeflash.models.test_type import TestType +from codeflash.optimization.optimizer import Optimizer + + +def make_func(name: str, project_root: Path) -> FunctionToOptimize: + return FunctionToOptimize(function_name=name, file_path=project_root / "mod.py") + + +def make_test(test_type: TestType, test_name: str = "test_something") -> FunctionCalledInTest: + return FunctionCalledInTest( + tests_in_file=TestsInFile( + test_file=Path("/tests/test_mod.py"), test_class=None, test_function=test_name, test_type=test_type + ), + position=CodePosition(line_no=1, col_no=0), + ) + + +def build_test_count_cache( + funcs: list[FunctionToOptimize], project_root: Path, function_to_tests: dict[str, set[FunctionCalledInTest]] +) -> dict[tuple[Path, str], int]: + return { + (func.file_path, func.qualified_name): existing_unit_test_count(func, project_root, function_to_tests) + for func in funcs + } + + +def make_optimizer(project_root: Path) -> Optimizer: + def _noop_display_global_ranking(*_args: object, **_kwargs: object) -> None: + return None + + optimizer = Optimizer.__new__(Optimizer) + optimizer.args = Namespace(project_root=project_root) + optimizer.display_global_ranking = _noop_display_global_ranking + return optimizer + + +@pytest.fixture +def project_root(tmp_path: Path) -> Path: + root = tmp_path / "project" + root.mkdir() + (root / "mod.py").write_text("def foo(): pass\ndef bar(): pass\ndef baz(): pass\n") + return root + + +def test_no_tests(project_root: Path) -> None: + func = make_func("foo", project_root) + assert existing_unit_test_count(func, project_root, {}) == 0 + + +def test_no_matching_key(project_root: Path) -> None: + func = make_func("foo", project_root) + tests = {"other_module.bar": {make_test(TestType.EXISTING_UNIT_TEST)}} + assert existing_unit_test_count(func, project_root, tests) == 0 + + +def test_only_replay_tests(project_root: Path) -> None: + func = make_func("foo", project_root) + key = func.qualified_name_with_modules_from_root(project_root) + tests = {key: {make_test(TestType.REPLAY_TEST)}} + assert existing_unit_test_count(func, project_root, tests) == 0 + + +def test_single_existing_test(project_root: Path) -> None: + func = make_func("foo", project_root) + key = func.qualified_name_with_modules_from_root(project_root) + tests = {key: {make_test(TestType.EXISTING_UNIT_TEST)}} + assert existing_unit_test_count(func, project_root, tests) == 1 + + +def test_multiple_existing_tests(project_root: Path) -> None: + func = make_func("foo", project_root) + key = func.qualified_name_with_modules_from_root(project_root) + tests = { + key: { + make_test(TestType.EXISTING_UNIT_TEST, "test_one"), + make_test(TestType.EXISTING_UNIT_TEST, "test_two"), + make_test(TestType.EXISTING_UNIT_TEST, "test_three"), + } + } + assert existing_unit_test_count(func, project_root, tests) == 3 + + +def test_mixed_test_types(project_root: Path) -> None: + func = make_func("foo", project_root) + key = func.qualified_name_with_modules_from_root(project_root) + tests = { + key: { + make_test(TestType.EXISTING_UNIT_TEST, "test_one"), + make_test(TestType.REPLAY_TEST, "test_replay"), + make_test(TestType.GENERATED_REGRESSION, "test_gen"), + make_test(TestType.EXISTING_UNIT_TEST, "test_two"), + } + } + assert existing_unit_test_count(func, project_root, tests) == 2 + + +def test_truthiness_for_boolean_usage(project_root: Path) -> None: + func = make_func("foo", project_root) + key = func.qualified_name_with_modules_from_root(project_root) + assert not existing_unit_test_count(func, project_root, {}) + assert existing_unit_test_count(func, project_root, {key: {make_test(TestType.EXISTING_UNIT_TEST)}}) + + +def test_functions_with_more_tests_rank_higher(project_root: Path) -> None: + funcs = [make_func(name, project_root) for name in ("foo", "bar", "baz")] + function_to_tests: dict[str, set[FunctionCalledInTest]] = { + funcs[0].qualified_name_with_modules_from_root(project_root): { + make_test(TestType.EXISTING_UNIT_TEST, "test_one") + }, + funcs[1].qualified_name_with_modules_from_root(project_root): { + make_test(TestType.EXISTING_UNIT_TEST, "test_one"), + make_test(TestType.EXISTING_UNIT_TEST, "test_two"), + make_test(TestType.EXISTING_UNIT_TEST, "test_three"), + }, + # baz has no tests + } + + ranked = sorted(funcs, key=lambda f: -existing_unit_test_count(f, project_root, function_to_tests)) + + assert ranked[0].function_name == "bar" # 3 tests + assert ranked[1].function_name == "foo" # 1 test + assert ranked[2].function_name == "baz" # 0 tests + + +def test_stable_sort_preserves_order_for_equal_counts(project_root: Path) -> None: + funcs = [make_func(name, project_root) for name in ("foo", "bar", "baz")] + function_to_tests: dict[str, set[FunctionCalledInTest]] = { + f.qualified_name_with_modules_from_root(project_root): {make_test(TestType.EXISTING_UNIT_TEST)} for f in funcs + } + + ranked = sorted(funcs, key=lambda f: -existing_unit_test_count(f, project_root, function_to_tests)) + + assert [f.function_name for f in ranked] == ["foo", "bar", "baz"] + + +def test_parametrized_tests_deduplication(project_root: Path) -> None: + func = make_func("foo", project_root) + key = func.qualified_name_with_modules_from_root(project_root) + tests = { + key: { + make_test(TestType.EXISTING_UNIT_TEST, "test_foo[0]"), + make_test(TestType.EXISTING_UNIT_TEST, "test_foo[1]"), + make_test(TestType.EXISTING_UNIT_TEST, "test_foo[2]"), + make_test(TestType.EXISTING_UNIT_TEST, "test_bar"), + } + } + assert existing_unit_test_count(func, project_root, tests) == 2 + + +def test_trace_ranking_keeps_addressable_time_primary_over_test_count(project_root: Path, tmp_path: Path) -> None: + optimizer = make_optimizer(project_root) + funcs = [make_func(name, project_root) for name in ("foo", "bar", "baz")] + trace_file = tmp_path / "trace.db" + trace_file.touch() + + ranked_functions = [funcs[0], funcs[1], funcs[2]] + addressable_times = {"foo": 100.0, "bar": 20.0, "baz": 5.0} + function_to_tests: dict[str, set[FunctionCalledInTest]] = { + funcs[1].qualified_name_with_modules_from_root(project_root): { + make_test(TestType.EXISTING_UNIT_TEST, "test_one"), + make_test(TestType.EXISTING_UNIT_TEST, "test_two"), + make_test(TestType.EXISTING_UNIT_TEST, "test_three"), + } + } + + class FakeRanker: + def __init__(self, _trace_file: Path) -> None: + pass + + def rank_functions(self, _functions: list[FunctionToOptimize]) -> list[FunctionToOptimize]: + return ranked_functions + + def get_function_addressable_time(self, function: FunctionToOptimize) -> float: + return addressable_times[function.function_name] + + with patch("codeflash.benchmarking.function_ranker.FunctionRanker", FakeRanker): + ranked = optimizer.rank_all_functions_globally( + {project_root / "mod.py": funcs}, + trace_file, + test_count_cache=build_test_count_cache(funcs, project_root, function_to_tests), + ) + + assert [func.function_name for _, func in ranked] == ["foo", "bar", "baz"] + + +def test_trace_ranking_uses_test_count_as_tiebreaker(project_root: Path, tmp_path: Path) -> None: + optimizer = make_optimizer(project_root) + funcs = [make_func(name, project_root) for name in ("foo", "bar", "baz")] + trace_file = tmp_path / "trace.db" + trace_file.touch() + + ranked_functions = [funcs[0], funcs[1], funcs[2]] + addressable_times = {"foo": 100.0, "bar": 100.0, "baz": 5.0} + function_to_tests: dict[str, set[FunctionCalledInTest]] = { + funcs[0].qualified_name_with_modules_from_root(project_root): { + make_test(TestType.EXISTING_UNIT_TEST, "test_one") + }, + funcs[1].qualified_name_with_modules_from_root(project_root): { + make_test(TestType.EXISTING_UNIT_TEST, "test_one"), + make_test(TestType.EXISTING_UNIT_TEST, "test_two"), + make_test(TestType.EXISTING_UNIT_TEST, "test_three"), + }, + } + + class FakeRanker: + def __init__(self, _trace_file: Path) -> None: + pass + + def rank_functions(self, _functions: list[FunctionToOptimize]) -> list[FunctionToOptimize]: + return ranked_functions + + def get_function_addressable_time(self, function: FunctionToOptimize) -> float: + return addressable_times[function.function_name] + + with patch("codeflash.benchmarking.function_ranker.FunctionRanker", FakeRanker): + ranked = optimizer.rank_all_functions_globally( + {project_root / "mod.py": funcs}, + trace_file, + test_count_cache=build_test_count_cache(funcs, project_root, function_to_tests), + ) + + assert [func.function_name for _, func in ranked] == ["bar", "foo", "baz"] + + +def test_dependency_count_ranking_keeps_callee_count_primary(project_root: Path) -> None: + optimizer = make_optimizer(project_root) + funcs = [make_func(name, project_root) for name in ("foo", "bar")] + function_to_tests: dict[str, set[FunctionCalledInTest]] = { + funcs[1].qualified_name_with_modules_from_root(project_root): { + make_test(TestType.EXISTING_UNIT_TEST, "test_one"), + make_test(TestType.EXISTING_UNIT_TEST, "test_two"), + make_test(TestType.EXISTING_UNIT_TEST, "test_three"), + } + } + + class FakeResolver: + def count_callees_per_function(self, _mapping: dict[Path, set[str]]) -> dict[tuple[Path, str], int]: + return {(project_root / "mod.py", "foo"): 5, (project_root / "mod.py", "bar"): 1} + + ranked = optimizer.rank_by_dependency_count( + [(project_root / "mod.py", funcs[0]), (project_root / "mod.py", funcs[1])], + FakeResolver(), + test_count_cache=build_test_count_cache(funcs, project_root, function_to_tests), + ) + + assert [func.function_name for _, func in ranked] == ["foo", "bar"] + + +def test_dependency_count_ranking_uses_test_count_as_tiebreaker(project_root: Path) -> None: + optimizer = make_optimizer(project_root) + funcs = [make_func(name, project_root) for name in ("foo", "bar")] + function_to_tests: dict[str, set[FunctionCalledInTest]] = { + funcs[0].qualified_name_with_modules_from_root(project_root): { + make_test(TestType.EXISTING_UNIT_TEST, "test_one") + }, + funcs[1].qualified_name_with_modules_from_root(project_root): { + make_test(TestType.EXISTING_UNIT_TEST, "test_one"), + make_test(TestType.EXISTING_UNIT_TEST, "test_two"), + make_test(TestType.EXISTING_UNIT_TEST, "test_three"), + }, + } + + class FakeResolver: + def count_callees_per_function(self, _mapping: dict[Path, set[str]]) -> dict[tuple[Path, str], int]: + return {(project_root / "mod.py", "foo"): 2, (project_root / "mod.py", "bar"): 2} + + ranked = optimizer.rank_by_dependency_count( + [(project_root / "mod.py", funcs[0]), (project_root / "mod.py", funcs[1])], + FakeResolver(), + test_count_cache=build_test_count_cache(funcs, project_root, function_to_tests), + ) + + assert [func.function_name for _, func in ranked] == ["bar", "foo"] diff --git a/tests/test_remove_unused_definitions.py b/tests/test_remove_unused_definitions.py index 032942f29..3bc237ba4 100644 --- a/tests/test_remove_unused_definitions.py +++ b/tests/test_remove_unused_definitions.py @@ -33,7 +33,7 @@ def another_function(): qualified_functions = {"main_function"} result = remove_unused_definitions_by_function_names(code, qualified_functions) # Normalize whitespace for comparison - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_class_variable_removal() -> None: @@ -84,7 +84,7 @@ def helper_function(): qualified_functions = {"helper_function"} result = remove_unused_definitions_by_function_names(code, qualified_functions) # Normalize whitespace for comparison - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_complex_variable_dependencies() -> None: @@ -122,7 +122,7 @@ def tuple_user(): qualified_functions = {"main_function"} result = remove_unused_definitions_by_function_names(code, qualified_functions) - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_type_annotation_usage() -> None: @@ -156,7 +156,7 @@ def unused_function(param: UnusedType) -> UnusedType: qualified_functions = {"main_function"} result = remove_unused_definitions_by_function_names(code, qualified_functions) # Normalize whitespace for comparison - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_class_method_with_dunder_methods() -> None: @@ -215,7 +215,7 @@ def helper_function(): qualified_functions = {"MyClass.target_method"} result = remove_unused_definitions_by_function_names(code, qualified_functions) # Normalize whitespace for comparison - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_complex_type_annotations() -> None: @@ -263,7 +263,7 @@ def unused_function(param: UnusedType) -> None: qualified_functions = {"process_data"} result = remove_unused_definitions_by_function_names(code, qualified_functions) - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_try_except_finally_variables() -> None: @@ -325,7 +325,7 @@ def unused_function(): qualified_functions = {"use_constants", "use_cleanup"} result = remove_unused_definitions_by_function_names(code, qualified_functions) - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_base_class_inheritance() -> None: @@ -383,8 +383,9 @@ def test_function(): qualified_functions = {"test_function"} result = remove_unused_definitions_by_function_names(code, qualified_functions) # LayoutDumper should be preserved because ObjectDetectionLayoutDumper inherits from it - assert "class LayoutDumper" in result - assert "class ObjectDetectionLayoutDumper" in result + assert "class LayoutDumper" in result.code + assert "class ObjectDetectionLayoutDumper" in result.code + assert result.code.strip() == expected.strip() def test_conditional_and_loop_variables() -> None: @@ -471,7 +472,7 @@ def unused_function(): qualified_functions = {"get_platform_info", "get_loop_result"} result = remove_unused_definitions_by_function_names(code, qualified_functions) - assert result.strip() == expected.strip() + assert result.code.strip() == expected.strip() def test_enum_attribute_access_dependency() -> None: @@ -519,10 +520,10 @@ def process_message(kind): qualified_functions = {"process_message"} result = remove_unused_definitions_by_function_names(code, qualified_functions) # MessageKind should be preserved because process_message uses MessageKind.VALUE - assert "class MessageKind" in result + assert "class MessageKind" in result.code # UNUSED_VAR should be removed - assert "UNUSED_VAR" not in result - assert result.strip() == expected.strip() + assert "UNUSED_VAR" not in result.code + assert result.code.strip() == expected.strip() def test_attribute_access_does_not_track_attr_name() -> None: @@ -551,7 +552,7 @@ def get_x(self): qualified_functions = {"MyClass.get_x", "MyClass.__init__"} result = remove_unused_definitions_by_function_names(code, qualified_functions) # Module-level x should NOT be kept (self.x doesn't reference it) - assert 'x = "module_level_x"' not in result + assert 'x = "module_level_x"' not in result.code # UNUSED_VAR should also be removed - assert "UNUSED_VAR" not in result - assert result.strip() == expected.strip() + assert "UNUSED_VAR" not in result.code + assert result.code.strip() == expected.strip()