diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 63c83149f..3b5b1c74c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -95,7 +95,8 @@ jobs: 'codeflash/languages/java/' 'codeflash/languages/base.py' \ 'codeflash/languages/registry.py' 'codeflash/optimization/' \ 'codeflash/verification/' 'codeflash-java-runtime/' \ - 'code_to_optimize/java/' 'tests/scripts/end_to_end_test_java*' + 'code_to_optimize/java/' 'tests/scripts/end_to_end_test_java*' \ + 'tests/test_languages/fixtures/java_tracer_e2e/' env: MERGE_BASE: ${{ steps.merge_base.outputs.sha }} @@ -257,7 +258,7 @@ jobs: - name: init-optimization script: end_to_end_test_init_optimization.py expected_improvement: 10 - environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }} + environment: ${{ ((github.event_name == 'workflow_dispatch' && github.actor != 'misrasaurabh1' && github.actor != 'KRRT7') || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }} runs-on: ubuntu-latest env: CODEFLASH_AIS_SERVER: prod @@ -344,7 +345,7 @@ jobs: script: end_to_end_test_js_ts_class.py js_project_dir: code_to_optimize/js/code_to_optimize_ts expected_improvement: 30 - environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }} + environment: ${{ ((github.event_name == 'workflow_dispatch' && github.actor != 'misrasaurabh1' && github.actor != 'KRRT7') || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }} runs-on: ubuntu-latest env: CODEFLASH_AIS_SERVER: prod @@ -424,7 +425,7 @@ jobs: script: end_to_end_test_java_void_optimization.py expected_improvement: 70 remove_git: true - environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }} + environment: ${{ ((github.event_name == 'workflow_dispatch' && github.actor != 'misrasaurabh1' && github.actor != 'KRRT7') || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }} runs-on: ubuntu-latest env: CODEFLASH_AIS_SERVER: prod @@ -435,6 +436,7 @@ jobs: RETRY_DELAY: 5 EXPECTED_IMPROVEMENT_PCT: ${{ matrix.expected_improvement }} CODEFLASH_END_TO_END: 1 + CODEFLASH_LOOPING_TIME: 5 steps: - uses: actions/checkout@v6 with: @@ -468,7 +470,15 @@ jobs: - name: Install dependencies run: uv sync + - name: Cache codeflash-runtime JAR + id: runtime-jar-cache + uses: actions/cache@v4 + with: + path: ~/.m2/repository/io/codeflash + key: codeflash-runtime-${{ hashFiles('codeflash-java-runtime/pom.xml', 'codeflash-java-runtime/src/**') }} + - name: Build and install codeflash-runtime JAR + if: steps.runtime-jar-cache.outputs.cache-hit != 'true' run: | cd codeflash-java-runtime mvn install -q -DskipTests diff --git a/.github/workflows/codeflash-optimize.yaml b/.github/workflows/codeflash-optimize.yaml index 9884665da..ab08aa1f8 100644 --- a/.github/workflows/codeflash-optimize.yaml +++ b/.github/workflows/codeflash-optimize.yaml @@ -43,4 +43,4 @@ jobs: - name: ⚡️Codeflash Optimization id: optimize_code run: | - uv run codeflash --benchmark --testgen-review \ No newline at end of file + uv run codeflash --benchmark --testgen-review --no-pr \ No newline at end of file diff --git a/.github/workflows/java-e2e.yaml b/.github/workflows/java-e2e.yaml new file mode 100644 index 000000000..0bfc979b6 --- /dev/null +++ b/.github/workflows/java-e2e.yaml @@ -0,0 +1,77 @@ +name: Java E2E Tests +on: + workflow_dispatch: + +jobs: + e2e-java: + strategy: + fail-fast: false + matrix: + include: + - name: java-fibonacci-nogit + script: end_to_end_test_java_fibonacci.py + expected_improvement: 70 + remove_git: true + - name: java-tracer + script: end_to_end_test_java_tracer.py + expected_improvement: 10 + - name: java-void-optimization-nogit + script: end_to_end_test_java_void_optimization.py + expected_improvement: 70 + remove_git: true + runs-on: ubuntu-latest + env: + CODEFLASH_AIS_SERVER: prod + POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }} + CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }} + COLUMNS: 110 + MAX_RETRIES: 3 + RETRY_DELAY: 5 + EXPECTED_IMPROVEMENT_PCT: ${{ matrix.expected_improvement }} + CODEFLASH_END_TO_END: 1 + CODEFLASH_LOOPING_TIME: 5 + steps: + - uses: actions/checkout@v6 + + - name: Set up JDK 11 + uses: actions/setup-java@v5 + with: + java-version: '11' + distribution: 'temurin' + cache: maven + + - name: Install uv + uses: astral-sh/setup-uv@v8.0.0 + with: + python-version: 3.11.6 + enable-cache: true + + - name: Install dependencies + run: uv sync + + - name: Cache codeflash-runtime JAR + id: runtime-jar-cache + uses: actions/cache@v4 + with: + path: ~/.m2/repository/io/codeflash + key: codeflash-runtime-${{ hashFiles('codeflash-java-runtime/pom.xml', 'codeflash-java-runtime/src/**') }} + + - name: Build and install codeflash-runtime JAR + if: steps.runtime-jar-cache.outputs.cache-hit != 'true' + run: | + cd codeflash-java-runtime + mvn install -q -DskipTests + + - name: Remove .git + if: matrix.remove_git + run: | + if [ -d ".git" ]; then + sudo rm -rf .git + echo ".git directory removed." + else + echo ".git directory does not exist." + exit 1 + fi + + - name: Run E2E test + run: uv run python tests/scripts/${{ matrix.script }} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/Serializer.java b/codeflash-java-runtime/src/main/java/com/codeflash/Serializer.java index 80d400935..e1c177ac9 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/Serializer.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/Serializer.java @@ -6,7 +6,6 @@ import com.esotericsoftware.kryo.util.DefaultInstantiatorStrategy; import org.objenesis.strategy.StdInstantiatorStrategy; -import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.io.OutputStream; import java.lang.reflect.Field; @@ -36,7 +35,11 @@ public final class Serializer { private static final int MAX_COLLECTION_SIZE = 1000; private static final int BUFFER_SIZE = 4096; - // Thread-local Kryo instances (Kryo is not thread-safe) + // Thread-local Kryo, Output, and IdentityHashMap instances for reuse + private static final ThreadLocal OUTPUT = ThreadLocal.withInitial(() -> new Output(BUFFER_SIZE, -1)); + private static final ThreadLocal> SEEN = + ThreadLocal.withInitial(IdentityHashMap::new); + private static final ThreadLocal KRYO = ThreadLocal.withInitial(() -> { Kryo kryo = new Kryo(); kryo.setRegistrationRequired(false); @@ -89,10 +92,78 @@ private Serializer() { * @return Serialized bytes (may contain KryoPlaceholder for unserializable parts) */ public static byte[] serialize(Object obj) { - Object processed = recursiveProcess(obj, new IdentityHashMap<>(), 0, ""); + // Fast path: if args are all safe types, skip recursive processing entirely + if (obj instanceof Object[] && isSafeArgs((Object[]) obj)) { + return directSerialize(obj); + } + + IdentityHashMap seen = SEEN.get(); + seen.clear(); + Object processed = recursiveProcess(obj, seen, 0, ""); return directSerialize(processed); } + /** + * Attempt fast-path serialization for args that are all known-safe types. + * Returns serialized bytes if all args are safe, or null if the slow path is needed. + * Callers can use this to avoid executor submission overhead for simple arguments. + */ + public static byte[] serializeFast(Object obj) { + if (obj instanceof Object[] && isSafeArgs((Object[]) obj)) { + return directSerialize(obj); + } + return null; + } + + /** + * Check if all elements of an args array can be serialized directly without recursive processing. + */ + private static boolean isSafeArgs(Object[] args) { + for (Object arg : args) { + if (!isSafeForDirectSerialization(arg)) { + return false; + } + } + return true; + } + + /** + * Check if an object is safe to serialize directly without recursive processing. + * Covers: null, simple types, primitive arrays, and safe containers (up to 3 levels deep). + */ + private static boolean isSafeForDirectSerialization(Object obj) { + return isSafeForDirectSerialization(obj, 3); + } + + private static boolean isSafeForDirectSerialization(Object obj, int depthLeft) { + if (obj == null || isSimpleType(obj)) { + return true; + } + if (depthLeft <= 0) { + return false; + } + Class clazz = obj.getClass(); + if (clazz.isArray() && clazz.getComponentType().isPrimitive()) { + return true; + } + if (isSafeContainerType(clazz)) { + if (obj instanceof Collection) { + for (Object item : (Collection) obj) { + if (!isSafeForDirectSerialization(item, depthLeft - 1)) return false; + } + return true; + } + if (obj instanceof Map) { + for (Map.Entry e : ((Map) obj).entrySet()) { + if (!isSafeForDirectSerialization(e.getKey(), depthLeft - 1) || + !isSafeForDirectSerialization(e.getValue(), depthLeft - 1)) return false; + } + return true; + } + } + return false; + } + /** * Deserialize bytes back to an object. * The returned object may contain KryoPlaceholder instances for parts @@ -141,14 +212,15 @@ public static byte[] serializeException(Throwable error) { /** * Direct serialization without recursive processing. + * Reuses a ThreadLocal Output buffer to avoid per-call allocation. */ private static byte[] directSerialize(Object obj) { Kryo kryo = KRYO.get(); - ByteArrayOutputStream baos = new ByteArrayOutputStream(BUFFER_SIZE); - try (Output output = new Output(baos)) { - kryo.writeClassAndObject(output, obj); - } - return baos.toByteArray(); + Output output = OUTPUT.get(); + output.reset(); + kryo.writeClassAndObject(output, obj); + output.flush(); + return output.toBytes(); } /** @@ -201,37 +273,23 @@ private static Object recursiveProcess(Object obj, IdentityHashMap map = (Map) obj; - if (containsOnlySimpleTypes(map)) { - // Simple map - try direct serialization to preserve full size - byte[] serialized = tryDirectSerialize(obj); - if (serialized != null) { - try { - deserialize(serialized); - return obj; // Success - return original - } catch (Exception e) { - // Fall through to recursive handling - } - } + if (isSafeContainerType(clazz) && containsOnlySimpleTypes(map)) { + return obj; } return handleMap(map, seen, depth, path); } if (obj instanceof Collection) { Collection collection = (Collection) obj; - if (containsOnlySimpleTypes(collection)) { - // Simple collection - try direct serialization to preserve full size - byte[] serialized = tryDirectSerialize(obj); - if (serialized != null) { - try { - deserialize(serialized); - return obj; // Success - return original - } catch (Exception e) { - // Fall through to recursive handling - } - } + if (isSafeContainerType(clazz) && containsOnlySimpleTypes(collection)) { + return obj; } return handleCollection(collection, seen, depth, path); } if (clazz.isArray()) { + // Primitive arrays (int[], double[], etc.) are directly serializable by Kryo + if (clazz.getComponentType().isPrimitive()) { + return obj; + } return handleArray(obj, seen, depth, path); } @@ -255,6 +313,19 @@ private static Object recursiveProcess(Object obj, IdentityHashMap clazz) { + return clazz == ArrayList.class || + clazz == LinkedList.class || + clazz == HashMap.class || + clazz == LinkedHashMap.class || + clazz == HashSet.class || + clazz == LinkedHashSet.class; + } + /** * Check if a class is known to be unserializable. */ diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java index 28c2d2998..8596d3ee8 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java @@ -31,7 +31,7 @@ public final class TraceRecorder { private TraceRecorder(TracerConfig config) { this.config = config; - this.writer = new TraceWriter(config.getDbPath()); + this.writer = new TraceWriter(config.getDbPath(), config.isInMemoryDb()); this.maxFunctionCount = config.getMaxFunctionCount(); this.serializerExecutor = Executors.newCachedThreadPool(r -> { Thread t = new Thread(r, "codeflash-serializer"); @@ -76,23 +76,27 @@ private void onEntryImpl(String className, String methodName, String descriptor, return; } - // Serialize args with timeout to prevent deep object graph traversal from blocking + // Serialize args — try inline fast path first, fall back to async with timeout byte[] argsBlob; - Future future = serializerExecutor.submit(() -> Serializer.serialize(args)); - try { - argsBlob = future.get(SERIALIZATION_TIMEOUT_MS, TimeUnit.MILLISECONDS); - } catch (TimeoutException e) { - future.cancel(true); - droppedCaptures.incrementAndGet(); - System.err.println("[codeflash-tracer] Serialization timed out for " + className + "." - + methodName); - return; - } catch (Exception e) { - Throwable cause = e.getCause() != null ? e.getCause() : e; - droppedCaptures.incrementAndGet(); - System.err.println("[codeflash-tracer] Serialization failed for " + className + "." - + methodName + ": " + cause.getClass().getSimpleName() + ": " + cause.getMessage()); - return; + argsBlob = Serializer.serializeFast(args); + if (argsBlob == null) { + // Slow path: async serialization with timeout for complex/unknown types + Future future = serializerExecutor.submit(() -> Serializer.serialize(args)); + try { + argsBlob = future.get(SERIALIZATION_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } catch (TimeoutException e) { + future.cancel(true); + droppedCaptures.incrementAndGet(); + System.err.println("[codeflash-tracer] Serialization timed out for " + className + "." + + methodName); + return; + } catch (Exception e) { + Throwable cause = e.getCause() != null ? e.getCause() : e; + droppedCaptures.incrementAndGet(); + System.err.println("[codeflash-tracer] Serialization failed for " + className + "." + + methodName + ": " + cause.getClass().getSimpleName() + ": " + cause.getMessage()); + return; + } } long timeNs = System.nanoTime(); diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceWriter.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceWriter.java index a9eeabf60..a75872089 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceWriter.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceWriter.java @@ -7,15 +7,22 @@ import java.sql.PreparedStatement; import java.sql.SQLException; import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; public final class TraceWriter { + private static final int BATCH_SIZE = 256; + private static final int QUEUE_CAPACITY = 65536; + private final Connection connection; + private final Path diskPath; + private final boolean inMemory; private final BlockingQueue writeQueue; private final Thread writerThread; private final AtomicBoolean running; @@ -23,14 +30,20 @@ public final class TraceWriter { private PreparedStatement insertFunctionCall; private PreparedStatement insertMetadata; - public TraceWriter(String dbPath) { - this.writeQueue = new LinkedBlockingQueue<>(); + public TraceWriter(String dbPath, boolean inMemory) { + this.diskPath = Paths.get(dbPath).toAbsolutePath(); + this.diskPath.getParent().toFile().mkdirs(); + this.inMemory = inMemory; + this.writeQueue = new ArrayBlockingQueue<>(QUEUE_CAPACITY); this.running = new AtomicBoolean(true); try { - Path path = Paths.get(dbPath).toAbsolutePath(); - path.getParent().toFile().mkdirs(); - this.connection = DriverManager.getConnection("jdbc:sqlite:" + path); + if (inMemory) { + // In-memory database for maximum write performance; flushed to disk via VACUUM INTO at close() + this.connection = DriverManager.getConnection("jdbc:sqlite::memory:"); + } else { + this.connection = DriverManager.getConnection("jdbc:sqlite:" + this.diskPath); + } initializeSchema(); prepareStatements(); @@ -45,8 +58,12 @@ public TraceWriter(String dbPath) { private void initializeSchema() throws SQLException { try (Statement stmt = connection.createStatement()) { - stmt.execute("PRAGMA journal_mode=WAL"); - stmt.execute("PRAGMA synchronous=NORMAL"); + if (!inMemory) { + stmt.execute("PRAGMA journal_mode=WAL"); + stmt.execute("PRAGMA synchronous=NORMAL"); + stmt.execute("PRAGMA cache_size=-16000"); + stmt.execute("PRAGMA temp_store=MEMORY"); + } stmt.execute( "CREATE TABLE IF NOT EXISTS function_calls(" + @@ -69,6 +86,8 @@ private void initializeSchema() throws SQLException { stmt.execute("CREATE INDEX IF NOT EXISTS idx_fc_class_func ON function_calls(classname, function)"); } + // Keep autocommit off for writer performance — commit explicitly per batch + connection.setAutoCommit(false); } private void prepareStatements() throws SQLException { @@ -95,27 +114,59 @@ public void writeMetadata(Map metadata) { } private void writerLoop() { + List batch = new ArrayList<>(BATCH_SIZE); + while (running.get() || !writeQueue.isEmpty()) { try { WriteTask task = writeQueue.poll(100, TimeUnit.MILLISECONDS); - if (task != null) { - task.execute(this); + if (task == null) { + continue; } + batch.add(task); + writeQueue.drainTo(batch, BATCH_SIZE - 1); + executeBatch(batch); + batch.clear(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); break; - } catch (SQLException e) { - System.err.println("[codeflash-tracer] Write error: " + e.getMessage()); } } // Drain remaining - WriteTask task; - while ((task = writeQueue.poll()) != null) { + writeQueue.drainTo(batch); + if (!batch.isEmpty()) { + executeBatch(batch); + } + } + + private void executeBatch(List batch) { + if (batch.isEmpty()) { + return; + } + + boolean hasFunctionCalls = false; + try { + for (WriteTask task : batch) { + if (task instanceof FunctionCallTask) { + ((FunctionCallTask) task).bindParameters(this); + insertFunctionCall.addBatch(); + hasFunctionCalls = true; + } else { + task.execute(this); + } + } + + if (hasFunctionCalls) { + insertFunctionCall.executeBatch(); + } + + connection.commit(); + } catch (SQLException e) { + System.err.println("[codeflash-tracer] Batch write error (" + batch.size() + " tasks): " + e.getMessage()); try { - task.execute(this); - } catch (SQLException e) { - System.err.println("[codeflash-tracer] Write error: " + e.getMessage()); + connection.rollback(); + } catch (SQLException re) { + System.err.println("[codeflash-tracer] Rollback failed: " + re.getMessage()); } } } @@ -139,9 +190,27 @@ public void close() { Thread.currentThread().interrupt(); } + // Close prepared statements first — required before VACUUM try { if (insertFunctionCall != null) insertFunctionCall.close(); if (insertMetadata != null) insertMetadata.close(); + } catch (SQLException e) { + System.err.println("[codeflash-tracer] Error closing statements: " + e.getMessage()); + } + + if (inMemory) { + try { + connection.commit(); + connection.setAutoCommit(true); + try (Statement stmt = connection.createStatement()) { + stmt.execute("VACUUM INTO '" + diskPath.toString().replace("'", "''") + "'"); + } + } catch (SQLException e) { + System.err.println("[codeflash-tracer] Failed to write trace DB to disk: " + e.getMessage()); + } + } + + try { if (connection != null) connection.close(); } catch (SQLException e) { System.err.println("[codeflash-tracer] Error closing TraceWriter: " + e.getMessage()); @@ -177,8 +246,7 @@ private static class FunctionCallTask implements WriteTask { this.argsBlob = argsBlob; } - @Override - public void execute(TraceWriter writer) throws SQLException { + void bindParameters(TraceWriter writer) throws SQLException { writer.insertFunctionCall.setString(1, type); writer.insertFunctionCall.setString(2, function); writer.insertFunctionCall.setString(3, classname); @@ -187,6 +255,11 @@ public void execute(TraceWriter writer) throws SQLException { writer.insertFunctionCall.setString(6, descriptor); writer.insertFunctionCall.setLong(7, timeNs); writer.insertFunctionCall.setBytes(8, argsBlob); + } + + @Override + public void execute(TraceWriter writer) throws SQLException { + bindParameters(writer); writer.insertFunctionCall.executeUpdate(); } } diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracerConfig.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracerConfig.java index 8fe799d2f..9e2675c00 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracerConfig.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracerConfig.java @@ -30,6 +30,9 @@ public final class TracerConfig { @SerializedName("projectRoot") private String projectRoot = ""; + @SerializedName("inMemoryDb") + private boolean inMemoryDb = false; + private static final Gson GSON = new Gson(); public static TracerConfig parse(String agentArgs) { @@ -89,6 +92,10 @@ public String getProjectRoot() { return projectRoot; } + public boolean isInMemoryDb() { + return inMemoryDb; + } + public boolean shouldInstrumentClass(String internalClassName) { String dotName = internalClassName.replace('/', '.'); diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index ff6494d73..c8cb8d884 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from enum import Enum from typing import Any, Union @@ -17,7 +18,7 @@ CONCURRENCY_FACTOR = 10 # Number of concurrent executions for concurrency benchmark MAX_TEST_FUNCTION_RUNS = 50 MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS = 100e6 # 100ms -TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget +TOTAL_LOOPING_TIME = float(os.getenv("CODEFLASH_LOOPING_TIME", "10.0")) # candidate benchmarking budget (seconds) COVERAGE_THRESHOLD = 60.0 MIN_TESTCASE_PASSED_THRESHOLD = 6 REPEAT_OPTIMIZATION_PROBABILITY = 0.1 diff --git a/codeflash/languages/function_optimizer.py b/codeflash/languages/function_optimizer.py index d9b4918fd..71ad03b18 100644 --- a/codeflash/languages/function_optimizer.py +++ b/codeflash/languages/function_optimizer.py @@ -489,6 +489,7 @@ def __init__( else function_to_optimize.file_path.read_text(encoding="utf8") ) self.language_support = current_language_support() + self.language_support.ensure_runtime_environment(self.project_root) if not function_to_optimize_ast: self.function_to_optimize_ast = self._resolve_function_ast( self.function_to_optimize_source_code, function_to_optimize.function_name, function_to_optimize.parents @@ -3253,6 +3254,11 @@ def get_test_env( test_env["CODEFLASH_TEST_ITERATION"] = str(codeflash_test_iteration) test_env["CODEFLASH_TRACER_DISABLE"] = str(codeflash_tracer_disable) test_env["CODEFLASH_LOOP_INDEX"] = str(codeflash_loop_index) + # Pin PYTHONHASHSEED so original and candidate test processes use the same hash seed. + # Without this, each subprocess gets a random seed, which can cause non-deterministic + # iteration order in sets/dicts and lead to flaky return-value comparisons. + if "PYTHONHASHSEED" not in test_env: + test_env["PYTHONHASHSEED"] = "0" return test_env def line_profiler_step( diff --git a/codeflash/languages/java/tracer.py b/codeflash/languages/java/tracer.py index 50506797e..649369d97 100644 --- a/codeflash/languages/java/tracer.py +++ b/codeflash/languages/java/tracer.py @@ -6,6 +6,7 @@ import subprocess from typing import TYPE_CHECKING +from codeflash.code_utils.env_utils import is_ci from codeflash.languages.java.line_profiler import find_agent_jar from codeflash.languages.java.replay_test import generate_replay_tests @@ -60,7 +61,7 @@ def _run_java_with_graceful_timeout( class JavaTracer: - """Orchestrates two-stage Java tracing: JFR profiling + argument capture.""" + """Orchestrates Java tracing: combined JFR profiling + argument capture in a single JVM invocation.""" def trace( self, @@ -71,29 +72,23 @@ def trace( max_function_count: int = 256, timeout: int = 0, ) -> tuple[Path, Path]: - """Run the Java program twice: once for profiling, once for arg capture. + """Run the Java program once with both JFR profiling and argument capture. Returns (trace_db_path, jfr_file_path). """ jfr_file = trace_db_path.with_suffix(".jfr") trace_db_path.parent.mkdir(parents=True, exist_ok=True) - # Stage 1: JFR Profiling - logger.info("Stage 1: Running JFR profiling...") - jfr_env = self.build_jfr_env(jfr_file) - _run_java_with_graceful_timeout(java_command, jfr_env, timeout, "JFR profiling") - - if not jfr_file.exists(): - logger.warning("JFR file was not created at %s", jfr_file) - - # Stage 2: Argument Capture via Tracing Agent - logger.info("Stage 2: Running argument capture...") config_path = self.create_tracer_config( trace_db_path, packages, project_root=project_root, max_function_count=max_function_count, timeout=timeout ) - agent_env = self.build_agent_env(config_path) - _run_java_with_graceful_timeout(java_command, agent_env, timeout, "Argument capture") + combined_env = self.build_combined_env(jfr_file, config_path) + + logger.info("Running combined JFR profiling + argument capture...") + _run_java_with_graceful_timeout(java_command, combined_env, timeout, "Combined tracing") + if not jfr_file.exists(): + logger.warning("JFR file was not created at %s", jfr_file) if not trace_db_path.exists(): logger.error("Trace database was not created at %s", trace_db_path) @@ -114,6 +109,7 @@ def create_tracer_config( "maxFunctionCount": max_function_count, "timeout": timeout, "projectRoot": str(project_root.resolve()) if project_root else "", + "inMemoryDb": is_ci(), } config_path = trace_db_path.with_suffix(".config.json") @@ -122,12 +118,7 @@ def create_tracer_config( def build_jfr_env(self, jfr_file: Path) -> dict[str, str]: env = os.environ.copy() - # Use profile settings with increased sampling frequency (1ms instead of default 10ms) - # This captures more samples for short-running programs - jfr_opts = ( - f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true" - ",jdk.ExecutionSample#period=1ms" - ) + jfr_opts = f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true" existing = env.get("JAVA_TOOL_OPTIONS", "") env["JAVA_TOOL_OPTIONS"] = f"{existing} {jfr_opts}".strip() return env @@ -144,6 +135,19 @@ def build_agent_env(self, config_path: Path, classpath: str | None = None) -> di env["JAVA_TOOL_OPTIONS"] = f"{existing} {agent_opts}".strip() return env + def build_combined_env(self, jfr_file: Path, config_path: Path, classpath: str | None = None) -> dict[str, str]: + """Build env with both JFR recording and tracing agent in a single JAVA_TOOL_OPTIONS.""" + env = os.environ.copy() + jfr_opts = f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true" + agent_jar = find_agent_jar(classpath=classpath) + if agent_jar is None: + msg = "codeflash-runtime JAR not found, cannot run tracing agent" + raise FileNotFoundError(msg) + agent_opts = f"{ADD_OPENS_FLAGS} -javaagent:{agent_jar}=trace={config_path.resolve()}" + existing = env.get("JAVA_TOOL_OPTIONS", "") + env["JAVA_TOOL_OPTIONS"] = f"{existing} {jfr_opts} {agent_opts}".strip() + return env + @staticmethod def detect_packages_from_source(module_root: Path) -> list[str]: """Scan Java files for package declarations and return unique package prefixes.""" diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 48920be8c..4a7d24585 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -349,10 +349,10 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: max_function_count = getattr(config, "max_function_count", 256) timeout = int(getattr(config, "timeout", None) or getattr(config, "tracer_timeout", 0) or 0) - console.print("[bold]Java project detected[/]") - console.print(f" Project root: {project_root}") - console.print(f" Module root: {getattr(config, 'module_root', '?')}") - console.print(f" Tests root: {getattr(config, 'tests_root', '?')}") + logger.info("Java project detected") + logger.info(" Project root: %s", project_root) + logger.info(" Module root: %s", getattr(config, "module_root", "?")) + logger.info(" Tests root: %s", getattr(config, "tests_root", "?")) from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.languages.java.tracer import JavaTracer, run_java_tracer diff --git a/tests/scripts/end_to_end_test_java_tracer.py b/tests/scripts/end_to_end_test_java_tracer.py index 0f9f8a2ff..5d92662ec 100644 --- a/tests/scripts/end_to_end_test_java_tracer.py +++ b/tests/scripts/end_to_end_test_java_tracer.py @@ -51,6 +51,8 @@ def run_test(expected_improvement_pct: int) -> bool: "-m", "codeflash.main", "--no-pr", + "--effort", + "low", "optimize", "java", "-cp", diff --git a/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/ProfilingWorkload.java b/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/ProfilingWorkload.java new file mode 100644 index 000000000..b7c48c625 --- /dev/null +++ b/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/ProfilingWorkload.java @@ -0,0 +1,91 @@ +package com.example; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Profiling workload for benchmarking the codeflash tracing agent. + * Exercises different argument types to stress serialization paths. + */ +public class ProfilingWorkload { + + // 1. Primitives only — cheapest to serialize + public static int addInts(int a, int b) { + return a + b; + } + + // 2. String arguments — moderate serialization cost + public static String concatStrings(String a, String b) { + return a + b; + } + + // 3. Array argument — requires element-by-element serialization + public static int sumArray(int[] values) { + int sum = 0; + for (int v : values) { + sum += v; + } + return sum; + } + + // 4. Collection argument — triggers recursive Kryo processing + public static int sumList(List values) { + int sum = 0; + for (int v : values) { + sum += v; + } + return sum; + } + + // 5. Nested map — deep object graph, expensive serialization + public static int countMapEntries(Map> data) { + int count = 0; + for (List list : data.values()) { + count += list.size(); + } + return count; + } + + public static void main(String[] args) { + int iterations = 1000; + + // 1. Primitives + for (int i = 0; i < iterations; i++) { + addInts(i, i + 1); + } + + // 2. Strings + for (int i = 0; i < iterations; i++) { + concatStrings("hello-" + i, "-world"); + } + + // 3. Arrays + int[] arr = new int[100]; + for (int i = 0; i < arr.length; i++) arr[i] = i; + for (int i = 0; i < iterations; i++) { + sumArray(arr); + } + + // 4. Lists + List list = new ArrayList<>(100); + for (int i = 0; i < 100; i++) list.add(i); + for (int i = 0; i < iterations; i++) { + sumList(list); + } + + // 5. Nested maps + Map> map = new HashMap<>(); + for (int i = 0; i < 10; i++) { + List vals = new ArrayList<>(); + for (int j = 0; j < 10; j++) vals.add(j); + map.put("key-" + i, vals); + } + for (int i = 0; i < iterations; i++) { + countMapEntries(map); + } + + System.out.println("ProfilingWorkload complete."); + } +} diff --git a/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java b/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java index 7beb2a4ea..7dfdad95f 100644 --- a/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java +++ b/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java @@ -1,18 +1,7 @@ package com.example; -import java.util.ArrayList; -import java.util.List; - public class Workload { - public static int computeSum(int n) { - int sum = 0; - for (int i = 0; i < n; i++) { - sum += i; - } - return sum; - } - public static String repeatString(String s, int count) { String result = ""; for (int i = 0; i < count; i++) { @@ -21,46 +10,15 @@ public static String repeatString(String s, int count) { return result; } - public static List filterEvens(List numbers) { - List result = new ArrayList<>(); - for (int n : numbers) { - if (n % 2 == 0) { - result.add(n); - } - } - return result; - } - - public int instanceMethod(int x, int y) { - return x * y + computeSum(x); - } - public static void main(String[] args) { - // Run methods with large inputs so JFR can capture CPU samples. - // Small inputs finish too fast (<1ms) for JFR's 10ms sampling interval. + // Run with large inputs so JFR can capture CPU samples. for (int round = 0; round < 1000; round++) { - computeSum(100_000); repeatString("hello world ", 1000); - - List nums = new ArrayList<>(); - for (int i = 1; i <= 10_000; i++) nums.add(i); - filterEvens(nums); - - Workload w = new Workload(); - w.instanceMethod(100_000, 42); } // Also call with small inputs for variety in traced args - System.out.println("computeSum(100) = " + computeSum(100)); System.out.println("repeatString(\"ab\", 3) = " + repeatString("ab", 3)); - List small = new ArrayList<>(); - for (int i = 1; i <= 10; i++) small.add(i); - System.out.println("filterEvens(1..10) = " + filterEvens(small)); - - Workload w = new Workload(); - System.out.println("instanceMethod(5, 3) = " + w.instanceMethod(5, 3)); - System.out.println("Workload complete."); } } diff --git a/tests/test_languages/test_java/test_java_tracer_e2e.py b/tests/test_languages/test_java/test_java_tracer_e2e.py index 157f23eb6..f16f19aa2 100644 --- a/tests/test_languages/test_java/test_java_tracer_e2e.py +++ b/tests/test_languages/test_java/test_java_tracer_e2e.py @@ -81,14 +81,11 @@ def test_agent_captures_invocations(self, compiled_workload: Path, trace_db: Pat conn = sqlite3.connect(str(trace_db)) try: rows = conn.execute("SELECT function, classname, descriptor, length(args) FROM function_calls").fetchall() - assert len(rows) >= 5, f"Expected at least 5 captured invocations, got {len(rows)}" + assert len(rows) >= 2, f"Expected at least 2 captured invocations, got {len(rows)}" # Check that specific methods were captured functions = {row[0] for row in rows} - assert "computeSum" in functions assert "repeatString" in functions - assert "filterEvens" in functions - assert "instanceMethod" in functions # Verify all rows have non-empty args blobs for row in rows: @@ -97,7 +94,7 @@ def test_agent_captures_invocations(self, compiled_workload: Path, trace_db: Pat # Verify metadata metadata = dict(conn.execute("SELECT key, value FROM metadata").fetchall()) assert "totalCaptures" in metadata - assert int(metadata["totalCaptures"]) >= 5 + assert int(metadata["totalCaptures"]) >= 2 finally: conn.close() @@ -136,11 +133,11 @@ def test_max_function_count_limit(self, compiled_workload: Path, trace_db: Path) conn = sqlite3.connect(str(trace_db)) try: - # computeSum is called 4 times (2 direct + 2 from instanceMethod) - compute_count = conn.execute( - "SELECT COUNT(*) FROM function_calls WHERE function = 'computeSum'" + # repeatString is called 1000+ times; with maxFunctionCount=2, at most 2 should be captured + repeat_count = conn.execute( + "SELECT COUNT(*) FROM function_calls WHERE function = 'repeatString'" ).fetchone()[0] - assert compute_count <= 2, f"Expected at most 2 computeSum captures, got {compute_count}" + assert repeat_count <= 2, f"Expected at most 2 repeatString captures, got {repeat_count}" finally: conn.close() @@ -198,7 +195,6 @@ def test_generates_test_files(self, compiled_workload: Path, trace_db: Path, tmp assert "package codeflash.replay;" in content assert "import org.junit.jupiter.api.Test;" in content assert "ReplayHelper" in content - assert "replay_computeSum_0" in content assert "replay_repeatString_0" in content def test_metadata_parsing(self, compiled_workload: Path, trace_db: Path, tmp_path: Path) -> None: @@ -243,7 +239,7 @@ def test_metadata_parsing(self, compiled_workload: Path, trace_db: Path, tmp_pat assert "functions" in metadata assert "trace_file" in metadata assert "classname" in metadata - assert "computeSum" in metadata["functions"] + assert "repeatString" in metadata["functions"] assert metadata["classname"] == "com.example.Workload" assert metadata["trace_file"] == trace_db.as_posix() @@ -267,7 +263,7 @@ def test_two_stage_trace(self, compiled_workload: Path, tmp_path: Path) -> None: conn = sqlite3.connect(str(trace_db)) try: count = conn.execute("SELECT COUNT(*) FROM function_calls").fetchone()[0] - assert count >= 5, f"Expected at least 5 captured invocations, got {count}" + assert count >= 2, f"Expected at least 2 captured invocations, got {count}" finally: conn.close() @@ -295,8 +291,7 @@ def test_full_trace_and_replay_generation(self, compiled_workload: Path, tmp_pat workload_files = [f for f in test_files if "Workload" in f.name and "ConstructorAccess" not in f.name] assert len(workload_files) == 1 content = workload_files[0].read_text(encoding="utf-8") - assert "replay_computeSum" in content - assert "replay_instanceMethod" in content + assert "replay_repeatString" in content def test_package_detection(self) -> None: """Test that package detection finds Java packages from source files.""" diff --git a/tests/test_languages/test_java/test_java_tracer_integration.py b/tests/test_languages/test_java/test_java_tracer_integration.py index f6ffefdf2..6927faba4 100644 --- a/tests/test_languages/test_java/test_java_tracer_integration.py +++ b/tests/test_languages/test_java/test_java_tracer_integration.py @@ -87,7 +87,6 @@ def test_discover_functions_from_replay_tests(self, traced_workload: tuple) -> N assert func.language == "java", f"Expected language='java', got '{func.language}'" assert func.file_path == file_path - assert "computeSum" in all_func_names assert "repeatString" in all_func_names def test_discover_tests_for_replay_tests(self, traced_workload: tuple) -> None: @@ -111,7 +110,6 @@ def test_discover_tests_for_replay_tests(self, traced_workload: tuple) -> None: func_name = qualified_name.split(".")[-1] if "." in qualified_name else qualified_name matched_func_names.add(func_name) - assert "computeSum" in matched_func_names, f"computeSum not found in: {result.keys()}" assert "repeatString" in matched_func_names, f"repeatString not found in: {result.keys()}" # Each function should have at least one test @@ -222,8 +220,8 @@ def test_full_pipeline(self, compiled_workload: Path, tmp_path: Path) -> None: assert len(function_to_tests) > 0, "No function-to-test mappings" # Verify function_to_tests has entries for our traced functions - has_compute_sum = any("computeSum" in key for key in function_to_tests) - assert has_compute_sum, f"computeSum not in function_to_tests keys: {list(function_to_tests.keys())}" + has_repeat_string = any("repeatString" in key for key in function_to_tests) + assert has_repeat_string, f"repeatString not in function_to_tests keys: {list(function_to_tests.keys())}" # Step 4: Rank functions (like optimizer.rank_all_functions_globally) if jfr_file.exists(): @@ -280,7 +278,7 @@ def test_instrument_and_compile_replay_tests(self, compiled_workload: Path, tmp_ source_code = WORKLOAD_SOURCE.read_text(encoding="utf-8") source_functions = discover_functions_from_source(source_code, file_path=WORKLOAD_SOURCE) # Pick the first function with a return type for instrumentation - target_func = next(f for f in source_functions if f.function_name == "computeSum") + target_func = next(f for f in source_functions if f.function_name == "repeatString") replay_test_file = replay_test_paths[0] test_source = replay_test_file.read_text(encoding="utf-8")