From 40ce89531f6e8bf245d8ede00e217a389928567c Mon Sep 17 00:00:00 2001 From: Arnab Nandy Date: Sun, 21 Jun 2026 01:15:53 +0530 Subject: [PATCH] Fix asymmetric Observation scope propagation in Kotlin Coroutines Introduce a ThreadLocal nesting check inside PropagationContextElement to prevent duplicate/asymmetric thread-local scope updates during nested undispatched coroutine context changes. Closes #36929 Signed-off-by: Arnab Nandy --- .../core/PropagationContextElement.java | 36 ++++++++++++++++++- .../core/PropagationContextElementTests.kt | 19 ++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/spring-core/src/main/java/org/springframework/core/PropagationContextElement.java b/spring-core/src/main/java/org/springframework/core/PropagationContextElement.java index a490179d82b4..522b93863e94 100644 --- a/spring-core/src/main/java/org/springframework/core/PropagationContextElement.java +++ b/spring-core/src/main/java/org/springframework/core/PropagationContextElement.java @@ -70,6 +70,8 @@ public final class PropagationContextElement extends AbstractCoroutineContextEle private static final boolean coroutinesReactorPresent = ClassUtils.isPresent("kotlinx.coroutines.reactor.ReactorContext", PropagationContextElement.class.getClassLoader()); + private static final ThreadLocal nestingDepth = ThreadLocal.withInitial(() -> 0); + private final ContextSnapshot threadLocalContextSnapshot; @@ -85,6 +87,11 @@ public void restoreThreadContext(CoroutineContext context, ContextSnapshot.Scope @Override public ContextSnapshot.Scope updateThreadContext(CoroutineContext context) { + int depth = nestingDepth.get(); + nestingDepth.set(depth + 1); + if (depth > 0) { + return new NoOpScope(); + } ContextSnapshot contextSnapshot; if (coroutinesReactorPresent) { contextSnapshot = ReactorDelegate.captureFrom(context); @@ -95,7 +102,7 @@ public ContextSnapshot.Scope updateThreadContext(CoroutineContext context) { else { contextSnapshot = this.threadLocalContextSnapshot; } - return contextSnapshot.setThreadLocals(); + return new WrapperScope(contextSnapshot.setThreadLocals()); } public static final class Key implements CoroutineContext.Key { @@ -115,4 +122,31 @@ private static final class ReactorDelegate { } } } + + private static final class WrapperScope implements ContextSnapshot.Scope { + + private final ContextSnapshot.Scope delegate; + + WrapperScope(ContextSnapshot.Scope delegate) { + this.delegate = delegate; + } + + @Override + public void close() { + try { + this.delegate.close(); + } + finally { + nestingDepth.set(0); + } + } + } + + private static final class NoOpScope implements ContextSnapshot.Scope { + + @Override + public void close() { + // No-op + } + } } diff --git a/spring-core/src/test/kotlin/org/springframework/core/PropagationContextElementTests.kt b/spring-core/src/test/kotlin/org/springframework/core/PropagationContextElementTests.kt index 544379966681..9d16782dc168 100644 --- a/spring-core/src/test/kotlin/org/springframework/core/PropagationContextElementTests.kt +++ b/spring-core/src/test/kotlin/org/springframework/core/PropagationContextElementTests.kt @@ -68,6 +68,25 @@ class PropagationContextElementTests { Hooks.disableAutomaticContextPropagation() } + @Test + fun nestedInvocations() { + val observation = Observation.createNotStarted("coroutine", observationRegistry) + observation.observe { + val element = PropagationContextElement() + val scope1 = element.updateThreadContext(kotlin.coroutines.EmptyCoroutineContext) + assertThat(observationRegistry.currentObservation).isEqualTo(observation) + + val scope2 = element.updateThreadContext(kotlin.coroutines.EmptyCoroutineContext) + assertThat(observationRegistry.currentObservation).isEqualTo(observation) + + element.restoreThreadContext(kotlin.coroutines.EmptyCoroutineContext, scope1) + assertThat(observationRegistry.currentObservation).isEqualTo(observation) + + element.restoreThreadContext(kotlin.coroutines.EmptyCoroutineContext, scope2) + assertThat(observationRegistry.currentObservation).isEqualTo(observation) + } + } + suspend fun suspendingFunction(value: String): String? { delay(1) val currentObservation = observationRegistry.currentObservation