Skip to content

Commit 0d2c37c

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Introducing Tracing.withContext()
Tracing.withContext() will improve how RxJava + Tracing works. Here's an example of how it will be used: ``` this.pluginManager .onUserMessageCallback(initialContext, newMessage) .compose(Tracing.<Content>withContext(capturedContext)) ``` The `.compose()` is a standin for calling `capturedContext.makeCurrent()` in the event handler. PiperOrigin-RevId: 883160268
1 parent a47b651 commit 0d2c37c

2 files changed

Lines changed: 267 additions & 3 deletions

File tree

core/src/main/java/com/google/adk/telemetry/Tracing.java

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,20 @@
3737
import io.opentelemetry.context.Context;
3838
import io.opentelemetry.context.Scope;
3939
import io.reactivex.rxjava3.core.Completable;
40+
import io.reactivex.rxjava3.core.CompletableObserver;
4041
import io.reactivex.rxjava3.core.CompletableSource;
4142
import io.reactivex.rxjava3.core.CompletableTransformer;
4243
import io.reactivex.rxjava3.core.Flowable;
4344
import io.reactivex.rxjava3.core.FlowableTransformer;
4445
import io.reactivex.rxjava3.core.Maybe;
46+
import io.reactivex.rxjava3.core.MaybeObserver;
4547
import io.reactivex.rxjava3.core.MaybeSource;
4648
import io.reactivex.rxjava3.core.MaybeTransformer;
4749
import io.reactivex.rxjava3.core.Single;
50+
import io.reactivex.rxjava3.core.SingleObserver;
4851
import io.reactivex.rxjava3.core.SingleSource;
4952
import io.reactivex.rxjava3.core.SingleTransformer;
53+
import io.reactivex.rxjava3.disposables.Disposable;
5054
import java.util.ArrayList;
5155
import java.util.HashMap;
5256
import java.util.List;
@@ -58,6 +62,8 @@
5862
import java.util.function.Consumer;
5963
import java.util.function.Supplier;
6064
import org.reactivestreams.Publisher;
65+
import org.reactivestreams.Subscriber;
66+
import org.reactivestreams.Subscription;
6167
import org.slf4j.Logger;
6268
import org.slf4j.LoggerFactory;
6369

@@ -550,4 +556,185 @@ public CompletableSource apply(Completable upstream) {
550556
});
551557
}
552558
}
559+
560+
/**
561+
* Returns a transformer that re-activates a given context for the duration of the stream's
562+
* subscription.
563+
*
564+
* @param context The context to re-activate.
565+
* @param <T> The type of the stream.
566+
* @return A transformer that re-activates the context.
567+
*/
568+
public static <T> ContextTransformer<T> withContext(Context context) {
569+
return new ContextTransformer<>(context);
570+
}
571+
572+
/**
573+
* A transformer that re-activates a given context for the duration of the stream's subscription.
574+
*
575+
* @param <T> The type of the stream.
576+
*/
577+
public static final class ContextTransformer<T>
578+
implements FlowableTransformer<T, T>,
579+
SingleTransformer<T, T>,
580+
MaybeTransformer<T, T>,
581+
CompletableTransformer {
582+
private final Context context;
583+
584+
private ContextTransformer(Context context) {
585+
this.context = context;
586+
}
587+
588+
@Override
589+
public Publisher<T> apply(Flowable<T> upstream) {
590+
return upstream.lift(subscriber -> TracingObserver.wrap(context, subscriber));
591+
}
592+
593+
@Override
594+
public SingleSource<T> apply(Single<T> upstream) {
595+
return upstream.lift(observer -> TracingObserver.wrap(context, observer));
596+
}
597+
598+
@Override
599+
public MaybeSource<T> apply(Maybe<T> upstream) {
600+
return upstream.lift(observer -> TracingObserver.wrap(context, observer));
601+
}
602+
603+
@Override
604+
public CompletableSource apply(Completable upstream) {
605+
return upstream.lift(observer -> TracingObserver.wrap(context, observer));
606+
}
607+
}
608+
609+
/**
610+
* An observer that wraps another observer and ensures that the OpenTelemetry context is active
611+
* during all callback methods.
612+
*
613+
* <p>This implementation only wraps the data-flow callbacks (`onNext`, `onSuccess`, etc.). The
614+
* `Subscription.request/cancel` and `Disposable.dispose` calls are not wrapped in the context. If
615+
* the upstream logic depends on the context during these signals, they might lose trace
616+
* information. Given this is a manual `withContext` utility, this might be an acceptable
617+
* trade-off for simplicity/performance, but worth keeping in mind.
618+
*
619+
* @param <T> The type of the items emitted by the stream.
620+
*/
621+
private static final class TracingObserver<T>
622+
implements Subscriber<T>, SingleObserver<T>, MaybeObserver<T>, CompletableObserver {
623+
private final Context context;
624+
private final Subscriber<? super T> subscriber;
625+
private final SingleObserver<? super T> singleObserver;
626+
private final MaybeObserver<? super T> maybeObserver;
627+
private final CompletableObserver completableObserver;
628+
629+
private TracingObserver(
630+
Context context,
631+
Subscriber<? super T> subscriber,
632+
SingleObserver<? super T> singleObserver,
633+
MaybeObserver<? super T> maybeObserver,
634+
CompletableObserver completableObserver) {
635+
this.context = context;
636+
this.subscriber = subscriber;
637+
this.singleObserver = singleObserver;
638+
this.maybeObserver = maybeObserver;
639+
this.completableObserver = completableObserver;
640+
}
641+
642+
static <T> TracingObserver<T> wrap(Context context, Subscriber<? super T> subscriber) {
643+
return new TracingObserver<>(context, subscriber, null, null, null);
644+
}
645+
646+
static <T> TracingObserver<T> wrap(Context context, SingleObserver<? super T> observer) {
647+
return new TracingObserver<>(context, null, observer, null, null);
648+
}
649+
650+
static <T> TracingObserver<T> wrap(Context context, MaybeObserver<? super T> observer) {
651+
return new TracingObserver<>(context, null, null, observer, null);
652+
}
653+
654+
static <T> TracingObserver<T> wrap(Context context, CompletableObserver observer) {
655+
return new TracingObserver<>(context, null, null, null, observer);
656+
}
657+
658+
private void runInContext(Runnable action) {
659+
try (Scope scope = context.makeCurrent()) {
660+
action.run();
661+
}
662+
}
663+
664+
@Override
665+
public void onSubscribe(Subscription s) {
666+
runInContext(
667+
() -> {
668+
if (subscriber != null) {
669+
subscriber.onSubscribe(s);
670+
}
671+
});
672+
}
673+
674+
@Override
675+
public void onSubscribe(Disposable d) {
676+
runInContext(
677+
() -> {
678+
if (singleObserver != null) {
679+
singleObserver.onSubscribe(d);
680+
} else if (maybeObserver != null) {
681+
maybeObserver.onSubscribe(d);
682+
} else if (completableObserver != null) {
683+
completableObserver.onSubscribe(d);
684+
}
685+
});
686+
}
687+
688+
@Override
689+
public void onNext(T t) {
690+
runInContext(
691+
() -> {
692+
if (subscriber != null) {
693+
subscriber.onNext(t);
694+
}
695+
});
696+
}
697+
698+
@Override
699+
public void onSuccess(T t) {
700+
runInContext(
701+
() -> {
702+
if (singleObserver != null) {
703+
singleObserver.onSuccess(t);
704+
} else if (maybeObserver != null) {
705+
maybeObserver.onSuccess(t);
706+
}
707+
});
708+
}
709+
710+
@Override
711+
public void onError(Throwable t) {
712+
runInContext(
713+
() -> {
714+
if (subscriber != null) {
715+
subscriber.onError(t);
716+
} else if (singleObserver != null) {
717+
singleObserver.onError(t);
718+
} else if (maybeObserver != null) {
719+
maybeObserver.onError(t);
720+
} else if (completableObserver != null) {
721+
completableObserver.onError(t);
722+
}
723+
});
724+
}
725+
726+
@Override
727+
public void onComplete() {
728+
runInContext(
729+
() -> {
730+
if (subscriber != null) {
731+
subscriber.onComplete();
732+
} else if (maybeObserver != null) {
733+
maybeObserver.onComplete();
734+
} else if (completableObserver != null) {
735+
completableObserver.onComplete();
736+
}
737+
});
738+
}
739+
}
553740
}

core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import com.google.adk.runner.Runner;
3232
import com.google.adk.sessions.InMemorySessionService;
3333
import com.google.adk.sessions.Session;
34+
import com.google.adk.sessions.SessionKey;
3435
import com.google.common.collect.ImmutableList;
3536
import com.google.common.collect.ImmutableMap;
3637
import com.google.genai.types.Content;
@@ -44,12 +45,17 @@
4445
import io.opentelemetry.api.trace.Span;
4546
import io.opentelemetry.api.trace.Tracer;
4647
import io.opentelemetry.context.Context;
48+
import io.opentelemetry.context.ContextKey;
4749
import io.opentelemetry.context.Scope;
4850
import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule;
4951
import io.opentelemetry.sdk.trace.data.SpanData;
52+
import io.reactivex.rxjava3.core.Completable;
5053
import io.reactivex.rxjava3.core.Flowable;
54+
import io.reactivex.rxjava3.core.Maybe;
55+
import io.reactivex.rxjava3.core.Single;
5156
import io.reactivex.rxjava3.schedulers.Schedulers;
5257
import java.util.List;
58+
import java.util.Map;
5359
import java.util.Optional;
5460
import org.junit.After;
5561
import org.junit.Before;
@@ -380,6 +386,70 @@ public void testTraceFlowable() throws InterruptedException {
380386
assertTrue(flowableSpanData.hasEnded());
381387
}
382388

389+
@Test
390+
public void testWithContextFlowable() throws InterruptedException {
391+
ContextKey<String> testKey = ContextKey.named("test-key");
392+
Context testContext = Context.root().with(testKey, "test-value");
393+
394+
Flowable<Integer> flowable =
395+
Flowable.just(1, 2, 3)
396+
.compose(Tracing.withContext(testContext))
397+
.subscribeOn(Schedulers.computation())
398+
.doOnNext(
399+
i -> {
400+
assertEquals("test-value", Context.current().get(testKey));
401+
});
402+
flowable.test().await().assertComplete();
403+
}
404+
405+
@Test
406+
public void testWithContextSingle() throws InterruptedException {
407+
ContextKey<String> testKey = ContextKey.named("test-key");
408+
Context testContext = Context.root().with(testKey, "test-value");
409+
410+
Single<Integer> single =
411+
Single.just(1)
412+
.compose(Tracing.withContext(testContext))
413+
.subscribeOn(Schedulers.computation())
414+
.doOnSuccess(
415+
i -> {
416+
assertEquals("test-value", Context.current().get(testKey));
417+
});
418+
single.test().await().assertComplete();
419+
}
420+
421+
@Test
422+
public void testWithContextMaybe() throws InterruptedException {
423+
ContextKey<String> testKey = ContextKey.named("test-key");
424+
Context testContext = Context.root().with(testKey, "test-value");
425+
426+
Maybe<Integer> maybe =
427+
Maybe.just(1)
428+
.compose(Tracing.withContext(testContext))
429+
.subscribeOn(Schedulers.computation())
430+
.doOnSuccess(
431+
i -> {
432+
assertEquals("test-value", Context.current().get(testKey));
433+
});
434+
maybe.test().await().assertComplete();
435+
}
436+
437+
@Test
438+
public void testWithContextCompletable() throws InterruptedException {
439+
ContextKey<String> testKey = ContextKey.named("test-key");
440+
Context testContext = Context.root().with(testKey, "test-value");
441+
442+
Completable completable =
443+
Completable.complete()
444+
.compose(Tracing.withContext(testContext))
445+
.subscribeOn(Schedulers.computation())
446+
.doOnComplete(
447+
() -> {
448+
assertEquals("test-value", Context.current().get(testKey));
449+
});
450+
completable.test().await().assertComplete();
451+
}
452+
383453
@Test
384454
public void testTraceTransformer() throws InterruptedException {
385455
Span parentSpan = tracer.spanBuilder("parent").startSpan();
@@ -595,7 +665,7 @@ public void runnerRunAsync_propagatesContext() throws InterruptedException {
595665
Session session =
596666
runner
597667
.sessionService()
598-
.createSession("test-app", "test-user", null, "test-session")
668+
.createSession(new SessionKey("test-app", "test-user", "test-session"))
599669
.blockingGet();
600670
Content newMessage = Content.fromParts(Part.fromText("hi"));
601671
RunConfig runConfig = RunConfig.builder().build();
@@ -623,13 +693,20 @@ public void runnerRunLive_propagatesContext() throws InterruptedException {
623693
Span parentSpan = tracer.spanBuilder("parent").startSpan();
624694
try (Scope s = parentSpan.makeCurrent()) {
625695
Session session =
626-
Session.builder("test-session").userId("test-user").appName("test-app").build();
696+
runner
697+
.sessionService()
698+
.createSession("test-app", "test-user", (Map<String, Object>) null, "test-session")
699+
.blockingGet();
627700
Content newMessage = Content.fromParts(Part.fromText("hi"));
628701
RunConfig runConfig = RunConfig.builder().build();
629702
LiveRequestQueue liveRequestQueue = new LiveRequestQueue();
630703
liveRequestQueue.content(newMessage);
631704
liveRequestQueue.close();
632-
runner.runLive(session, liveRequestQueue, runConfig).test().await().assertComplete();
705+
runner
706+
.runLive(session.userId(), session.id(), liveRequestQueue, runConfig)
707+
.test()
708+
.await()
709+
.assertComplete();
633710
} finally {
634711
parentSpan.end();
635712
}

0 commit comments

Comments
 (0)