Skip to content

Commit a0a2a85

Browse files
authored
Merge pull request #1660 from codeflash-ai/unstructured-inference
feat: improve function ranking with reference graph and test-based boosting
2 parents 4822d1a + 8af7fdc commit a0a2a85

21 files changed

Lines changed: 2707 additions & 1987 deletions

codeflash/api/aiservice.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from codeflash.code_utils.env_utils import get_codeflash_api_key
1515
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
1616
from codeflash.code_utils.time_utils import humanize_runtime
17-
from codeflash.languages import Language, current_language
17+
from codeflash.languages import Language, current_language, current_language_support
1818
from codeflash.models.ExperimentMetadata import ExperimentMetadata
1919
from codeflash.models.models import (
2020
AIServiceRefinerRequest,
@@ -58,6 +58,8 @@ def add_language_metadata(
5858
payload: dict[str, Any], language_version: str | None = None, module_system: str | None = None
5959
) -> None:
6060
"""Add language version and module system metadata to an API payload."""
61+
if language_version is None:
62+
language_version = current_language_support().language_version
6163
payload["language_version"] = language_version
6264
payload["python_version"] = language_version if current_language() == Language.PYTHON else None
6365

codeflash/code_utils/config_consts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050

5151
MAX_CONTEXT_LEN_REVIEW = 1000
5252

53+
HIGH_EFFORT_TOP_N = 15
54+
5355

5456
class EffortLevel(str, Enum):
5557
LOW = "low"

codeflash/discovery/discover_unit_tests.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import subprocess
1212
import unittest
1313
from collections import defaultdict
14+
from functools import lru_cache
1415
from pathlib import Path
1516
from typing import TYPE_CHECKING, Callable, Optional, final
1617

@@ -35,6 +36,21 @@
3536
from codeflash.verification.verification_utils import TestConfig
3637

3738

39+
def existing_unit_test_count(
40+
func: FunctionToOptimize, project_root: Path, function_to_tests: dict[str, set[FunctionCalledInTest]]
41+
) -> int:
42+
key = f"{module_name_from_file_path_cached(func.file_path, project_root)}.{func.qualified_name}"
43+
tests = function_to_tests.get(key, set())
44+
seen: set[tuple[Path, str | None, str]] = set()
45+
for t in tests:
46+
if t.tests_in_file.test_type != TestType.EXISTING_UNIT_TEST:
47+
continue
48+
tif = t.tests_in_file
49+
base_name = tif.test_function.split("[", 1)[0]
50+
seen.add((tif.test_file, tif.test_class, base_name))
51+
return len(seen)
52+
53+
3854
@final
3955
class PytestExitCode(enum.IntEnum): # don't need to import entire pytest just for this
4056
#: Tests passed.
@@ -1079,3 +1095,9 @@ def process_test_files(
10791095
tests_cache.close()
10801096

10811097
return dict(function_to_test_map), num_discovered_tests, num_discovered_replay_tests
1098+
1099+
1100+
# Cache module name resolution to avoid repeated Path.resolve()/relative_to() calls
1101+
@lru_cache(maxsize=128)
1102+
def module_name_from_file_path_cached(file_path: Path, project_root: Path) -> str:
1103+
return module_name_from_file_path(file_path, project_root)

codeflash/languages/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pathlib import Path
1919

2020
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
21+
from codeflash.models.call_graph import CallGraph
2122
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId, ValidCode
2223
from codeflash.verification.verification_utils import TestConfig
2324

@@ -250,6 +251,12 @@ def count_callees_per_function(
250251
"""Return the number of callees for each (file_path, qualified_name) pair."""
251252
...
252253

254+
def get_call_graph(
255+
self, file_path_to_qualified_names: dict[Path, set[str]], *, include_metadata: bool = False
256+
) -> CallGraph:
257+
"""Return a CallGraph with full caller→callee edges for the given functions."""
258+
...
259+
253260
def close(self) -> None:
254261
"""Release resources (e.g. database connections)."""
255262
...

codeflash/languages/code_replacer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77
from __future__ import annotations
88

9-
from pathlib import Path
9+
import os
1010
from typing import TYPE_CHECKING
1111

1212
from codeflash.cli_cmds.console import logger
1313
from codeflash.languages.base import FunctionFilterCriteria, Language
1414

1515
if TYPE_CHECKING:
16+
from pathlib import Path
17+
1618
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1719
from codeflash.languages.base import LanguageSupport
1820
from codeflash.models.models import CodeStringsMarkdown
@@ -38,7 +40,7 @@ def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStrin
3840
# directory prefix but the correct filename
3941
target_name = relative_path.name
4042
basename_matches = [
41-
code for path, code in file_to_code_context.items() if path != "None" and Path(path).name == target_name
43+
code for path, code in file_to_code_context.items() if path != "None" and os.path.basename(path) == target_name
4244
]
4345
if len(basename_matches) == 1:
4446
logger.debug(f"Using basename-matched code block for {relative_path}")

codeflash/languages/function_optimizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ def __init__(
425425
args: Namespace | None = None,
426426
replay_tests_dir: Path | None = None,
427427
call_graph: DependencyResolver | None = None,
428+
effort_override: str | None = None,
428429
) -> None:
429430
self.project_root = test_cfg.project_root_path.resolve()
430431
self.test_cfg = test_cfg
@@ -451,7 +452,8 @@ def __init__(
451452
self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None
452453
self.test_files = TestFiles(test_files=[])
453454

454-
self.effort = getattr(args, "effort", EffortLevel.MEDIUM.value) if args else EffortLevel.MEDIUM.value
455+
default_effort = getattr(args, "effort", EffortLevel.MEDIUM.value) if args else EffortLevel.MEDIUM.value
456+
self.effort = effort_override or default_effort
455457

456458
self.args = args # Check defaults for these
457459
self.function_trace_id: str = str(uuid.uuid4())

0 commit comments

Comments
 (0)