Skip to content

Commit e7d2273

Browse files
committed
perf: cache jedi project, batch test cache writes, fix Windows relative_to bug
Port valuable improvements from #1846 that remain applicable after #1660: - Cache jedi.Project instances via @cache to avoid recreating across 5 call sites - Fix unguarded relative_to() in get_code_optimization_context (Windows 8.3 paths) - Pre-group references by parent function in get_function_sources_from_jedi for O(1) lookup - Batch TestsCache writes with flush() + executemany instead of per-row commit - Gracefully disable cache writes on sqlite3.OperationalError - Build functions_to_optimize_by_name dict for O(1) fallback lookup in process_test_files - Derive all_defs from all_names via is_definition() to save a redundant Jedi call
1 parent a0a2a85 commit e7d2273

2 files changed

Lines changed: 130 additions & 85 deletions

File tree

codeflash/discovery/discover_unit_tests.py

Lines changed: 68 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ def __init__(self, project_root_path: Path) -> None:
136136
)
137137

138138
self.memory_cache = {}
139+
self.pending_rows: list[tuple[str, str, str, str, str, str, int | str, int, int]] = []
140+
self.writes_enabled = True
139141

140142
def insert_test(
141143
self,
@@ -150,10 +152,8 @@ def insert_test(
150152
col_number: int,
151153
) -> None:
152154
test_type_value = test_type.value if hasattr(test_type, "value") else test_type
153-
self.cur.execute(
154-
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
155+
self.pending_rows.append(
155156
(
156-
self.project_root_path,
157157
file_path,
158158
file_hash,
159159
qualified_name_with_modules_from_root,
@@ -163,9 +163,23 @@ def insert_test(
163163
test_type_value,
164164
line_number,
165165
col_number,
166-
),
166+
)
167167
)
168-
self.connection.commit()
168+
169+
def flush(self) -> None:
170+
if not self.writes_enabled or not self.pending_rows:
171+
return
172+
try:
173+
self.cur.executemany(
174+
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
175+
[(self.project_root_path, *row) for row in self.pending_rows],
176+
)
177+
self.connection.commit()
178+
except sqlite3.OperationalError as e:
179+
logger.debug(f"Failed to persist discovered test cache, disabling cache writes: {e}")
180+
self.writes_enabled = False
181+
finally:
182+
self.pending_rows.clear()
169183

170184
def get_function_to_test_map_for_file(
171185
self, file_path: str, file_hash: str
@@ -212,6 +226,7 @@ def compute_file_hash(path: Path) -> str:
212226
return h.hexdigest()
213227

214228
def close(self) -> None:
229+
self.flush()
215230
self.cur.close()
216231
self.connection.close()
217232

@@ -849,6 +864,10 @@ def process_test_files(
849864
function_to_test_map = defaultdict(set)
850865
num_discovered_tests = 0
851866
num_discovered_replay_tests = 0
867+
functions_to_optimize_by_name: dict[str, list[FunctionToOptimize]] = defaultdict(list)
868+
if functions_to_optimize:
869+
for function_to_optimize in functions_to_optimize:
870+
functions_to_optimize_by_name[function_to_optimize.function_name].append(function_to_optimize)
852871

853872
# Set up sys_path for Jedi to resolve imports correctly
854873
import sys
@@ -891,8 +910,8 @@ def process_test_files(
891910
test_functions = set()
892911

893912
all_names = script.get_names(all_scopes=True, references=True)
894-
all_defs = script.get_names(all_scopes=True, definitions=True)
895913
all_names_top = script.get_names(all_scopes=True)
914+
all_defs = [name for name in all_names if name.is_definition()]
896915

897916
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
898917
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
@@ -967,10 +986,9 @@ def process_test_files(
967986

968987
test_function_names_set = set(test_functions_by_name.keys())
969988
relevant_names = []
970-
971-
names_with_full_name = [name for name in all_names if name.full_name is not None]
972-
973-
for name in names_with_full_name:
989+
for name in all_names:
990+
if name.full_name is None:
991+
continue
974992
match = FUNCTION_NAME_REGEX.search(name.full_name)
975993
if match and match.group(1) in test_function_names_set:
976994
relevant_names.append((name, match.group(1)))
@@ -985,56 +1003,49 @@ def process_test_files(
9851003
if not definition or definition[0].type != "function":
9861004
# Fallback: Try to match against functions_to_optimize when Jedi can't resolve
9871005
# This handles cases where Jedi fails with pytest fixtures
988-
if functions_to_optimize and name.name:
989-
for func_to_opt in functions_to_optimize:
990-
# Check if this unresolved name matches a function we're looking for
991-
if func_to_opt.function_name == name.name:
992-
# Check if the test file imports the class/module containing this function
993-
qualified_name_with_modules = func_to_opt.qualified_name_with_modules_from_root(
994-
project_root_path
995-
)
1006+
if functions_to_optimize_by_name and name.name:
1007+
for func_to_opt in functions_to_optimize_by_name.get(name.name, []):
1008+
qualified_name_with_modules = func_to_opt.qualified_name_with_modules_from_root(
1009+
project_root_path
1010+
)
9961011

997-
# Only add if this test actually tests the function we're optimizing
998-
for test_func in test_functions_by_name[scope]:
999-
if test_func.parameters is not None:
1000-
if test_framework == "pytest":
1001-
scope_test_function = (
1002-
f"{test_func.function_name}[{test_func.parameters}]"
1003-
)
1004-
else: # unittest
1005-
scope_test_function = (
1006-
f"{test_func.function_name}_{test_func.parameters}"
1007-
)
1008-
else:
1009-
scope_test_function = test_func.function_name
1010-
1011-
function_to_test_map[qualified_name_with_modules].add(
1012-
FunctionCalledInTest(
1013-
tests_in_file=TestsInFile(
1014-
test_file=test_file,
1015-
test_class=test_func.test_class,
1016-
test_function=scope_test_function,
1017-
test_type=test_func.test_type,
1018-
),
1019-
position=CodePosition(line_no=name.line, col_no=name.column),
1020-
)
1021-
)
1022-
tests_cache.insert_test(
1023-
file_path=str(test_file),
1024-
file_hash=file_hash,
1025-
qualified_name_with_modules_from_root=qualified_name_with_modules,
1026-
function_name=scope,
1027-
test_class=test_func.test_class or "",
1028-
test_function=scope_test_function,
1029-
test_type=test_func.test_type,
1030-
line_number=name.line,
1031-
col_number=name.column,
1012+
# Only add if this test actually tests the function we're optimizing
1013+
for test_func in test_functions_by_name[scope]:
1014+
if test_func.parameters is not None:
1015+
if test_framework == "pytest":
1016+
scope_test_function = f"{test_func.function_name}[{test_func.parameters}]"
1017+
else: # unittest
1018+
scope_test_function = f"{test_func.function_name}_{test_func.parameters}"
1019+
else:
1020+
scope_test_function = test_func.function_name
1021+
1022+
function_to_test_map[qualified_name_with_modules].add(
1023+
FunctionCalledInTest(
1024+
tests_in_file=TestsInFile(
1025+
test_file=test_file,
1026+
test_class=test_func.test_class,
1027+
test_function=scope_test_function,
1028+
test_type=test_func.test_type,
1029+
),
1030+
position=CodePosition(line_no=name.line, col_no=name.column),
10321031
)
1032+
)
1033+
tests_cache.insert_test(
1034+
file_path=str(test_file),
1035+
file_hash=file_hash,
1036+
qualified_name_with_modules_from_root=qualified_name_with_modules,
1037+
function_name=scope,
1038+
test_class=test_func.test_class or "",
1039+
test_function=scope_test_function,
1040+
test_type=test_func.test_type,
1041+
line_number=name.line,
1042+
col_number=name.column,
1043+
)
10331044

1034-
if test_func.test_type == TestType.REPLAY_TEST:
1035-
num_discovered_replay_tests += 1
1045+
if test_func.test_type == TestType.REPLAY_TEST:
1046+
num_discovered_replay_tests += 1
10361047

1037-
num_discovered_tests += 1
1048+
num_discovered_tests += 1
10381049
continue
10391050
definition_obj = definition[0]
10401051
definition_path = str(definition_obj.module_path)
@@ -1090,6 +1101,7 @@ def process_test_files(
10901101
logger.debug(str(e))
10911102
continue
10921103

1104+
tests_cache.flush()
10931105
progress.advance(task_id)
10941106

10951107
tests_cache.close()

codeflash/languages/python/context/code_context_extractor.py

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
from collections import defaultdict
77
from dataclasses import dataclass, field
8+
from functools import cache
89
from itertools import chain
910
from typing import TYPE_CHECKING
1011

@@ -47,6 +48,13 @@
4748
from codeflash.languages.python.context.unused_definition_remover import UsageInfo
4849

4950

51+
@cache
52+
def get_jedi_project(project_root_path: str) -> object:
53+
import jedi
54+
55+
return jedi.Project(path=project_root_path)
56+
57+
5058
@dataclass
5159
class FileContextCache:
5260
original_module: cst.Module
@@ -102,15 +110,22 @@ def get_code_optimization_context(
102110
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
103111
call_graph: DependencyResolver | None = None,
104112
) -> CodeOptimizationContext:
113+
project_root_path = project_root_path.resolve()
114+
jedi_project = get_jedi_project(str(project_root_path))
115+
105116
# Get FunctionSource representation of helpers of FTO
106117
fto_input = {function_to_optimize.file_path: {function_to_optimize.qualified_name}}
107118
if call_graph is not None:
108119
helpers_of_fto_dict, helpers_of_fto_list = call_graph.get_callees(fto_input)
109120
else:
110-
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(fto_input, project_root_path)
121+
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(
122+
fto_input, project_root_path, jedi_project=jedi_project
123+
)
111124

112125
# Add function to optimize into helpers of FTO dict, as they'll be processed together
113-
fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path)
126+
fto_as_function_source = get_function_to_optimize_as_function_source(
127+
function_to_optimize, project_root_path, jedi_project=jedi_project
128+
)
114129
helpers_of_fto_dict[function_to_optimize.file_path].add(fto_as_function_source)
115130

116131
# Format data to search for helpers of helpers using get_function_sources_from_jedi
@@ -124,7 +139,7 @@ def get_code_optimization_context(
124139
qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if "." in qn})
125140

126141
helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(
127-
helpers_of_fto_qualified_names_dict, project_root_path
142+
helpers_of_fto_qualified_names_dict, project_root_path, jedi_project=jedi_project
128143
)
129144

130145
# Extract all code contexts in a single pass (one CST parse per file)
@@ -133,11 +148,14 @@ def get_code_optimization_context(
133148
final_read_writable_code = all_ctx.read_writable
134149

135150
# Ensure the target file is first in the code blocks so the LLM knows which file to optimize
136-
target_relative = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve())
137-
target_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path == target_relative]
138-
other_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path != target_relative]
139-
if target_blocks:
140-
final_read_writable_code.code_strings = target_blocks + other_blocks
151+
try:
152+
target_relative = function_to_optimize.file_path.resolve().relative_to(project_root_path)
153+
target_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path == target_relative]
154+
other_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path != target_relative]
155+
if target_blocks:
156+
final_read_writable_code.code_strings = target_blocks + other_blocks
157+
except ValueError:
158+
pass
141159

142160
read_only_code_markdown = all_ctx.read_only
143161

@@ -434,13 +452,13 @@ def re_extract_from_cache(
434452
) -> CodeStringsMarkdown:
435453
"""Re-extract context from cached modules without file I/O or CST parsing."""
436454
result = CodeStringsMarkdown()
437-
for cache in file_caches:
455+
for file_cache in file_caches:
438456
try:
439457
pruned = parse_code_and_prune_cst(
440-
cache.cleaned_module,
458+
file_cache.cleaned_module,
441459
code_context_type,
442-
cache.fto_names,
443-
cache.hoh_names,
460+
file_cache.fto_names,
461+
file_cache.hoh_names,
444462
remove_docstrings=remove_docstrings,
445463
)
446464
except ValueError:
@@ -450,24 +468,25 @@ def re_extract_from_cache(
450468
code = ast.unparse(ast.parse(pruned.code))
451469
else:
452470
code = add_needed_imports_from_module(
453-
src_module_code=cache.original_module,
471+
src_module_code=file_cache.original_module,
454472
dst_module_code=pruned,
455-
src_path=cache.file_path,
456-
dst_path=cache.file_path,
473+
src_path=file_cache.file_path,
474+
dst_path=file_cache.file_path,
457475
project_root=project_root_path,
458-
helper_functions=cache.helper_functions,
476+
helper_functions=file_cache.helper_functions,
459477
)
460-
result.code_strings.append(CodeString(code=code, file_path=cache.relative_path))
478+
result.code_strings.append(CodeString(code=code, file_path=file_cache.relative_path))
461479
return result
462480

463481

464482
def get_function_to_optimize_as_function_source(
465-
function_to_optimize: FunctionToOptimize, project_root_path: Path
483+
function_to_optimize: FunctionToOptimize, project_root_path: Path, *, jedi_project: object | None = None
466484
) -> FunctionSource:
467485
import jedi
468486

469487
# Use jedi to find function to optimize
470-
script = jedi.Script(path=function_to_optimize.file_path, project=jedi.Project(path=project_root_path))
488+
project = jedi_project if jedi_project is not None else get_jedi_project(str(project_root_path.resolve()))
489+
script = jedi.Script(path=function_to_optimize.file_path, project=project)
471490

472491
# Get all names in the file
473492
names = script.get_names(all_scopes=True, definitions=True, references=False)
@@ -498,22 +517,36 @@ def get_function_to_optimize_as_function_source(
498517

499518

500519
def get_function_sources_from_jedi(
501-
file_path_to_qualified_function_names: dict[Path, set[str]], project_root_path: Path
520+
file_path_to_qualified_function_names: dict[Path, set[str]],
521+
project_root_path: Path,
522+
*,
523+
jedi_project: object | None = None,
502524
) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]:
503525
import jedi
504526

527+
project_root_path = project_root_path.resolve()
528+
project = jedi_project if jedi_project is not None else get_jedi_project(str(project_root_path))
505529
file_path_to_function_source = defaultdict(set)
506530
function_source_list: list[FunctionSource] = []
507531
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
508-
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
532+
script = jedi.Script(path=file_path, project=project)
509533
file_refs = script.get_names(all_scopes=True, definitions=False, references=True)
510534

535+
# Pre-group references by their parent function's qualified name for O(1) lookup
536+
refs_by_parent: dict[str, list[Name]] = defaultdict(list)
537+
for ref in file_refs:
538+
if not ref.full_name:
539+
continue
540+
try:
541+
parent = ref.parent()
542+
if parent is None or parent.type != "function":
543+
continue
544+
refs_by_parent[get_qualified_name(parent.module_name, parent.full_name)].append(ref)
545+
except (AttributeError, ValueError):
546+
continue
547+
511548
for qualified_function_name in qualified_function_names:
512-
names = [
513-
ref
514-
for ref in file_refs
515-
if ref.full_name and belongs_to_function_qualified(ref, qualified_function_name)
516-
]
549+
names = refs_by_parent.get(qualified_function_name, [])
517550
for name in names:
518551
try:
519552
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
@@ -1103,7 +1136,7 @@ def _resolve_imported_class_reference(
11031136
module_name, class_name = resolved_name.rsplit(".", 1)
11041137
try:
11051138
script_code = f"from {module_name} import {class_name}"
1106-
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
1139+
script = jedi.Script(script_code, project=get_jedi_project(str(project_root_path)))
11071140
definitions = script.goto(1, len(f"from {module_name} import ") + len(class_name), follow_imports=True)
11081141
except Exception:
11091142
return None
@@ -1263,7 +1296,7 @@ def extract_parameter_type_constructors(
12631296
def append_type_context(type_name: str, module_name: str, *, transitive: bool = False) -> None:
12641297
try:
12651298
script_code = f"from {module_name} import {type_name}"
1266-
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
1299+
script = jedi.Script(script_code, project=get_jedi_project(str(project_root_path)))
12671300
definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True)
12681301
if not definitions:
12691302
return
@@ -1429,7 +1462,7 @@ def extract_class_and_bases(
14291462
continue
14301463
try:
14311464
test_code = f"import {module_name}"
1432-
script = jedi.Script(test_code, project=jedi.Project(path=project_root_path))
1465+
script = jedi.Script(test_code, project=get_jedi_project(str(project_root_path)))
14331466
completions = script.goto(1, len(test_code))
14341467

14351468
if not completions:

0 commit comments

Comments
 (0)