diff --git a/context/context.go b/context/context.go index 7cfb1e9..ade8897 100644 --- a/context/context.go +++ b/context/context.go @@ -27,3 +27,12 @@ func ActorFromContext(ctx context.Context) *Actor { return a } + +// ActorIDFromContext safely extracts the actor ID from the context. +// Returns an empty string if no actor is found. +func ActorIDFromContext(ctx context.Context) string { + if actor := ActorFromContext(ctx); actor != nil { + return actor.ID + } + return "" +} diff --git a/context/context_test.go b/context/context_test.go new file mode 100644 index 0000000..22e4d56 --- /dev/null +++ b/context/context_test.go @@ -0,0 +1,88 @@ +package context + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/aibridge/recorder" +) + +func TestAsActor(t *testing.T) { + t.Parallel() + + // Given: a metadata map + metadata := recorder.Metadata{"key": "value"} + + // When: storing an actor in the context + ctx := AsActor(context.Background(), "actor-123", metadata) + + // Then: the actor should be retrievable with correct ID and metadata + actor := ActorFromContext(ctx) + require.NotNil(t, actor) + assert.Equal(t, "actor-123", actor.ID) + assert.Equal(t, "value", actor.Metadata["key"]) +} + +func TestActorFromContext(t *testing.T) { + t.Parallel() + + t.Run("returns actor when present", func(t *testing.T) { + t.Parallel() + + // Given: a context with an actor + ctx := AsActor(context.Background(), "test-id", recorder.Metadata{}) + + // When: extracting the actor from context + actor := ActorFromContext(ctx) + + // Then: the actor should be returned with correct ID + require.NotNil(t, actor) + assert.Equal(t, "test-id", actor.ID) + }) + + t.Run("returns nil when no actor", func(t *testing.T) { + t.Parallel() + + // Given: a context without an actor + ctx := context.Background() + + // When: extracting the actor from context + actor := ActorFromContext(ctx) + + // Then: nil should be returned + assert.Nil(t, actor) + }) +} + +func TestActorIDFromContext(t *testing.T) { + t.Parallel() + + t.Run("returns actor ID when present", func(t *testing.T) { + t.Parallel() + + // Given: a context with an actor + ctx := AsActor(context.Background(), "test-actor-id", recorder.Metadata{}) + + // When: extracting the actor ID from context + got := ActorIDFromContext(ctx) + + // Then: the actor ID should be returned + assert.Equal(t, "test-actor-id", got) + }) + + t.Run("returns empty string when no actor", func(t *testing.T) { + t.Parallel() + + // Given: a context without an actor + ctx := context.Background() + + // When: extracting the actor ID from context + got := ActorIDFromContext(ctx) + + // Then: an empty string should be returned + assert.Empty(t, got) + }) +} diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index 2610b3e..02e7f7b 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -61,7 +61,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) return []attribute.KeyValue{ attribute.String(tracing.RequestPath, r.URL.Path), attribute.String(tracing.InterceptionID, s.id.String()), - attribute.String(tracing.InitiatorID, aibcontext.ActorFromContext(r.Context()).ID), + attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())), attribute.String(tracing.Provider, config.ProviderOpenAI), attribute.String(tracing.Model, s.Model()), attribute.Bool(tracing.Streaming, streaming), diff --git a/intercept/messages/base.go b/intercept/messages/base.go index ba37067..6db64fa 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -75,7 +75,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) return []attribute.KeyValue{ attribute.String(tracing.RequestPath, r.URL.Path), attribute.String(tracing.InterceptionID, s.id.String()), - attribute.String(tracing.InitiatorID, aibcontext.ActorFromContext(r.Context()).ID), + attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())), attribute.String(tracing.Provider, aibconfig.ProviderAnthropic), attribute.String(tracing.Model, s.Model()), attribute.Bool(tracing.Streaming, streaming), diff --git a/intercept/responses/base.go b/intercept/responses/base.go index 78c09d0..f43dd79 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -78,7 +78,7 @@ func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streami return []attribute.KeyValue{ attribute.String(tracing.RequestPath, r.URL.Path), attribute.String(tracing.InterceptionID, i.id.String()), - attribute.String(tracing.InitiatorID, aibcontext.ActorFromContext(r.Context()).ID), + attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())), attribute.String(tracing.Provider, config.ProviderOpenAI), attribute.String(tracing.Model, i.Model()), attribute.Bool(tracing.Streaming, streaming),