From 1cf38e542da7789e8a6157f02d85def8d23577d4 Mon Sep 17 00:00:00 2001 From: Adarsh Tiwari Date: Tue, 2 Jun 2026 13:19:54 +0530 Subject: [PATCH] make event retries end-to-end --- backend/internal/store/events.go | 23 +++++-- backend/internal/ws/handler.go | 34 ++++++++-- backend/internal/ws/handler_test.go | 102 +++++++++++++++++++++++++++- 3 files changed, 147 insertions(+), 12 deletions(-) diff --git a/backend/internal/store/events.go b/backend/internal/store/events.go index 5fcce0f..a388a01 100644 --- a/backend/internal/store/events.go +++ b/backend/internal/store/events.go @@ -65,16 +65,25 @@ type EventFilters struct { MinMAD string } -// InsertEvent persists one paired event. The payload column holds the -// full JSON-encoded PairedEvent blob; downstream readers (dashboard, -// engine) re-decode it as needed. -func (s *Store) InsertEvent(ctx context.Context, e *Event) error { - _, err := s.db.ExecContext(ctx, +// InsertEvent persists one paired event and reports whether a new row +// was inserted. The payload column holds the full JSON-encoded +// PairedEvent blob; downstream readers (dashboard, engine) re-decode it +// as needed. SDK retries can replay the same event_id. +func (s *Store) InsertEvent(ctx context.Context, e *Event) (bool, error) { + res, err := s.db.ExecContext(ctx, `INSERT INTO events (id, session_id, agent_id, agent_profile_id, event_type, run_id, payload, tokens_used) - VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (id) DO NOTHING`, e.ID, e.SessionID, e.AgentID, e.AgentProfileID, e.EventType, e.RunID, e.PayloadJSON, e.TokensUsed) - return err + if err != nil { + return false, err + } + n, err := res.RowsAffected() + if err != nil { + return false, err + } + return n > 0, nil } // ListEvents returns a page of events matching the filters, plus the diff --git a/backend/internal/ws/handler.go b/backend/internal/ws/handler.go index 29a16a5..fdf5580 100644 --- a/backend/internal/ws/handler.go +++ b/backend/internal/ws/handler.go @@ -241,9 +241,31 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla if err != nil { return err } - if err := st.InsertEvent(ctx, newEventRow(sess, ev, payloadJSON)); err != nil { + inserted, err := st.InsertEvent(ctx, newEventRow(sess, ev, payloadJSON)) + if err != nil { return err } + if !inserted { + // A retry can race with the first in-flight delivery: event row + // already exists, but verdict insert has not landed yet. Poll a + // few times before giving up so we don't fail the WS batch on a + // benign duplicate. + for i := 0; i < 3; i++ { + existing, err := st.GetVerdictByEventID(ctx, ev.EventId) + if err == nil { + return dispatchVerdict(ctx, sess, st, hub, ev, snap, existing.ID, existing.MADCode) + } + if !errors.Is(err, store.ErrNotFound) { + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(25 * time.Millisecond): + } + } + return nil + } // Refresh the runtime agents row so the dashboard's /agents view // reflects activity. Best-effort: a write failure is logged but @@ -296,6 +318,10 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla verdict.MADCode, verdict.Classification) } + return dispatchVerdict(ctx, sess, st, hub, ev, snap, vrow.ID, verdict.MADCode) +} + +func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *Hub, ev *pb.PairedEvent, snap *pb.PolicySnapshot, verdictID, madCode string) error { // Mode-gated dispatch: // alert: persist verdict, do NOT notify the SDK (dashboard-only). // hitl + in-scope + actionable: persist + queue for human review, @@ -304,7 +330,7 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla // no-op for the operator since the SDK never blocks on it). // hitl + out-of-scope: forward (no review queued for this code). // block: forward all verdicts; SDK is the enforcement point. - inScope := shouldFanOut(snap, verdict.MADCode) + inScope := shouldFanOut(snap, madCode) switch snap.GetMode() { case pb.Mode_MODE_ALERT: // Dashboard-only: the verdict is persisted, the SDK is not @@ -315,7 +341,7 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla // Queue for human review and hold the verdict. The reviews // REST handler resumes the SDK with a HitlResponse-bearing // Verdict on approve/reject via the same hub channel. - if err := st.InsertHitlQueue(ctx, ev.EventId, vrow.ID, sess.sessionID, verdict.MADCode); err != nil { + if err := st.InsertHitlQueue(ctx, ev.EventId, verdictID, sess.sessionID, madCode); err != nil { slog.ErrorContext(ctx, "hitl.insert_failed", "error", err, "event_id", ev.EventId) } @@ -338,7 +364,7 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla Verdict: &pb.Verdict{ EventId: ev.EventId, SessionId: sess.sessionID, - MadCode: verdict.MADCode, + MadCode: madCode, Policy: snap, }, }, diff --git a/backend/internal/ws/handler_test.go b/backend/internal/ws/handler_test.go index a4f9142..3ed4cc6 100644 --- a/backend/internal/ws/handler_test.go +++ b/backend/internal/ws/handler_test.go @@ -12,6 +12,7 @@ import ( "net/http/httptest" "net/url" "strings" + "sync/atomic" "testing" "time" @@ -158,6 +159,100 @@ func TestRoundTrip(t *testing.T) { websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) } +func TestDuplicateEventRetryKeepsWSOpen(t *testing.T) { + db := openInMemoryDB(t) + t.Cleanup(func() { _ = db.Close() }) + + st := store.New(db) + plaintextKey := "adr_local_test_key_retry" + keyHash := sha256Hex(plaintextKey) + insertAPIKey(t, db, keyHash) + if _, err := db.Exec(`UPDATE policies SET mode = 'block' WHERE id = 1`); err != nil { + t.Fatalf("set mode=block: %v", err) + } + + var classifyCalls int32 + mux := http.NewServeMux() + mux.Handle("/ws", ws.AuthMiddleware(st)(ws.NewHandler(st, &fakeClassifier{calls: &classifyCalls}, ws.NewHub(), nil, nil))) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws" + header := http.Header{"Authorization": {"Bearer " + plaintextKey}} + conn, _, err := websocket.DefaultDialer.Dial(wsURL, header) + if err != nil { + t.Fatalf("dial: %v", err) + } + t.Cleanup(func() { _ = conn.Close() }) + + if err := writeProto(conn, &bpb.ClientFrame{ + Frame: &bpb.ClientFrame_Login{Login: &bpb.SessionLogin{ + SessionId: "retry-sess", SchemaVersion: 2, + }}, + }); err != nil { + t.Fatalf("send login: %v", err) + } + if _, err := readServerFrame(conn); err != nil { + t.Fatalf("read login_ack: %v", err) + } + + eventID := uuid.NewString() + batchFrame := &bpb.ClientFrame{ + Frame: &bpb.ClientFrame_PairedBatch{PairedBatch: &bpb.PairedEventBatch{ + Events: []*bpb.PairedEvent{{ + EventId: eventID, SessionId: "retry-sess", + RunId: "run-retry", + PairType: bpb.PairType_PAIR_TYPE_TOOL, + Agent: &bpb.AgentContext{AgentId: "retry-agent"}, + Data: &bpb.PairedEvent_Tool{Tool: &bpb.ToolPairData{ + ToolName: "noop", ToolCallId: "tc-retry", Input: "{}", Output: "ok", + }}, + }}, + }}, + } + + if err := writeProto(conn, batchFrame); err != nil { + t.Fatalf("send first paired_batch: %v", err) + } + firstVerdict, err := readServerFrame(conn) + if err != nil { + t.Fatalf("read first verdict: %v", err) + } + if got := firstVerdict.GetVerdict().GetEventId(); got != eventID { + t.Fatalf("first verdict event_id = %q, want %q", got, eventID) + } + + if err := writeProto(conn, batchFrame); err != nil { + t.Fatalf("send retry paired_batch: %v", err) + } + retryVerdict, err := readServerFrame(conn) + if err != nil { + t.Fatalf("read retry verdict after duplicate event insert: %v", err) + } + if got := retryVerdict.GetVerdict().GetEventId(); got != eventID { + t.Fatalf("retry verdict event_id = %q, want %q", got, eventID) + } + + var eventRows int + if err := db.QueryRow("SELECT count(*) FROM events WHERE id = ?", eventID).Scan(&eventRows); err != nil { + t.Fatalf("query events: %v", err) + } + if eventRows != 1 { + t.Fatalf("expected duplicate retry to keep 1 event row, got %d", eventRows) + } + + var verdictRows int + if err := db.QueryRow("SELECT count(*) FROM verdicts WHERE event_id = ?", eventID).Scan(&verdictRows); err != nil { + t.Fatalf("query verdicts: %v", err) + } + if verdictRows != 1 { + t.Fatalf("expected duplicate retry to keep 1 verdict row, got %d", verdictRows) + } + if got := atomic.LoadInt32(&classifyCalls); got != 1 { + t.Fatalf("expected classifier to run once for duplicate retry, got %d calls", got) + } +} + // TestAlertModeNoFanOut confirms the mode gate withholds Verdict // frames from the SDK in alert mode (dashboard-only). The verdict row // is still persisted; only the WS write is suppressed. @@ -363,9 +458,14 @@ func TestUnauthDial(t *testing.T) { // helpers // ----------------------------------------------------------------- -type fakeClassifier struct{} +type fakeClassifier struct { + calls *int32 +} func (f *fakeClassifier) Classify(_ context.Context, _ *bpb.PairedEvent, _ string) (*engine.Verdict, error) { + if f.calls != nil { + atomic.AddInt32(f.calls, 1) + } return &engine.Verdict{MADCode: "M0", Classification: "benign"}, nil }