@@ -245,7 +245,7 @@ func (s *ADSServer) StreamAggregatedResources(stream ads.SotWStream) (err error)
245245 }
246246
247247 err = h .loop ()
248- slog .DebugContext (h .streamCtx , "Closing stream" , "err" , err )
248+ slog .DebugContext (h .stream . Context () , "Closing stream" , "err" , err )
249249 return err
250250}
251251
@@ -281,7 +281,7 @@ func (s *ADSServer) DeltaAggregatedResources(stream ads.DeltaStream) (err error)
281281 }
282282
283283 err = h .loop ()
284- slog .DebugContext (h .streamCtx , "Closing stream" , "err" , err )
284+ slog .DebugContext (h .stream . Context () , "Closing stream" , "err" , err )
285285 return err
286286}
287287
@@ -293,19 +293,23 @@ type adsDiscoveryRequest interface {
293293 GetNode () * ads.Node
294294}
295295
296+ type adsDiscoveryResponse interface {
297+ proto.Message
298+ GetNonce () string
299+ }
300+
296301type adsStream [REQ adsDiscoveryRequest , RES proto.Message ] interface {
297302 Context () context.Context
298303 Recv () (REQ , error )
299304 Send (RES ) error
300305}
301306
302307// streamHandler captures the various elements required to handle an ADS stream.
303- type streamHandler [REQ adsDiscoveryRequest , RES proto. Message ] struct {
308+ type streamHandler [REQ adsDiscoveryRequest , RES adsDiscoveryResponse ] struct {
304309 sendLock sync.Mutex
305310
306311 server * ADSServer
307312 stream adsStream [REQ , RES ]
308- streamCtx context.Context
309313 streamType ads.StreamType
310314 newHandler func (
311315 ctx context.Context ,
@@ -331,7 +335,13 @@ func (h *streamHandler[REQ, RES]) send(res RES) (err error) {
331335 h .sendLock .Lock ()
332336 defer h .sendLock .Unlock ()
333337 h .setControlPlane (res , h .server .controlPlane )
334- slog .DebugContext (h .streamCtx , "Sending" , "msg" , res )
338+ slog .DebugContext (h .stream .Context (), "Sending" , "msg" , res )
339+ if h .server .statsHandler != nil {
340+ h .server .statsHandler .HandleServerEvent (h .stream .Context (), & serverstats.SendingResponse {
341+ Res : res ,
342+ Nonce : res .GetNonce (),
343+ })
344+ }
335345 return h .stream .Send (res )
336346}
337347
@@ -354,11 +364,11 @@ func (h *streamHandler[REQ, RES]) getSubscriptionManager(typeURL string) interna
354364 }
355365
356366 manager := h .newManager (
357- h .streamCtx ,
367+ h .stream . Context () ,
358368 h .server .locator ,
359369 typeURL ,
360370 h .newHandler (
361- h .streamCtx ,
371+ h .stream . Context () ,
362372 h .server .newGranularRateLimiter (),
363373 h .server .statsHandler ,
364374 typeURL ,
@@ -378,11 +388,6 @@ func (h *streamHandler[REQ, RES]) loop() error {
378388 return err
379389 }
380390
381- // initialize the stream context with the node on the first request
382- if h .streamCtx == nil {
383- h .streamCtx = context .WithValue (h .stream .Context (), nodeContextKey {}, req .GetNode ())
384- }
385-
386391 err = h .handleRequest (req )
387392 if err != nil {
388393 return err
@@ -391,19 +396,38 @@ func (h *streamHandler[REQ, RES]) loop() error {
391396}
392397
393398func (h * streamHandler [REQ , RES ]) handleRequest (req REQ ) (err error ) {
394- slog .DebugContext (h .streamCtx , "Received request" , "req" , req )
395-
396- var stat * serverstats.RequestReceived
397399 if h .server .statsHandler != nil {
398400 start := time .Now ()
399- stat = & serverstats.RequestReceived {Req : req }
400401 defer func () {
401- stat .Duration = time .Since (start )
402- h .server .statsHandler .HandleServerEvent (h .streamCtx , stat )
402+ h .server .statsHandler .HandleServerEvent (h .stream .Context (), & serverstats.RequestProcessed {
403+ Req : req ,
404+ Duration : time .Since (start ),
405+ })
403406 }()
404407 }
405408
406- err = h .server .requestLimiter .Wait (h .streamCtx )
409+ slog .DebugContext (h .stream .Context (), "Received request" , "req" , req )
410+
411+ var isACK , isNACK bool
412+ switch {
413+ case req .GetErrorDetail () != nil :
414+ slog .WarnContext (h .stream .Context (), "Got client NACK" , "req" , req )
415+ isNACK = true
416+ case req .GetResponseNonce () != "" :
417+ slog .DebugContext (h .stream .Context (), "ACKED" , "req" , req )
418+ isACK = true
419+ }
420+
421+ if h .server .statsHandler != nil {
422+ h .server .statsHandler .HandleServerEvent (h .stream .Context (), & serverstats.RequestReceived {
423+ Req : req ,
424+ IsACK : isACK ,
425+ IsNACK : isNACK ,
426+ Nonce : req .GetResponseNonce (),
427+ })
428+ }
429+
430+ err = h .server .requestLimiter .Wait (h .stream .Context ())
407431 if err != nil {
408432 return err
409433 }
@@ -412,19 +436,6 @@ func (h *streamHandler[REQ, RES]) handleRequest(req REQ) (err error) {
412436 h .aggregateSubscriptions = make (map [string ]internal.SubscriptionManager [REQ ])
413437 }
414438
415- switch {
416- case req .GetErrorDetail () != nil :
417- slog .WarnContext (h .streamCtx , "Got client NACK" , "req" , req )
418- if stat != nil {
419- stat .IsNACK = true
420- }
421- case req .GetResponseNonce () != "" :
422- slog .DebugContext (h .streamCtx , "ACKED" , "req" , req )
423- if stat != nil {
424- stat .IsACK = true
425- }
426- }
427-
428439 h .getSubscriptionManager (req .GetTypeUrl ()).ProcessSubscriptions (req )
429440
430441 return nil
@@ -463,13 +474,3 @@ type ResourceLocator interface {
463474 handler ads.RawSubscriptionHandler ,
464475 ) (unsubscribe func ())
465476}
466-
467- type nodeContextKey struct {}
468-
469- // NodeFromContext returns the [ads.Node] in the given context, if it exists. Note that the
470- // [ADSServer] will always provide the Node in the context when invoking methods on the
471- // [ResourceLocator].
472- func NodeFromContext (streamCtx context.Context ) (* ads.Node , bool ) {
473- node , ok := streamCtx .Value (nodeContextKey {}).(* ads.Node )
474- return node , ok
475- }
0 commit comments