Skip to content

Commit ed36a75

Browse files
committed
Fix a race condition when tracing Kotlin coroutines.
* When using a `ThreadContextElement` to track suspensions and resumes, its important that we make `updateThreadContext` and `restoreThreadContext` safe, when calls to them are concurrent and interleaved. * For e.g. this is something that can happen when creating a child coroutine that in turn uses a multi threaded dispatcher. ``` coroutine.updateThreadContext() // Thread #1 ... coroutine body ... // suspension + immediate dispatch happen here coroutine.updateThreadContext() // Thread #2, coroutine is already resumed // ... coroutine body after suspension point on Thread #2 ... coroutine.restoreThreadContext() // Thread #1, is invoked late because Thread #1 is slow coroutine.restoreThreadContext() // Thread #2, may happen in parallel with the previous restore ``` * Previously when `updateThreadContext(...)` would get called on `Thread #2`, we won't emit an additional `begin` slice, for the new track because we are in `STATE_BEGIN` on the old thread. * When ending a slice, we could end on `Thread #2` and not on `Thread #1` which ends up corrupting traces. Test: This is something very difficult to to test, but I added additional integration tests, that involve jumping threads. Example trace: https://ui.perfetto.dev/#!/?s=575b99d9a235654f9de85aadaaa0ce63f31e4770 Change-Id: I18a49c3f6b65fa41ab5f9ad855377ac779aa4326
1 parent faeec94 commit ed36a75

4 files changed

Lines changed: 77 additions & 4 deletions

File tree

tracing/tracing-wire/src/desktopTest/kotlin/androidx/tracing/wire/TracingDemoTest.kt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,20 @@ class TracingDemoTest {
120120
}
121121
}
122122

123+
@Test
124+
internal fun testContextSwitching(): Unit = runBlocking {
125+
var index = 0
126+
driver.use {
127+
tracer.traceCoroutine(category = "category", name = "switch") {
128+
repeat(100) {
129+
index += 1
130+
val dispatcher = if (index % 2 == 0) Dispatchers.Default else Dispatchers.IO
131+
withContext(dispatcher) { delay(100) }
132+
}
133+
}
134+
}
135+
}
136+
123137
@Test
124138
// The amount of time spent sleeping does not affect the outcome of the test.
125139
@Suppress("BanThreadSleep")

tracing/tracing/src/commonMain/kotlin/androidx/tracing/ContextElements.kt

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package androidx.tracing
1818

19+
import androidx.collection.MutableLongIntMap
20+
import androidx.collection.mutableLongIntMapOf
1921
import kotlin.coroutines.AbstractCoroutineContextElement
2022
import kotlin.coroutines.CoroutineContext
2123

@@ -40,14 +42,30 @@ internal constructor(
4042
*/
4143
public open val flowIds: List<Long>,
4244
) : AbstractCoroutineContextElement(key = KEY), PropagationToken, AutoCloseable {
43-
// Always starts in a `begin` state.
44-
@JvmField internal var started: Int = STATE_BEGIN
45+
46+
// Calls to update and restore will happen concurrently on multiple threads, if we happen
47+
// to jump threads between suspends and resumes. Therefore, it's important for us to track
48+
// begin and end per thread id. For more information refer to the `Reentrancy and thread-safety`
49+
// section in the documentation for `CopyableThreadContextElement`.
50+
@JvmField internal val started: MutableLongIntMap = mutableLongIntMapOf()
51+
52+
init {
53+
synchronized(this) {
54+
// Always starts in a `begin` state.
55+
started[currentThreadId()] = STATE_BEGIN
56+
}
57+
}
4558

4659
@Suppress("NOTHING_TO_INLINE")
4760
internal inline fun synchronizedCompareAndSet(expected: Int, newValue: Int): Boolean {
4861
return synchronized(this) {
49-
if (started == expected) {
50-
started = newValue
62+
// Treat the absence of an entry here as `STATE_END` given we have not emitted a
63+
// trace packet yet on this Thread. This is true for every child coroutine, but
64+
// **not** for the coroutine that kicked things off (handled by the init block).
65+
val id = currentThreadId()
66+
val current = started.getOrDefault(key = id, defaultValue = STATE_END)
67+
if (current == expected) {
68+
started[id] = newValue
5169
true
5270
} else {
5371
false
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright 2026 The Android Open Source Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package androidx.tracing
18+
19+
internal expect inline fun currentThreadId(): Long
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright 2026 The Android Open Source Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package androidx.tracing
18+
19+
@Suppress("NOTHING_TO_INLINE", "DEPRECATION")
20+
internal actual inline fun currentThreadId(): Long {
21+
return Thread.currentThread().id
22+
}

0 commit comments

Comments
 (0)