diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 62d6896fe..93881f0ad 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -384,22 +384,32 @@ def closest_matching_file_function_name( qualified_fn_to_find_lower = qualified_fn_to_find.lower() - # Cache levenshtein_distance locally for improved lookup speed - _levenshtein = levenshtein_distance + # Prepare a flattened list of candidates with precomputed lowercase names and lengths + # to avoid repeated .lower() and len() calls inside the main loop. + candidates: list[tuple[Path, FunctionToOptimize, str, int]] = [] for file_path, functions in found_fns.items(): for function in functions: - # Compare either full qualified name or just function name - fn_name = function.qualified_name.lower() - # If the absolute length difference is already >= min_distance, skip calculation - if abs(len(qualified_fn_to_find_lower) - len(fn_name)) >= min_distance: - continue - dist = _levenshtein(qualified_fn_to_find_lower, fn_name) + fn_name_lower = function.qualified_name.lower() + candidates.append((file_path, function, fn_name_lower, len(fn_name_lower))) + + # Use a bounded levenshtein variant here to early-exit when the distance cannot be + # smaller than the current min_distance. This avoids expensive full-distance calculations. + _bounded = _bounded_levenshtein + + target_len = len(qualified_fn_to_find_lower) - if dist < min_distance: - min_distance = dist - closest_match = function - closest_file = file_path + for file_path, function, fn_name, fn_len in candidates: + # If the absolute length difference is already >= min_distance, skip calculation + if abs(target_len - fn_len) >= min_distance: + continue + # compute bounded distance; if result is >= min_distance it won't improve + dist = _bounded(qualified_fn_to_find_lower, fn_name, min_distance - 1) + + if dist < min_distance: + min_distance = dist + closest_match = function + closest_file = file_path if closest_match is not None and closest_file is not None: return closest_file, closest_match @@ -936,3 +946,78 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list file_path in submodule_paths or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths) ) + + +def _bounded_levenshtein(s1: str, s2: str, max_distance: int) -> int: + """Compute Levenshtein distance but stop when distance exceeds max_distance. + + Returns a value > max_distance when the true distance is > max_distance. + """ + # Fast path equal + if s1 == s2: + return 0 + + # Ensure s1 is the shorter + if len(s1) > len(s2): + s1, s2 = s2, s1 + + n = len(s1) + m = len(s2) + + # If length difference already exceeds max allowed distance, we can exit + if m - n > max_distance: + return max_distance + 1 + + # Initialize previous row: distances from empty s2 prefix to s1 prefixes + previous = list(range(n + 1)) + current = [0] * (n + 1) + + # We will only compute values within a "band" [start..end] for each row + for i in range(1, m + 1): + # Position in s2 is i (1-based for DP), character is s2[i-1] + char2 = s2[i - 1] + # Compute band boundaries (1-based indices for s1 positions) + start = max(1, i - max_distance) + end = min(n, i + max_distance) + + # If start > end the band is empty -> distance exceeds max_distance + if start > end: + return max_distance + 1 + + # Set current[0] for the empty prefix of s1 + current[0] = i + + # Fill left part outside band with large values + for k in range(1, start): + current[k] = max_distance + 1 + + # Compute values inside the band + for j in range(start, end + 1): + if s1[j - 1] == char2: + current[j] = previous[j - 1] + else: + # deletion = previous[j] + 1 + # insertion = current[j - 1] + 1 + # substitution = previous[j - 1] + 1 + a = previous[j] + 1 + b = current[j - 1] + 1 + c = previous[j - 1] + 1 + # Fast min of three + t = min(b, a) + current[j] = min(t, c) + + # Fill right part outside band with large values + for k in range(end + 1, n + 1): + current[k] = max_distance + 1 + + # Swap rows + previous, current = current, previous + + # Early exit: if the minimum value in the active band is greater than max_distance + # then the final distance must exceed max_distance + # (band width is small: at most 2*max_distance+1). + row_min = min(previous[start : end + 1]) + if row_min > max_distance: + return max_distance + 1 + + return previous[n]