Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,35 @@ default AsyncRunnable thenRunRetryingWhile(
});
}

/**
* This method is equivalent to a while loop, where the condition is checked before each iteration.
* If the condition returns {@code false} on the first check, the body is never executed.
*
* @param loopBodyRunnable the asynchronous task to be executed in each iteration of the loop
* @param whileCheck a condition to check before each iteration; the loop continues as long as this condition returns true
* @return the composition of this and the looping branch
* @see AsyncCallbackLoop
*/
default AsyncRunnable thenRunWhileLoop(final BooleanSupplier whileCheck, final AsyncRunnable loopBodyRunnable) {
return thenRun(finalCallback -> {
LoopState loopState = new LoopState();
new AsyncCallbackLoop(loopState, iterationCallback -> {

if (loopState.breakAndCompleteIf(() -> !whileCheck.getAsBoolean(), iterationCallback)) {
return;
}
loopBodyRunnable.finish((result, t) -> {
if (t != null) {
iterationCallback.completeExceptionally(t);
return;
}
iterationCallback.complete(iterationCallback);
});

}).run(finalCallback);
});
}

/**
* This method is equivalent to a do-while loop, where the loop body is executed first and
* then the condition is checked to determine whether the loop should continue.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright 2008-present MongoDB, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.mongodb.internal.async;

import com.mongodb.annotations.NotThreadSafe;
import com.mongodb.lang.Nullable;

/**
* A trampoline that converts recursive callback invocations into an iterative loop,
* preventing stack overflow in async loops.
*
* <p>When async loop iterations complete synchronously on the same thread, callback
* recursion occurs: each iteration's {@code callback.onResult()} immediately triggers
* the next iteration, causing unbounded stack growth. For example, a 1000-iteration
* loop would create > 1000 stack frames and cause {@code StackOverflowError}.</p>
*
* <p>The trampoline intercepts this recursion: instead of executing the next iteration
* immediately (which would deepen the stack), it enqueues the continuation and returns, allowing
* the stack to unwind. A flat loop at the top then processes enqueued continuation iteratively,
* maintaining constant stack depth regardless of iteration count.</p>
*
* <p>Since async chains are sequential, at most one task is pending at any time.
* The trampoline uses a single slot rather than a queue.</p>
*
* The first call on a thread becomes the "trampoline owner" and runs the drain loop.
* Subsequent (re-entrant) calls on the same thread enqueue their continuation and return immediately.</p>
*
* <p>This class is not part of the public API and may be removed or changed at any time</p>
*/
@NotThreadSafe
public final class AsyncTrampoline {

private static final ThreadLocal<ContinuationHolder> TRAMPOLINE = new ThreadLocal<>();

private AsyncTrampoline() {}

/**
* Execute continuation through the trampoline. If no trampoline is active, become the owner
* and drain all enqueued continuations. If a trampoline is already active, enqueue and return.
*/
public static void run(final Runnable continuation) {
ContinuationHolder continuationHolder = TRAMPOLINE.get();
if (continuationHolder != null) {
continuationHolder.enqueue(continuation);
} else {
continuationHolder = new ContinuationHolder();
TRAMPOLINE.set(continuationHolder);
try {
continuation.run();
while (continuationHolder.continuation != null) {
Runnable continuationToRun = continuationHolder.continuation;
continuationHolder.continuation = null;
continuationToRun.run();
}
} finally {
TRAMPOLINE.remove();
}
}
}

/**
* A single-slot container for continuation.
* At most one continuation is pending at any time in a sequential async chain.
*/
@NotThreadSafe
private static final class ContinuationHolder {
@Nullable
private Runnable continuation;

void enqueue(final Runnable continuation) {
if (this.continuation != null) {
throw new AssertionError("Trampoline slot already occupied");
}
this.continuation = continuation;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.mongodb.internal.async.function;

import com.mongodb.annotations.NotThreadSafe;
import com.mongodb.internal.async.AsyncTrampoline;
import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.lang.Nullable;

Expand Down Expand Up @@ -62,9 +63,11 @@ public void run(final SingleResultCallback<Void> callback) {
@NotThreadSafe
private class LoopingCallback implements SingleResultCallback<Void> {
private final SingleResultCallback<Void> wrapped;
private final Runnable nextIteration;

LoopingCallback(final SingleResultCallback<Void> callback) {
wrapped = callback;
nextIteration = () -> body.run(this);
Copy link
Member Author

@vbabanin vbabanin Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nextIteration is reused to avoid creation of extra objects via LambdaMetafactory as we have the capturing lambda.

bounce.work = task is a write to a heap object's field, which can be considered an automatic escape in the JIT's analysis. Even if the Bounce object is short-lived, the JIT sees "object written to another object's field" and should give up.

The AsyncCallbackLoop JMH GC profiling (OpenJDK 17.0.10 LTS, 64-bit Server VM, mixed mode with compressed oops).

Metric Runnable (this) Lambda
Alloc rate 0.039 MB/sec 96.924 MB/sec
Alloc per op 64 B/op 160,048 B/op
GC count ~ 0 10
GC time 0 ms 9 ms

For Lambda case:
Per iteration: 1 lambda * 16 bytes = 16 B

  • Per op (10,000 iterations): 10,000 * 16 = 160,000 B
  • Plus one-time objects ~ 48 B

}

@Override
Expand All @@ -80,7 +83,7 @@ public void onResult(@Nullable final Void result, @Nullable final Throwable t) {
return;
}
if (continueLooping) {
body.run(this);
AsyncTrampoline.run(nextIteration);
} else {
wrapped.onResult(result, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
import static org.junit.jupiter.api.Assertions.assertEquals;

abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase {
private static final TimeoutContext TIMEOUT_CONTEXT = new TimeoutContext(new TimeoutSettings(0, 0, 0, 0L, 0));
Expand Down Expand Up @@ -723,6 +724,120 @@ void testTryCatchTestAndRethrow() {
});
}

@Test
void testWhile() {
// last iteration: 3 < 3 = 1
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 1(transition to next iteration) = 4
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 4(transition to next iteration) = 7
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 7(transition to next iteration) = 10
assertBehavesSameVariations(10,
() -> {
int counter = 0;
while (counter < 3 && plainTest(counter)) {
counter++;
sync(counter);
}
},
(callback) -> {
MutableValue<Integer> counter = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(() -> counter.get() < 3 && plainTest(counter.get()), c2 -> {
counter.set(counter.get() + 1);
async(counter.get(), c2);
}).finish(callback);
});
}

@Test
void testWhileWithThenRun() {
// while: last iteration: 3 < 3 = 1
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 1(transition to next iteration) = 4
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 4(transition to next iteration) = 7
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 7(transition to next iteration) = 10
// trailing sync: 1(exception) + 1(success) = 2
// 6(while exception) + 4(while success) * 2(trailing sync) = 14
assertBehavesSameVariations(14,
() -> {
int counter = 0;
while (counter < 3 && plainTest(counter)) {
counter++;
sync(counter);
}
sync(counter + 1);
},
(callback) -> {
MutableValue<Integer> counter = new MutableValue<>(0);
beginAsync().thenRun(c -> {
beginAsync().thenRunWhileLoop(() -> counter.get() < 3 && plainTest(counter.get()), c2 -> {
counter.set(counter.get() + 1);
async(counter.get(), c2);
}).finish(c);
}).thenRun(c -> {
async(counter.get() + 1, c);
}).finish(callback);
});
}

@Test
void testNestedWhileLoops() {
// inner while: 4 success + 6 exception = 10
// last inner iteration: 3 < 3 = 1
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 1(transition to next iteration) = 12
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 12(transition to next iteration) = 56
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 56(transition to next iteration) = 232
assertBehavesSameVariations(232,
() -> {
int outer = 0;
while (outer < 3 && plainTest(outer)) {
int inner = 0;
while (inner < 3 && plainTest(inner)) {
sync(outer + inner);
inner++;
}
outer++;
}
},
(callback) -> {
MutableValue<Integer> outer = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(() -> outer.get() < 3 && plainTest(outer.get()), c -> {
MutableValue<Integer> inner = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(
() -> inner.get() < 3 && plainTest(inner.get()),
c2 -> {
beginAsync().thenRun(c3 -> {
async(outer.get() + inner.get(), c3);
}).thenRun(c3 -> {
inner.set(inner.get() + 1);
c3.complete(c3);
}).finish(c2);
}
).thenRun(c2 -> {
outer.set(outer.get() + 1);
c2.complete(c2);
}).finish(c);
}).finish(callback);
});
}

@Test
void testWhileLoopStackConstant() {
int depthWith100 = maxStackDepthForIterations(100);
int depthWith10000 = maxStackDepthForIterations(10_000);
assertEquals(depthWith100, depthWith10000, "Stack depth should be constant regardless of iteration count (trampoline)");
}

private int maxStackDepthForIterations(final int iterations) {
MutableValue<Integer> counter = new MutableValue<>(0);
MutableValue<Integer> maxDepth = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(() -> counter.get() < iterations, c -> {
maxDepth.set(Math.max(maxDepth.get(), Thread.currentThread().getStackTrace().length));
counter.set(counter.get() + 1);
c.complete(c);
}).finish((v, t) -> {});

assertEquals(iterations, counter.get());
return maxDepth.get();
}

@Test
void testRetryLoop() {
assertBehavesSameVariations(InvocationTracker.DEPTH_LIMIT * 2 + 1,
Expand Down Expand Up @@ -768,6 +883,65 @@ void testDoWhileLoop() {
});
}

@Test
void testNestedDoWhileLoops() {
// inner do-while: 3 success + 5 exception = 8
// last outer iteration: 3 < 3 = 1
// 5(inner exception) + 3(inner success) * 1(transition to next iteration) = 8
// 5(inner exception) + 3(inner success) * (1(outer plainTest exception) + 1(outer plainTest false) + 8(transition to next iteration)) = 35
// 5(inner exception) + 3(inner success) * (1(outer plainTest exception) + 1(outer plainTest false) + 35(transition to next iteration)) = 116
assertBehavesSameVariations(116,
() -> {
int outer = 0;
do {
int inner = 0;
do {
sync(outer + inner);
inner++;
} while (inner < 3 && plainTest(inner));
outer++;
} while (outer < 3 && plainTest(outer));
},
(callback) -> {
MutableValue<Integer> outer = new MutableValue<>(0);
beginAsync().thenRunDoWhileLoop(c -> {
MutableValue<Integer> inner = new MutableValue<>(0);
beginAsync().thenRunDoWhileLoop(c2 -> {
beginAsync().thenRun(c3 -> {
async(outer.get() + inner.get(), c3);
}).thenRun(c3 -> {
inner.set(inner.get() + 1);
c3.complete(c3);
}).finish(c2);
}, () -> inner.get() < 3 && plainTest(inner.get())
).thenRun(c2 -> {
outer.set(outer.get() + 1);
c2.complete(c2);
}).finish(c);
}, () -> outer.get() < 3 && plainTest(outer.get())).finish(callback);
});
}

@Test
void testDoWhileLoopStackConstant() {
int depthWith100 = maxDoWhileStackDepthForIterations(100);
int depthWith10000 = maxDoWhileStackDepthForIterations(10_000);
assertEquals(depthWith100, depthWith10000,
"Stack depth should be constant regardless of iteration count");
}

private int maxDoWhileStackDepthForIterations(final int iterations) {
MutableValue<Integer> counter = new MutableValue<>(0);
MutableValue<Integer> maxDepth = new MutableValue<>(0);
beginAsync().thenRunDoWhileLoop(c -> {
maxDepth.set(Math.max(maxDepth.get(), Thread.currentThread().getStackTrace().length));
counter.set(counter.get() + 1);
c.complete(c);
}, () -> counter.get() < iterations).finish((v, t) -> {});
assertEquals(iterations, counter.get());
return maxDepth.get();
}

@Test
void testFinallyWithPlainInsideTry() {
// (in try: normal flow + exception + exception) * (in finally: normal + exception) = 6
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.function.Consumer;
import java.util.function.Supplier;

import static java.lang.String.format;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -272,14 +273,16 @@ private <T> void assertBehavesSame(final Supplier<T> sync, final Runnable betwee
}

assertTrue(wasCalledFuture.isDone(), "callback should have been called");
assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched");
assertEquals(expectedValue, actualValue.get());
assertEquals(expectedException == null, actualException.get() == null,
"both or neither should have produced an exception");
format("both or neither should have produced an exception. Expected exception: %s, actual exception: %s",
expectedException,
actualException));
if (expectedException != null) {
assertEquals(expectedException.getMessage(), actualException.get().getMessage());
assertEquals(expectedException.getClass(), actualException.get().getClass());
}
assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched");
assertEquals(expectedValue, actualValue.get());

listener.clear();
}
Expand Down