Skip to content

Commit 77cddec

Browse files
committed
progress on instrumentation of java code
1 parent 60fefbc commit 77cddec

5 files changed

Lines changed: 615 additions & 69 deletions

File tree

codeflash/languages/java/instrumentation.py

Lines changed: 261 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def instrument_existing_test(
119119
120120
For Java, this:
121121
1. Renames the class to match the new file name (Java requires class name = file name)
122-
2. Adds timing instrumentation to test methods (for performance mode)
122+
2. For behavior mode: adds timing instrumentation that writes to SQLite
123+
3. For performance mode: adds timing instrumentation with stdout markers
123124
124125
Args:
125126
test_path: Path to the test file.
@@ -157,14 +158,21 @@ def instrument_existing_test(
157158
replacement = rf'\1class {new_class_name}'
158159
modified_source = re.sub(pattern, replacement, source)
159160

160-
# For performance mode, add timing instrumentation to test methods
161+
# Add timing instrumentation to test methods
161162
# Use original class name (without suffix) in timing markers for consistency with Python
162163
if mode == "performance":
163164
modified_source = _add_timing_instrumentation(
164165
modified_source,
165166
original_class_name, # Use original name in markers, not the renamed class
166167
func_name,
167168
)
169+
else:
170+
# Behavior mode: add timing instrumentation that also writes to SQLite
171+
modified_source = _add_behavior_instrumentation(
172+
modified_source,
173+
original_class_name,
174+
func_name,
175+
)
168176

169177
logger.debug(
170178
"Java %s testing for %s: renamed class %s -> %s",
@@ -177,6 +185,257 @@ def instrument_existing_test(
177185
return True, modified_source
178186

179187

188+
def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) -> str:
189+
"""Add behavior instrumentation to test methods.
190+
191+
For behavior mode, this adds:
192+
1. Gson import for JSON serialization
193+
2. SQLite database connection setup
194+
3. Function call wrapping to capture return values
195+
4. SQLite insert with serialized return values
196+
197+
Args:
198+
source: The test source code.
199+
class_name: Name of the test class.
200+
func_name: Name of the function being tested.
201+
202+
Returns:
203+
Instrumented source code.
204+
205+
"""
206+
# Add necessary imports at the top of the file
207+
import_statements = [
208+
"import java.sql.Connection;",
209+
"import java.sql.DriverManager;",
210+
"import java.sql.PreparedStatement;",
211+
"import java.sql.Statement;",
212+
"import com.google.gson.Gson;",
213+
"import com.google.gson.GsonBuilder;",
214+
]
215+
216+
# Find position to insert imports (after package, before class)
217+
lines = source.split('\n')
218+
result = []
219+
imports_added = False
220+
i = 0
221+
222+
while i < len(lines):
223+
line = lines[i]
224+
stripped = line.strip()
225+
226+
# Add imports after the last existing import or before the class declaration
227+
if not imports_added:
228+
if stripped.startswith('import '):
229+
result.append(line)
230+
i += 1
231+
# Find end of imports
232+
while i < len(lines) and lines[i].strip().startswith('import '):
233+
result.append(lines[i])
234+
i += 1
235+
# Add our imports
236+
for imp in import_statements:
237+
if imp not in source:
238+
result.append(imp)
239+
imports_added = True
240+
continue
241+
elif stripped.startswith('public class') or stripped.startswith('class'):
242+
# No imports found, add before class
243+
for imp in import_statements:
244+
result.append(imp)
245+
result.append("")
246+
imports_added = True
247+
248+
result.append(line)
249+
i += 1
250+
251+
# Now add timing and SQLite instrumentation to test methods
252+
source = '\n'.join(result)
253+
lines = source.split('\n')
254+
result = []
255+
i = 0
256+
iteration_counter = 0
257+
258+
while i < len(lines):
259+
line = lines[i]
260+
stripped = line.strip()
261+
262+
# Look for @Test annotation
263+
if stripped.startswith('@Test'):
264+
result.append(line)
265+
i += 1
266+
267+
# Collect any additional annotations
268+
while i < len(lines) and lines[i].strip().startswith('@'):
269+
result.append(lines[i])
270+
i += 1
271+
272+
# Now find the method signature and opening brace
273+
method_lines = []
274+
while i < len(lines):
275+
method_lines.append(lines[i])
276+
if '{' in lines[i]:
277+
break
278+
i += 1
279+
280+
# Add the method signature lines
281+
for ml in method_lines:
282+
result.append(ml)
283+
i += 1
284+
285+
# We're now inside the method body
286+
iteration_counter += 1
287+
iter_id = iteration_counter
288+
289+
# Detect indentation
290+
method_sig_line = method_lines[-1] if method_lines else ""
291+
base_indent = len(method_sig_line) - len(method_sig_line.lstrip())
292+
indent = " " * (base_indent + 4)
293+
294+
# Collect method body until we find matching closing brace
295+
brace_depth = 1
296+
body_lines = []
297+
298+
while i < len(lines) and brace_depth > 0:
299+
body_line = lines[i]
300+
for ch in body_line:
301+
if ch == '{':
302+
brace_depth += 1
303+
elif ch == '}':
304+
brace_depth -= 1
305+
306+
if brace_depth > 0:
307+
body_lines.append(body_line)
308+
i += 1
309+
else:
310+
# We've hit the closing brace
311+
i += 1
312+
break
313+
314+
# Wrap function calls to capture return values
315+
# Look for patterns like: obj.funcName(args) or new Class().funcName(args)
316+
call_counter = 0
317+
wrapped_body_lines = []
318+
319+
# Use regex to find method calls with the target function
320+
# Pattern matches: receiver.funcName(args) where receiver can be:
321+
# - identifier (counter, calc, etc.)
322+
# - new ClassName()
323+
# - new ClassName(args)
324+
# - this
325+
method_call_pattern = re.compile(
326+
rf'((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)',
327+
re.MULTILINE
328+
)
329+
330+
for body_line in body_lines:
331+
# Check if this line contains a call to the target function
332+
if func_name in body_line and '(' in body_line:
333+
line_indent = len(body_line) - len(body_line.lstrip())
334+
line_indent_str = " " * line_indent
335+
336+
# Find all matches in the line
337+
matches = list(method_call_pattern.finditer(body_line))
338+
if matches:
339+
# Process matches in reverse order to maintain correct positions
340+
new_line = body_line
341+
for match in reversed(matches):
342+
call_counter += 1
343+
var_name = f"_cf_result{iter_id}_{call_counter}"
344+
full_call = match.group(0) # e.g., "new StringUtils().reverse(\"hello\")"
345+
346+
# Replace this occurrence with the variable
347+
new_line = new_line[:match.start()] + var_name + new_line[match.end():]
348+
349+
# Insert capture line
350+
capture_line = f"{line_indent_str}Object {var_name} = {full_call};"
351+
wrapped_body_lines.append(capture_line)
352+
353+
wrapped_body_lines.append(new_line)
354+
else:
355+
wrapped_body_lines.append(body_line)
356+
else:
357+
wrapped_body_lines.append(body_line)
358+
359+
# Build the serialized return value expression
360+
# If we captured any calls, serialize the last one; otherwise serialize null
361+
if call_counter > 0:
362+
result_var = f"_cf_result{iter_id}_{call_counter}"
363+
serialize_expr = f'new GsonBuilder().serializeNulls().create().toJson({result_var})'
364+
else:
365+
serialize_expr = '"null"'
366+
367+
# Add behavior instrumentation code
368+
behavior_start_code = [
369+
f"{indent}// Codeflash behavior instrumentation",
370+
f'{indent}int _cf_loop{iter_id} = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));',
371+
f"{indent}int _cf_iter{iter_id} = {iter_id};",
372+
f'{indent}String _cf_mod{iter_id} = "{class_name}";',
373+
f'{indent}String _cf_cls{iter_id} = "{class_name}";',
374+
f'{indent}String _cf_fn{iter_id} = "{func_name}";',
375+
f'{indent}String _cf_outputFile{iter_id} = System.getenv("CODEFLASH_OUTPUT_FILE");',
376+
f'{indent}String _cf_testIteration{iter_id} = System.getenv("CODEFLASH_TEST_ITERATION");',
377+
f'{indent}if (_cf_testIteration{iter_id} == null) _cf_testIteration{iter_id} = "0";',
378+
f'{indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + "######$!");',
379+
f"{indent}long _cf_start{iter_id} = System.nanoTime();",
380+
f"{indent}String _cf_serializedResult{iter_id} = null;",
381+
f"{indent}try {{",
382+
]
383+
result.extend(behavior_start_code)
384+
385+
# Add the wrapped body lines with extra indentation
386+
for bl in wrapped_body_lines:
387+
result.append(" " + bl)
388+
389+
# Add serialization after the body (before finally)
390+
result.append(f"{indent} _cf_serializedResult{iter_id} = {serialize_expr};")
391+
392+
# Add finally block with SQLite write
393+
method_close_indent = " " * base_indent
394+
behavior_end_code = [
395+
f"{indent}}} finally {{",
396+
f"{indent} long _cf_end{iter_id} = System.nanoTime();",
397+
f"{indent} long _cf_dur{iter_id} = _cf_end{iter_id} - _cf_start{iter_id};",
398+
f'{indent} System.out.println("!######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + ":" + _cf_dur{iter_id} + "######!");',
399+
f"{indent} // Write to SQLite if output file is set",
400+
f"{indent} if (_cf_outputFile{iter_id} != null && !_cf_outputFile{iter_id}.isEmpty()) {{",
401+
f"{indent} try {{",
402+
f"{indent} Class.forName(\"org.sqlite.JDBC\");",
403+
f"{indent} try (Connection _cf_conn{iter_id} = DriverManager.getConnection(\"jdbc:sqlite:\" + _cf_outputFile{iter_id})) {{",
404+
f"{indent} try (Statement _cf_stmt{iter_id} = _cf_conn{iter_id}.createStatement()) {{",
405+
f'{indent} _cf_stmt{iter_id}.execute("CREATE TABLE IF NOT EXISTS test_results (" +',
406+
f'{indent} "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " +',
407+
f'{indent} "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " +',
408+
f'{indent} "runtime INTEGER, return_value TEXT, verification_type TEXT)");',
409+
f"{indent} }}",
410+
f'{indent} String _cf_sql{iter_id} = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";',
411+
f"{indent} try (PreparedStatement _cf_pstmt{iter_id} = _cf_conn{iter_id}.prepareStatement(_cf_sql{iter_id})) {{",
412+
f"{indent} _cf_pstmt{iter_id}.setString(1, _cf_mod{iter_id});",
413+
f"{indent} _cf_pstmt{iter_id}.setString(2, _cf_cls{iter_id});",
414+
f'{indent} _cf_pstmt{iter_id}.setString(3, "{class_name}Test");',
415+
f"{indent} _cf_pstmt{iter_id}.setString(4, _cf_fn{iter_id});",
416+
f"{indent} _cf_pstmt{iter_id}.setInt(5, _cf_loop{iter_id});",
417+
f'{indent} _cf_pstmt{iter_id}.setString(6, _cf_iter{iter_id} + "_" + _cf_testIteration{iter_id});',
418+
f"{indent} _cf_pstmt{iter_id}.setLong(7, _cf_dur{iter_id});",
419+
f"{indent} _cf_pstmt{iter_id}.setString(8, _cf_serializedResult{iter_id});", # Serialized return value
420+
f'{indent} _cf_pstmt{iter_id}.setString(9, "function_call");',
421+
f"{indent} _cf_pstmt{iter_id}.executeUpdate();",
422+
f"{indent} }}",
423+
f"{indent} }}",
424+
f"{indent} }} catch (Exception _cf_e{iter_id}) {{",
425+
f'{indent} System.err.println("CodeflashHelper: SQLite error: " + _cf_e{iter_id}.getMessage());',
426+
f"{indent} }}",
427+
f"{indent} }}",
428+
f"{indent}}}",
429+
f"{method_close_indent}}}", # Method closing brace
430+
]
431+
result.extend(behavior_end_code)
432+
else:
433+
result.append(line)
434+
i += 1
435+
436+
return '\n'.join(result)
437+
438+
180439
def _add_timing_instrumentation(source: str, class_name: str, func_name: str) -> str:
181440
"""Add timing instrumentation to test methods.
182441

codeflash/languages/java/resources/CodeflashHelper.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package codeflash.runtime;
22

3+
import java.io.ByteArrayOutputStream;
34
import java.io.File;
5+
import java.io.ObjectOutputStream;
6+
import java.io.Serializable;
47
import java.sql.Connection;
58
import java.sql.DriverManager;
69
import java.sql.PreparedStatement;

codeflash/languages/java/test_runner.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pathlib import Path
1818
from typing import TYPE_CHECKING, Any
1919

20+
from codeflash.code_utils.code_utils import get_run_tmp_file
2021
from codeflash.languages.base import TestResult
2122
from codeflash.languages.java.build_tools import (
2223
find_maven_executable,
@@ -58,7 +59,8 @@ def run_behavioral_tests(
5859
"""Run behavioral tests for Java code.
5960
6061
This runs tests and captures behavior (inputs/outputs) for verification.
61-
For Java, verification is based on JUnit test pass/fail results.
62+
For Java, test results are written to a SQLite database via CodeflashHelper,
63+
and JUnit test pass/fail results serve as the primary verification mechanism.
6264
6365
Args:
6466
test_paths: TestFiles object or list of test file paths.
@@ -70,17 +72,21 @@ def run_behavioral_tests(
7072
candidate_index: Index of the candidate being tested.
7173
7274
Returns:
73-
Tuple of (result_xml_path, subprocess_result, coverage_path, config_path).
75+
Tuple of (result_xml_path, subprocess_result, sqlite_db_path, None).
7476
7577
"""
7678
project_root = project_root or cwd
7779

78-
# Set environment variables for timing instrumentation
80+
# Create SQLite database path for behavior capture - use standard path that parse_test_results expects
81+
sqlite_db_path = get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite"))
82+
83+
# Set environment variables for timing instrumentation and behavior capture
7984
run_env = os.environ.copy()
8085
run_env.update(test_env)
8186
run_env["CODEFLASH_LOOP_INDEX"] = "1" # Single loop for behavior tests
8287
run_env["CODEFLASH_MODE"] = "behavior"
8388
run_env["CODEFLASH_TEST_ITERATION"] = str(candidate_index)
89+
run_env["CODEFLASH_OUTPUT_FILE"] = str(sqlite_db_path) # SQLite output path
8490

8591
# Run Maven tests
8692
result = _run_maven_tests(
@@ -95,7 +101,8 @@ def run_behavioral_tests(
95101
surefire_dir = project_root / "target" / "surefire-reports"
96102
result_xml_path = _get_combined_junit_xml(surefire_dir, candidate_index)
97103

98-
return result_xml_path, result, None, None
104+
# Return sqlite_db_path as the third element (was None before)
105+
return result_xml_path, result, sqlite_db_path, None
99106

100107

101108
def run_benchmarking_tests(

0 commit comments

Comments
 (0)