diff --git a/backend/internal/store/events.go b/backend/internal/store/events.go index 808bfd7..a388a01 100644 --- a/backend/internal/store/events.go +++ b/backend/internal/store/events.go @@ -65,17 +65,25 @@ type EventFilters struct { MinMAD string } -// InsertEvent persists one paired event. The payload column holds the -// full JSON-encoded PairedEvent blob; downstream readers (dashboard, -// engine) re-decode it as needed. SDK retries can replay the same -// event_id, so duplicate primary-key inserts are ignored. -func (s *Store) InsertEvent(ctx context.Context, e *Event) error { - _, err := s.db.ExecContext(ctx, - `INSERT OR IGNORE INTO events +// InsertEvent persists one paired event and reports whether a new row +// was inserted. The payload column holds the full JSON-encoded +// PairedEvent blob; downstream readers (dashboard, engine) re-decode it +// as needed. SDK retries can replay the same event_id. +func (s *Store) InsertEvent(ctx context.Context, e *Event) (bool, error) { + res, err := s.db.ExecContext(ctx, + `INSERT INTO events (id, session_id, agent_id, agent_profile_id, event_type, run_id, payload, tokens_used) - VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (id) DO NOTHING`, e.ID, e.SessionID, e.AgentID, e.AgentProfileID, e.EventType, e.RunID, e.PayloadJSON, e.TokensUsed) - return err + if err != nil { + return false, err + } + n, err := res.RowsAffected() + if err != nil { + return false, err + } + return n > 0, nil } // ListEvents returns a page of events matching the filters, plus the diff --git a/backend/internal/ws/handler.go b/backend/internal/ws/handler.go index 29a16a5..fdf5580 100644 --- a/backend/internal/ws/handler.go +++ b/backend/internal/ws/handler.go @@ -241,9 +241,31 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla if err != nil { return err } - if err := st.InsertEvent(ctx, newEventRow(sess, ev, payloadJSON)); err != nil { + inserted, err := st.InsertEvent(ctx, newEventRow(sess, ev, payloadJSON)) + if err != nil { return err } + if !inserted { + // A retry can race with the first in-flight delivery: event row + // already exists, but verdict insert has not landed yet. Poll a + // few times before giving up so we don't fail the WS batch on a + // benign duplicate. + for i := 0; i < 3; i++ { + existing, err := st.GetVerdictByEventID(ctx, ev.EventId) + if err == nil { + return dispatchVerdict(ctx, sess, st, hub, ev, snap, existing.ID, existing.MADCode) + } + if !errors.Is(err, store.ErrNotFound) { + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(25 * time.Millisecond): + } + } + return nil + } // Refresh the runtime agents row so the dashboard's /agents view // reflects activity. Best-effort: a write failure is logged but @@ -296,6 +318,10 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla verdict.MADCode, verdict.Classification) } + return dispatchVerdict(ctx, sess, st, hub, ev, snap, vrow.ID, verdict.MADCode) +} + +func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *Hub, ev *pb.PairedEvent, snap *pb.PolicySnapshot, verdictID, madCode string) error { // Mode-gated dispatch: // alert: persist verdict, do NOT notify the SDK (dashboard-only). // hitl + in-scope + actionable: persist + queue for human review, @@ -304,7 +330,7 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla // no-op for the operator since the SDK never blocks on it). // hitl + out-of-scope: forward (no review queued for this code). // block: forward all verdicts; SDK is the enforcement point. - inScope := shouldFanOut(snap, verdict.MADCode) + inScope := shouldFanOut(snap, madCode) switch snap.GetMode() { case pb.Mode_MODE_ALERT: // Dashboard-only: the verdict is persisted, the SDK is not @@ -315,7 +341,7 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla // Queue for human review and hold the verdict. The reviews // REST handler resumes the SDK with a HitlResponse-bearing // Verdict on approve/reject via the same hub channel. - if err := st.InsertHitlQueue(ctx, ev.EventId, vrow.ID, sess.sessionID, verdict.MADCode); err != nil { + if err := st.InsertHitlQueue(ctx, ev.EventId, verdictID, sess.sessionID, madCode); err != nil { slog.ErrorContext(ctx, "hitl.insert_failed", "error", err, "event_id", ev.EventId) } @@ -338,7 +364,7 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla Verdict: &pb.Verdict{ EventId: ev.EventId, SessionId: sess.sessionID, - MadCode: verdict.MADCode, + MadCode: madCode, Policy: snap, }, }, diff --git a/backend/internal/ws/handler_test.go b/backend/internal/ws/handler_test.go index a34e670..3ed4cc6 100644 --- a/backend/internal/ws/handler_test.go +++ b/backend/internal/ws/handler_test.go @@ -12,6 +12,7 @@ import ( "net/http/httptest" "net/url" "strings" + "sync/atomic" "testing" "time" @@ -170,8 +171,9 @@ func TestDuplicateEventRetryKeepsWSOpen(t *testing.T) { t.Fatalf("set mode=block: %v", err) } + var classifyCalls int32 mux := http.NewServeMux() - mux.Handle("/ws", ws.AuthMiddleware(st)(ws.NewHandler(st, &fakeClassifier{}, ws.NewHub(), nil, nil))) + mux.Handle("/ws", ws.AuthMiddleware(st)(ws.NewHandler(st, &fakeClassifier{calls: &classifyCalls}, ws.NewHub(), nil, nil))) srv := httptest.NewServer(mux) t.Cleanup(srv.Close) @@ -238,6 +240,17 @@ func TestDuplicateEventRetryKeepsWSOpen(t *testing.T) { if eventRows != 1 { t.Fatalf("expected duplicate retry to keep 1 event row, got %d", eventRows) } + + var verdictRows int + if err := db.QueryRow("SELECT count(*) FROM verdicts WHERE event_id = ?", eventID).Scan(&verdictRows); err != nil { + t.Fatalf("query verdicts: %v", err) + } + if verdictRows != 1 { + t.Fatalf("expected duplicate retry to keep 1 verdict row, got %d", verdictRows) + } + if got := atomic.LoadInt32(&classifyCalls); got != 1 { + t.Fatalf("expected classifier to run once for duplicate retry, got %d calls", got) + } } // TestAlertModeNoFanOut confirms the mode gate withholds Verdict @@ -445,9 +458,14 @@ func TestUnauthDial(t *testing.T) { // helpers // ----------------------------------------------------------------- -type fakeClassifier struct{} +type fakeClassifier struct { + calls *int32 +} func (f *fakeClassifier) Classify(_ context.Context, _ *bpb.PairedEvent, _ string) (*engine.Verdict, error) { + if f.calls != nil { + atomic.AddInt32(f.calls, 1) + } return &engine.Verdict{MADCode: "M0", Classification: "benign"}, nil }