Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion otelriver/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package otelriver
import (
"cmp"
"context"
"encoding/json"
"errors"
"slices"
"time"
Expand All @@ -11,6 +12,7 @@ import (
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"

"github.com/riverqueue/river"
Expand Down Expand Up @@ -45,6 +47,13 @@ type MiddlewareConfig struct {
// metric names, with attributes differentiating them.
EnableSemanticMetrics bool

// EnableTracePropagation injects W3C trace context (traceparent/tracestate)
// into job metadata on insert and extracts it on work, adding a span link
// from the work span back to the span that enqueued the job. A link is used
// rather than a parent so the work span's timeline is independent of the
// insert span (the two may be separated by minutes or hours).
EnableTracePropagation bool

// EnableWorkSpanJobKindSuffix appends the job kind a suffix to work spans
// so they look like `river.work/my_job` instead of `river.work`.
EnableWorkSpanJobKindSuffix bool
Expand Down Expand Up @@ -208,6 +217,12 @@ func (m *Middleware) InsertMany(ctx context.Context, manyParams []*rivertype.Job
}
}()

if m.config.EnableTracePropagation {
for i := range manyParams {
manyParams[i].Metadata = injectTraceContext(ctx, manyParams[i].Metadata)
}
}

insertRes, err = doInner(ctx)
panicked = false
return insertRes, err
Expand All @@ -219,8 +234,14 @@ func (m *Middleware) Work(ctx context.Context, job *rivertype.JobRow, doInner fu
spanName += "/" + job.Kind
}

var startOpts []trace.SpanStartOption
if m.config.EnableTracePropagation {
if sc := extractSpanContext(ctx, job.Metadata); sc.IsValid() {
startOpts = append(startOpts, trace.WithLinks(trace.Link{SpanContext: sc}))
}
}
ctx, span := m.tracer.Start(ctx, spanName,
trace.WithSpanKind(trace.SpanKindConsumer))
append(startOpts, trace.WithSpanKind(trace.SpanKindConsumer))...)
defer span.End()

attrs := []attribute.KeyValue{
Expand Down Expand Up @@ -336,6 +357,53 @@ func mustInt64Counter(meter metric.Meter, name string, options ...metric.Int64Co
return metric
}

// injectTraceContext injects the current span context from ctx into metadata
// JSON under the W3C "traceparent" (and optionally "tracestate") key. If
// injection fails for any reason the original metadata is returned unchanged.
func injectTraceContext(ctx context.Context, metadata []byte) []byte {
carrier := make(propagation.MapCarrier)
propagation.TraceContext{}.Inject(ctx, carrier)
if len(carrier) == 0 {
return metadata
}
if len(metadata) == 0 {
metadata = []byte("{}")
}
var meta map[string]any
if err := json.Unmarshal(metadata, &meta); err != nil {
return metadata
}
for k, v := range carrier {
meta[k] = v
}
injected, err := json.Marshal(meta)
if err != nil {
return metadata
}
return injected
}

// extractSpanContext reads W3C trace context from metadata JSON and returns the
// remote SpanContext it encodes. Returns a zero SpanContext (IsValid() == false)
// if no traceparent is present or the metadata cannot be parsed.
func extractSpanContext(ctx context.Context, metadata []byte) trace.SpanContext {
if len(metadata) == 0 {
return trace.SpanContext{}
}
var meta map[string]any
if err := json.Unmarshal(metadata, &meta); err != nil {
return trace.SpanContext{}
}
carrier := make(propagation.MapCarrier)
for k, v := range meta {
if s, ok := v.(string); ok {
carrier[k] = s
}
}
extracted := propagation.TraceContext{}.Extract(ctx, carrier)
return trace.SpanFromContext(extracted).SpanContext()
}

// Sets success status on the given span and within the set of attributes. The
// index of the status attribute is required ahead of time as a minor
// optimization.
Expand Down
75 changes: 75 additions & 0 deletions otelriver/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package otelriver

import (
"context"
"encoding/json"
"errors"
"fmt"
"testing"
Expand All @@ -10,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
"go.opentelemetry.io/otel/sdk/metric/metricdata/metricdatatest"
Expand Down Expand Up @@ -304,6 +306,34 @@ func TestMiddleware(t *testing.T) {
getAttribute(t, spans[0].Attributes, "kinds").AsStringSlice())
})

t.Run("InsertManyInjectsTraceparent", func(t *testing.T) {
t.Parallel()

middleware, bundle := setupConfig(t, &MiddlewareConfig{EnableTracePropagation: true})

params := []*rivertype.JobInsertParams{{Kind: "no_op"}}
doInner := func(ctx context.Context) ([]*rivertype.JobInsertResult, error) {
return []*rivertype.JobInsertResult{{Job: &rivertype.JobRow{ID: 1}}}, nil
}

_, err := middleware.InsertMany(ctx, params, doInner)
require.NoError(t, err)

// The insert_many span's context should have been injected into the params metadata.
require.NotNil(t, params[0].Metadata)
var meta map[string]any
require.NoError(t, json.Unmarshal(params[0].Metadata, &meta))
traceparent, ok := meta["traceparent"].(string)
require.True(t, ok, "expected traceparent key in job metadata")

// The traceparent must reference the insert_many span's trace and span IDs.
spans := bundle.traceExporter.GetSpans()
require.Len(t, spans, 1)
insertSpan := spans[0]
require.Contains(t, traceparent, insertSpan.SpanContext.TraceID().String())
require.Contains(t, traceparent, insertSpan.SpanContext.SpanID().String())
})

t.Run("InsertManyDurationUnitMS", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -750,6 +780,51 @@ func TestMiddleware(t *testing.T) {
}
})

t.Run("WorkExtractsTraceparent", func(t *testing.T) {
t.Parallel()

middleware, bundle := setupConfig(t, &MiddlewareConfig{EnableTracePropagation: true})

// Build a synthetic traceparent pointing to a remote parent span.
parentTraceID := "4bf92f3577b34da6a3ce929d0e0e4736"
parentSpanID := "00f067aa0ba902b7"
carrier := propagation.MapCarrier{
"traceparent": fmt.Sprintf("00-%s-%s-01", parentTraceID, parentSpanID),
}
metadata, err := json.Marshal(carrier)
require.NoError(t, err)

err = middleware.Work(ctx, &rivertype.JobRow{
Kind: "no_op",
Metadata: metadata,
}, func(ctx context.Context) error { return nil })
require.NoError(t, err)

spans := bundle.traceExporter.GetSpans()
require.Len(t, spans, 1)
workSpan := spans[0]

// The work span must be linked to (not a child of) the insert span.
require.False(t, workSpan.Parent.IsValid(), "work span should not be a child of the insert span")
require.Len(t, workSpan.Links, 1)
require.Equal(t, parentTraceID, workSpan.Links[0].SpanContext.TraceID().String())
require.Equal(t, parentSpanID, workSpan.Links[0].SpanContext.SpanID().String())
})

t.Run("WorkExtractsTraceparentMissingMetadata", func(t *testing.T) {
t.Parallel()

middleware, bundle := setupConfig(t, &MiddlewareConfig{EnableTracePropagation: true})

// No traceparent in metadata — work span should be a root span.
err := middleware.Work(ctx, &rivertype.JobRow{Kind: "no_op"}, func(ctx context.Context) error { return nil })
require.NoError(t, err)

spans := bundle.traceExporter.GetSpans()
require.Len(t, spans, 1)
require.False(t, spans[0].Parent.IsValid(), "expected no parent span when metadata has no traceparent")
})

t.Run("WorkEnableWorkSpanJobKindSuffix ", func(t *testing.T) {
t.Parallel()

Expand Down
Loading