Skip to content

Commit f123ee8

Browse files
committed
refactor: move generate_concolic_tests and normalize_code into LanguageSupport protocol
- Add generate_concolic_tests as optional protocol method (default no-op, Python overrides with CrossHair). Delete concolic_testing.py. - Remove normalize_code indirection: callers use protocol directly, PythonSupport calls PythonNormalizer without going through deduplicate_code.
1 parent 51c08ba commit f123ee8

5 files changed

Lines changed: 133 additions & 181 deletions

File tree

codeflash/code_utils/deduplicate_code.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,15 @@
1111

1212
from codeflash.code_utils.normalizers import get_normalizer
1313
from codeflash.languages import current_language
14-
from codeflash.languages.base import Language
1514

1615

17-
def normalize_code(
18-
code: str, remove_docstrings: bool = True, return_ast_dump: bool = False, language: str | None = None
19-
) -> str:
16+
def normalize_code(code: str, language: str | None = None) -> str:
2017
"""Normalize code by parsing, cleaning, and normalizing variable names.
2118
2219
Function names, class names, and parameters are preserved.
2320
2421
Args:
2522
code: Source code as string
26-
remove_docstrings: Whether to remove docstrings (Python only)
27-
return_ast_dump: Return AST dump instead of unparsed code (Python only)
2823
language: Language of the code. If None, uses the current session language.
2924
3025
Returns:
@@ -35,30 +30,10 @@ def normalize_code(
3530
language = current_language().value
3631

3732
try:
38-
normalizer = get_normalizer(language)
39-
40-
# Python has additional options
41-
if language == Language.PYTHON:
42-
if return_ast_dump:
43-
return normalizer.normalize_for_hash(code)
44-
return normalizer.normalize(code, remove_docstrings=remove_docstrings)
45-
46-
# For other languages, use standard normalization
47-
return normalizer.normalize(code)
33+
return get_normalizer(language).normalize(code)
4834
except ValueError:
49-
# Unknown language - fall back to basic normalization
5035
return _basic_normalize(code)
5136
except Exception:
52-
# Parsing error - try other languages or fall back
53-
if language == Language.PYTHON:
54-
# Try JavaScript as fallback
55-
try:
56-
js_normalizer = get_normalizer("javascript")
57-
js_result = js_normalizer.normalize(code)
58-
if js_result != _basic_normalize(code):
59-
return js_result
60-
except Exception:
61-
pass
6237
return _basic_normalize(code)
6338

6439

codeflash/languages/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,20 @@ def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict:
750750

751751
# === Test Execution ===
752752

753+
def generate_concolic_tests(
754+
self,
755+
test_cfg: TestConfig,
756+
project_root: Path,
757+
function_to_optimize: FunctionToOptimize,
758+
function_to_optimize_ast: Any,
759+
) -> tuple[dict, str]:
760+
"""Generate concolic tests for a function.
761+
762+
Default implementation returns empty results. Override for languages
763+
that support concolic testing (e.g. Python via CrossHair).
764+
"""
765+
return {}, ""
766+
753767
def run_behavioral_tests(
754768
self,
755769
test_paths: Any,

codeflash/languages/python/support.py

Lines changed: 110 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -574,21 +574,10 @@ def validate_syntax(self, source: str) -> bool:
574574
return False
575575

576576
def normalize_code(self, source: str) -> str:
577-
"""Normalize Python code for deduplication.
578-
579-
Removes comments, normalizes whitespace, and replaces variable names.
580-
581-
Args:
582-
source: Source code to normalize.
583-
584-
Returns:
585-
Normalized source code.
586-
587-
"""
588-
from codeflash.code_utils.deduplicate_code import normalize_code
577+
from codeflash.code_utils.normalizers import get_normalizer
589578

590579
try:
591-
return normalize_code(source, remove_docstrings=True, language=Language.PYTHON)
580+
return get_normalizer("python").normalize(source, remove_docstrings=True)
592581
except Exception:
593582
return source
594583

@@ -1077,3 +1066,111 @@ def run_line_profile_tests(
10771066
timeout=600,
10781067
)
10791068
return result_file_path, results
1069+
1070+
def generate_concolic_tests(
1071+
self, test_cfg: Any, project_root: Path, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: Any
1072+
) -> tuple[dict, str]:
1073+
import ast
1074+
import importlib.util
1075+
import subprocess
1076+
import tempfile
1077+
import time
1078+
1079+
from codeflash.cli_cmds.console import console
1080+
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
1081+
from codeflash.code_utils.shell_utils import make_env_with_project_root
1082+
from codeflash.discovery.discover_unit_tests import discover_unit_tests
1083+
from codeflash.languages.python.static_analysis.concolic_utils import (
1084+
clean_concolic_tests,
1085+
is_valid_concolic_test,
1086+
)
1087+
from codeflash.languages.python.static_analysis.static_analysis import has_typed_parameters
1088+
from codeflash.lsp.helpers import is_LSP_enabled
1089+
from codeflash.telemetry.posthog_cf import ph
1090+
from codeflash.verification.verification_utils import TestConfig
1091+
1092+
crosshair_available = importlib.util.find_spec("crosshair") is not None
1093+
1094+
start_time = time.perf_counter()
1095+
function_to_concolic_tests: dict = {}
1096+
concolic_test_suite_code = ""
1097+
1098+
if not crosshair_available:
1099+
logger.debug("Skipping concolic test generation (crosshair-tool is not installed)")
1100+
return function_to_concolic_tests, concolic_test_suite_code
1101+
1102+
if is_LSP_enabled():
1103+
logger.debug("Skipping concolic test generation in LSP mode")
1104+
return function_to_concolic_tests, concolic_test_suite_code
1105+
1106+
if (
1107+
test_cfg.concolic_test_root_dir
1108+
and isinstance(function_to_optimize_ast, ast.FunctionDef)
1109+
and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents)
1110+
):
1111+
logger.info("Generating concolic opcode coverage tests for the original code…")
1112+
console.rule()
1113+
try:
1114+
env = make_env_with_project_root(project_root)
1115+
cover_result = subprocess.run(
1116+
[
1117+
SAFE_SYS_EXECUTABLE,
1118+
"-m",
1119+
"crosshair",
1120+
"cover",
1121+
"--example_output_format=pytest",
1122+
"--per_condition_timeout=20",
1123+
".".join(
1124+
[
1125+
function_to_optimize.file_path.relative_to(project_root)
1126+
.with_suffix("")
1127+
.as_posix()
1128+
.replace("/", "."),
1129+
function_to_optimize.qualified_name,
1130+
]
1131+
),
1132+
],
1133+
capture_output=True,
1134+
text=True,
1135+
cwd=project_root,
1136+
check=False,
1137+
timeout=600,
1138+
env=env,
1139+
)
1140+
except subprocess.TimeoutExpired:
1141+
logger.debug("CrossHair Cover test generation timed out")
1142+
return function_to_concolic_tests, concolic_test_suite_code
1143+
1144+
if cover_result.returncode == 0:
1145+
generated_concolic_test: str = cover_result.stdout
1146+
if not is_valid_concolic_test(generated_concolic_test, project_root=str(project_root)):
1147+
logger.debug("CrossHair generated invalid test, skipping")
1148+
console.rule()
1149+
return function_to_concolic_tests, concolic_test_suite_code
1150+
concolic_test_suite_code = clean_concolic_tests(generated_concolic_test)
1151+
concolic_test_suite_dir = Path(tempfile.mkdtemp(dir=test_cfg.concolic_test_root_dir))
1152+
concolic_test_suite_path = concolic_test_suite_dir / "test_concolic_coverage.py"
1153+
concolic_test_suite_path.write_text(concolic_test_suite_code, encoding="utf8")
1154+
1155+
concolic_test_cfg = TestConfig(
1156+
tests_root=concolic_test_suite_dir,
1157+
tests_project_rootdir=test_cfg.concolic_test_root_dir,
1158+
project_root_path=project_root,
1159+
)
1160+
function_to_concolic_tests, num_discovered_concolic_tests, _ = discover_unit_tests(concolic_test_cfg)
1161+
logger.info(
1162+
"Created %d concolic unit test case%s ",
1163+
num_discovered_concolic_tests,
1164+
"s" if num_discovered_concolic_tests != 1 else "",
1165+
)
1166+
console.rule()
1167+
ph("cli-optimize-concolic-tests", {"num_tests": num_discovered_concolic_tests})
1168+
1169+
else:
1170+
logger.debug(
1171+
"Error running CrossHair Cover%s", ": " + cover_result.stderr if cover_result.stderr else "."
1172+
)
1173+
console.rule()
1174+
end_time = time.perf_counter()
1175+
logger.debug("Generated concolic tests in %.2f seconds", end_time - start_time)
1176+
return function_to_concolic_tests, concolic_test_suite_code

codeflash/optimization/function_optimizer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
EffortLevel,
5959
get_effort_value,
6060
)
61-
from codeflash.code_utils.deduplicate_code import normalize_code
6261
from codeflash.code_utils.env_utils import get_pr_number
6362
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
6463
from codeflash.code_utils.git_utils import git_root_dir
@@ -104,7 +103,6 @@
104103
)
105104
from codeflash.result.explanation import Explanation
106105
from codeflash.telemetry.posthog_cf import ph
107-
from codeflash.verification.concolic_testing import generate_concolic_tests
108106
from codeflash.verification.equivalence import compare_test_results
109107
from codeflash.verification.parse_test_output import parse_concurrency_metrics, parse_test_results
110108
from codeflash.verification.verification_utils import get_test_file_path
@@ -965,7 +963,9 @@ def select_best_optimization(
965963
runtimes_list = []
966964

967965
for valid_opt in eval_ctx.valid_optimizations:
968-
valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip())
966+
valid_opt_normalized_code = self.language_support.normalize_code(
967+
valid_opt.candidate.source_code.flat.strip()
968+
)
969969
new_candidate_with_shorter_code = OptimizedCandidate(
970970
source_code=eval_ctx.ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"],
971971
optimization_id=valid_opt.candidate.optimization_id,
@@ -1072,7 +1072,7 @@ def process_single_candidate(
10721072

10731073
candidate = candidate_node.candidate
10741074

1075-
normalized_code = normalize_code(candidate.source_code.flat.strip())
1075+
normalized_code = self.language_support.normalize_code(candidate.source_code.flat.strip())
10761076

10771077
if normalized_code == normalized_original:
10781078
logger.info(f"h3|Candidate {candidate_index}/{total_candidates}: Identical to original code, skipping.")
@@ -1284,7 +1284,7 @@ def determine_best_candidate(
12841284
self.future_adaptive_optimizations,
12851285
)
12861286
candidate_index = 0
1287-
normalized_original = normalize_code(code_context.read_writable_code.flat.strip())
1287+
normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip())
12881288

12891289
# Process candidates using queue-based approach
12901290
while not processor.is_done():
@@ -1679,9 +1679,9 @@ def generate_tests(
16791679
future_concolic_tests = None
16801680
else:
16811681
future_concolic_tests = self.executor.submit(
1682-
generate_concolic_tests,
1682+
self.language_support.generate_concolic_tests,
16831683
self.test_cfg,
1684-
self.args,
1684+
self.args.project_root,
16851685
self.function_to_optimize,
16861686
self.function_to_optimize_ast,
16871687
)

0 commit comments

Comments
 (0)