Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 92 additions & 55 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def __init__(self, project_root_path: Path) -> None:
)

self.memory_cache = {}
self.pending_rows: list[tuple[str, str, str, str, str, str, int | str, int, int]] = []
self.writes_enabled = True

def insert_test(
self,
Expand All @@ -134,10 +136,8 @@ def insert_test(
col_number: int,
) -> None:
test_type_value = test_type.value if hasattr(test_type, "value") else test_type
self.cur.execute(
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
self.pending_rows.append(
(
self.project_root_path,
file_path,
file_hash,
qualified_name_with_modules_from_root,
Expand All @@ -147,9 +147,47 @@ def insert_test(
test_type_value,
line_number,
col_number,
),
)
)
self.connection.commit()

def flush(self) -> None:
if not self.writes_enabled or not self.pending_rows:
return
try:
self.cur.executemany(
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
[
(
self.project_root_path,
file_path,
file_hash,
qualified_name_with_modules_from_root,
function_name,
test_class,
test_function,
test_type_value,
line_number,
col_number,
)
for (
file_path,
file_hash,
qualified_name_with_modules_from_root,
function_name,
test_class,
test_function,
test_type_value,
line_number,
col_number,
) in self.pending_rows
],
)
self.connection.commit()
except sqlite3.OperationalError as e:
logger.debug(f"Failed to persist discovered test cache, disabling cache writes: {e}")
self.writes_enabled = False
finally:
self.pending_rows.clear()

def get_function_to_test_map_for_file(
self, file_path: str, file_hash: str
Expand Down Expand Up @@ -196,6 +234,7 @@ def compute_file_hash(path: Path) -> str:
return h.hexdigest()

def close(self) -> None:
self.flush()
self.cur.close()
self.connection.close()

Expand Down Expand Up @@ -833,6 +872,10 @@ def process_test_files(
function_to_test_map = defaultdict(set)
num_discovered_tests = 0
num_discovered_replay_tests = 0
functions_to_optimize_by_name: dict[str, list[FunctionToOptimize]] = defaultdict(list)
if functions_to_optimize:
for function_to_optimize in functions_to_optimize:
functions_to_optimize_by_name[function_to_optimize.function_name].append(function_to_optimize)

# Set up sys_path for Jedi to resolve imports correctly
import sys
Expand Down Expand Up @@ -875,8 +918,8 @@ def process_test_files(
test_functions = set()

all_names = script.get_names(all_scopes=True, references=True)
all_defs = script.get_names(all_scopes=True, definitions=True)
all_names_top = script.get_names(all_scopes=True)
all_defs = [name for name in all_names if name.is_definition()]

top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
Expand Down Expand Up @@ -951,10 +994,9 @@ def process_test_files(

test_function_names_set = set(test_functions_by_name.keys())
relevant_names = []

names_with_full_name = [name for name in all_names if name.full_name is not None]

for name in names_with_full_name:
for name in all_names:
if name.full_name is None:
continue
match = FUNCTION_NAME_REGEX.search(name.full_name)
if match and match.group(1) in test_function_names_set:
relevant_names.append((name, match.group(1)))
Expand All @@ -969,56 +1011,50 @@ def process_test_files(
if not definition or definition[0].type != "function":
# Fallback: Try to match against functions_to_optimize when Jedi can't resolve
# This handles cases where Jedi fails with pytest fixtures
if functions_to_optimize and name.name:
for func_to_opt in functions_to_optimize:
if functions_to_optimize_by_name and name.name:
for func_to_opt in functions_to_optimize_by_name.get(name.name, []):
# Check if this unresolved name matches a function we're looking for
if func_to_opt.function_name == name.name:
# Check if the test file imports the class/module containing this function
qualified_name_with_modules = func_to_opt.qualified_name_with_modules_from_root(
project_root_path
)
qualified_name_with_modules = func_to_opt.qualified_name_with_modules_from_root(
project_root_path
)

# Only add if this test actually tests the function we're optimizing
for test_func in test_functions_by_name[scope]:
if test_func.parameters is not None:
if test_framework == "pytest":
scope_test_function = (
f"{test_func.function_name}[{test_func.parameters}]"
)
else: # unittest
scope_test_function = (
f"{test_func.function_name}_{test_func.parameters}"
)
else:
scope_test_function = test_func.function_name

function_to_test_map[qualified_name_with_modules].add(
FunctionCalledInTest(
tests_in_file=TestsInFile(
test_file=test_file,
test_class=test_func.test_class,
test_function=scope_test_function,
test_type=test_func.test_type,
),
position=CodePosition(line_no=name.line, col_no=name.column),
)
)
tests_cache.insert_test(
file_path=str(test_file),
file_hash=file_hash,
qualified_name_with_modules_from_root=qualified_name_with_modules,
function_name=scope,
test_class=test_func.test_class or "",
test_function=scope_test_function,
test_type=test_func.test_type,
line_number=name.line,
col_number=name.column,
# Only add if this test actually tests the function we're optimizing
for test_func in test_functions_by_name[scope]:
if test_func.parameters is not None:
if test_framework == "pytest":
scope_test_function = f"{test_func.function_name}[{test_func.parameters}]"
else: # unittest
scope_test_function = f"{test_func.function_name}_{test_func.parameters}"
else:
scope_test_function = test_func.function_name

function_to_test_map[qualified_name_with_modules].add(
FunctionCalledInTest(
tests_in_file=TestsInFile(
test_file=test_file,
test_class=test_func.test_class,
test_function=scope_test_function,
test_type=test_func.test_type,
),
position=CodePosition(line_no=name.line, col_no=name.column),
)
)
tests_cache.insert_test(
file_path=str(test_file),
file_hash=file_hash,
qualified_name_with_modules_from_root=qualified_name_with_modules,
function_name=scope,
test_class=test_func.test_class or "",
test_function=scope_test_function,
test_type=test_func.test_type,
line_number=name.line,
col_number=name.column,
)

if test_func.test_type == TestType.REPLAY_TEST:
num_discovered_replay_tests += 1
if test_func.test_type == TestType.REPLAY_TEST:
num_discovered_replay_tests += 1

num_discovered_tests += 1
num_discovered_tests += 1
continue
definition_obj = definition[0]
definition_path = str(definition_obj.module_path)
Expand Down Expand Up @@ -1074,6 +1110,7 @@ def process_test_files(
logger.debug(str(e))
continue

tests_cache.flush()
progress.advance(task_id)

tests_cache.close()
Expand Down
Loading
Loading