Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions backend/internal/store/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,25 @@ type EventFilters struct {
MinMAD string
}

// InsertEvent persists one paired event. The payload column holds the
// full JSON-encoded PairedEvent blob; downstream readers (dashboard,
// engine) re-decode it as needed. SDK retries can replay the same
// event_id, so duplicate primary-key inserts are ignored.
func (s *Store) InsertEvent(ctx context.Context, e *Event) error {
_, err := s.db.ExecContext(ctx,
`INSERT OR IGNORE INTO events
// InsertEvent persists one paired event and reports whether a new row
// was inserted. The payload column holds the full JSON-encoded
// PairedEvent blob; downstream readers (dashboard, engine) re-decode it
// as needed. SDK retries can replay the same event_id.
func (s *Store) InsertEvent(ctx context.Context, e *Event) (bool, error) {
res, err := s.db.ExecContext(ctx,
`INSERT INTO events
(id, session_id, agent_id, agent_profile_id, event_type, run_id, payload, tokens_used)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (id) DO NOTHING`,
e.ID, e.SessionID, e.AgentID, e.AgentProfileID, e.EventType, e.RunID, e.PayloadJSON, e.TokensUsed)
return err
if err != nil {
return false, err
}
n, err := res.RowsAffected()
if err != nil {
return false, err
}
return n > 0, nil
}

// ListEvents returns a page of events matching the filters, plus the
Expand Down
34 changes: 30 additions & 4 deletions backend/internal/ws/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,31 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla
if err != nil {
return err
}
if err := st.InsertEvent(ctx, newEventRow(sess, ev, payloadJSON)); err != nil {
inserted, err := st.InsertEvent(ctx, newEventRow(sess, ev, payloadJSON))
if err != nil {
return err
}
if !inserted {
// A retry can race with the first in-flight delivery: event row
// already exists, but verdict insert has not landed yet. Poll a
// few times before giving up so we don't fail the WS batch on a
// benign duplicate.
for i := 0; i < 3; i++ {
existing, err := st.GetVerdictByEventID(ctx, ev.EventId)
if err == nil {
return dispatchVerdict(ctx, sess, st, hub, ev, snap, existing.ID, existing.MADCode)
}
if !errors.Is(err, store.ErrNotFound) {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(25 * time.Millisecond):
}
}
return nil
}

// Refresh the runtime agents row so the dashboard's /agents view
// reflects activity. Best-effort: a write failure is logged but
Expand Down Expand Up @@ -296,6 +318,10 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla
verdict.MADCode, verdict.Classification)
}

return dispatchVerdict(ctx, sess, st, hub, ev, snap, vrow.ID, verdict.MADCode)
}

func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *Hub, ev *pb.PairedEvent, snap *pb.PolicySnapshot, verdictID, madCode string) error {
// Mode-gated dispatch:
// alert: persist verdict, do NOT notify the SDK (dashboard-only).
// hitl + in-scope + actionable: persist + queue for human review,
Expand All @@ -304,7 +330,7 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla
// no-op for the operator since the SDK never blocks on it).
// hitl + out-of-scope: forward (no review queued for this code).
// block: forward all verdicts; SDK is the enforcement point.
inScope := shouldFanOut(snap, verdict.MADCode)
inScope := shouldFanOut(snap, madCode)
switch snap.GetMode() {
case pb.Mode_MODE_ALERT:
// Dashboard-only: the verdict is persisted, the SDK is not
Expand All @@ -315,7 +341,7 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla
// Queue for human review and hold the verdict. The reviews
// REST handler resumes the SDK with a HitlResponse-bearing
// Verdict on approve/reject via the same hub channel.
if err := st.InsertHitlQueue(ctx, ev.EventId, vrow.ID, sess.sessionID, verdict.MADCode); err != nil {
if err := st.InsertHitlQueue(ctx, ev.EventId, verdictID, sess.sessionID, madCode); err != nil {
slog.ErrorContext(ctx, "hitl.insert_failed",
"error", err, "event_id", ev.EventId)
}
Expand All @@ -338,7 +364,7 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla
Verdict: &pb.Verdict{
EventId: ev.EventId,
SessionId: sess.sessionID,
MadCode: verdict.MADCode,
MadCode: madCode,
Policy: snap,
},
},
Expand Down
22 changes: 20 additions & 2 deletions backend/internal/ws/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/http/httptest"
"net/url"
"strings"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -170,8 +171,9 @@ func TestDuplicateEventRetryKeepsWSOpen(t *testing.T) {
t.Fatalf("set mode=block: %v", err)
}

var classifyCalls int32
mux := http.NewServeMux()
mux.Handle("/ws", ws.AuthMiddleware(st)(ws.NewHandler(st, &fakeClassifier{}, ws.NewHub(), nil, nil)))
mux.Handle("/ws", ws.AuthMiddleware(st)(ws.NewHandler(st, &fakeClassifier{calls: &classifyCalls}, ws.NewHub(), nil, nil)))
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)

Expand Down Expand Up @@ -238,6 +240,17 @@ func TestDuplicateEventRetryKeepsWSOpen(t *testing.T) {
if eventRows != 1 {
t.Fatalf("expected duplicate retry to keep 1 event row, got %d", eventRows)
}

var verdictRows int
if err := db.QueryRow("SELECT count(*) FROM verdicts WHERE event_id = ?", eventID).Scan(&verdictRows); err != nil {
t.Fatalf("query verdicts: %v", err)
}
if verdictRows != 1 {
t.Fatalf("expected duplicate retry to keep 1 verdict row, got %d", verdictRows)
}
if got := atomic.LoadInt32(&classifyCalls); got != 1 {
t.Fatalf("expected classifier to run once for duplicate retry, got %d calls", got)
}
}

// TestAlertModeNoFanOut confirms the mode gate withholds Verdict
Expand Down Expand Up @@ -445,9 +458,14 @@ func TestUnauthDial(t *testing.T) {
// helpers
// -----------------------------------------------------------------

type fakeClassifier struct{}
type fakeClassifier struct {
calls *int32
}

func (f *fakeClassifier) Classify(_ context.Context, _ *bpb.PairedEvent, _ string) (*engine.Verdict, error) {
if f.calls != nil {
atomic.AddInt32(f.calls, 1)
}
return &engine.Verdict{MADCode: "M0", Classification: "benign"}, nil
}

Expand Down