diff --git a/backend/internal/api/handlers_test.go b/backend/internal/api/handlers_test.go index 1a4437d..a361c21 100644 --- a/backend/internal/api/handlers_test.go +++ b/backend/internal/api/handlers_test.go @@ -505,7 +505,10 @@ func TestApproveReviewPublishesToSubscriber(t *testing.T) { } // Fake SDK subscriber. - ch, dereg := hub.Register(sessID) + ch, dereg, err := hub.Register(sessID, "test-owner") + if err != nil { + t.Fatalf("Register: %v", err) + } defer dereg() resp := postJSON(t, srv, cookie, "/api/reviews/"+queueID+"/approve", map[string]any{}) diff --git a/backend/internal/ws/frames.go b/backend/internal/ws/frames.go index f7ae604..23c1e91 100644 --- a/backend/internal/ws/frames.go +++ b/backend/internal/ws/frames.go @@ -21,6 +21,7 @@ const schemaVersion = 2 // plus our application-specific 4xxx codes. const ( closeProtocolError = 1002 + closePolicyViolation = 1008 closeInternalServerErr = 1011 closeQuotaExhausted = 4003 ) diff --git a/backend/internal/ws/handler.go b/backend/internal/ws/handler.go index 29a16a5..6db577d 100644 --- a/backend/internal/ws/handler.go +++ b/backend/internal/ws/handler.go @@ -99,7 +99,25 @@ func serve(ctx context.Context, conn *websocket.Conn, sess *session, st *store.S // verdicts and HITL resolutions land here, drained by the writer // goroutine. The LoginAck is written directly inside handleLogin // (single goroutine, pre-register, no concurrency to serialise). - hubCh, deregister := hub.Register(sess.sessionID) + hubCh, deregister, err := hub.Register(sess.sessionID, sess.routeOwner()) + if err != nil { + if errors.Is(err, ErrSessionOwnerConflict) { + slog.WarnContext(ctx, "ws.session_owner_conflict", + "session_id", sess.sessionID, + "api_key_id", sess.apiKey.ID, + "route_owner", sess.routeOwner(), + ) + closeWith(conn, closePolicyViolation, "session_id already active for another owner") + return + } + slog.ErrorContext(ctx, "ws.session_register_failed", + "error", err, + "session_id", sess.sessionID, + "api_key_id", sess.apiKey.ID, + ) + closeWith(conn, closeInternalServerErr, "internal error") + return + } writerDone := make(chan struct{}) go func() { defer close(writerDone) diff --git a/backend/internal/ws/handler_test.go b/backend/internal/ws/handler_test.go index a4f9142..4cdf544 100644 --- a/backend/internal/ws/handler_test.go +++ b/backend/internal/ws/handler_test.go @@ -235,6 +235,97 @@ func TestAlertModeNoFanOut(t *testing.T) { } } +func TestSessionIDReuseDifferentOwnerDoesNotStealVerdicts(t *testing.T) { + db := openInMemoryDB(t) + t.Cleanup(func() { _ = db.Close() }) + + st := store.New(db) + plaintextKeyA := "adr_local_test_key_owner_a" + plaintextKeyB := "adr_local_test_key_owner_b" + insertAPIKeyWithProfile(t, db, sha256Hex(plaintextKeyA), "agent-profile-a") + insertAPIKeyWithProfile(t, db, sha256Hex(plaintextKeyB), "agent-profile-b") + + if _, err := db.Exec(`UPDATE policies SET mode = 'block' WHERE id = 1`); err != nil { + t.Fatalf("set mode=block: %v", err) + } + + mux := http.NewServeMux() + mux.Handle("/ws", ws.AuthMiddleware(st)(ws.NewHandler(st, &fakeClassifier{}, ws.NewHub(), nil, nil))) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws" + connA, _, err := websocket.DefaultDialer.Dial(wsURL, http.Header{ + "Authorization": {"Bearer " + plaintextKeyA}, + }) + if err != nil { + t.Fatalf("dial client A: %v", err) + } + t.Cleanup(func() { _ = connA.Close() }) + + const sessionID = "shared-session-takeover-test" + if err := writeProto(connA, &bpb.ClientFrame{ + Frame: &bpb.ClientFrame_Login{Login: &bpb.SessionLogin{ + SessionId: sessionID, SchemaVersion: 2, + }}, + }); err != nil { + t.Fatalf("send login client A: %v", err) + } + if _, err := readServerFrame(connA); err != nil { + t.Fatalf("read login_ack client A: %v", err) + } + + connB, _, err := websocket.DefaultDialer.Dial(wsURL, http.Header{ + "Authorization": {"Bearer " + plaintextKeyB}, + }) + if err != nil { + t.Fatalf("dial client B: %v", err) + } + t.Cleanup(func() { _ = connB.Close() }) + if err := writeProto(connB, &bpb.ClientFrame{ + Frame: &bpb.ClientFrame_Login{Login: &bpb.SessionLogin{ + SessionId: sessionID, SchemaVersion: 2, + }}, + }); err != nil { + t.Fatalf("send login client B: %v", err) + } + if _, err := readServerFrame(connB); err != nil { + t.Fatalf("read login_ack client B: %v", err) + } + if err := connB.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("set client B deadline: %v", err) + } + if _, _, err := connB.ReadMessage(); err == nil { + t.Fatal("expected conflicting client B to be closed") + } else if closeErr, ok := err.(*websocket.CloseError); !ok || closeErr.Code != websocket.ClosePolicyViolation { + t.Fatalf("client B close err = %v, want close code %d", err, websocket.ClosePolicyViolation) + } + + eventID := uuid.NewString() + if err := writeProto(connA, &bpb.ClientFrame{ + Frame: &bpb.ClientFrame_PairedBatch{PairedBatch: &bpb.PairedEventBatch{ + Events: []*bpb.PairedEvent{{ + EventId: eventID, SessionId: sessionID, + PairType: bpb.PairType_PAIR_TYPE_TOOL, + Agent: &bpb.AgentContext{AgentId: "owner-a-agent"}, + Data: &bpb.PairedEvent_Tool{Tool: &bpb.ToolPairData{ + ToolName: "noop", ToolCallId: "tc-owner-a", Input: "{}", Output: "ok", + }}, + }}, + }}, + }); err != nil { + t.Fatalf("send paired_batch client A: %v", err) + } + + verdict, err := readServerFrame(connA) + if err != nil { + t.Fatalf("read verdict client A: %v", err) + } + if got := verdict.GetVerdict(); got == nil || got.EventId != eventID { + t.Fatalf("client A verdict = %+v, want event_id %q", got, eventID) + } +} + // TestRevokeKicksLiveWS asserts the security guarantee: an open WS // authenticated with key X gets terminated within seconds when X is // revoked, not at next-disconnect-whenever. Drives the path the REST @@ -403,6 +494,17 @@ func insertAPIKey(t *testing.T, db *sql.DB, hashHex string) { } } +func insertAPIKeyWithProfile(t *testing.T, db *sql.DB, hashHex, agentProfileID string) { + t.Helper() + _, err := db.Exec( + `INSERT INTO api_keys (id, key_hash, prefix, label, agent_profile_id) VALUES (?, ?, ?, ?, ?)`, + uuid.NewString(), hashHex, "adr_local_te", "test", agentProfileID, + ) + if err != nil { + t.Fatalf("insert api_keys: %v", err) + } +} + func writeProto(conn *websocket.Conn, msg proto.Message) error { buf, err := proto.Marshal(msg) if err != nil { diff --git a/backend/internal/ws/hub.go b/backend/internal/ws/hub.go index fe2c129..7bf5d86 100644 --- a/backend/internal/ws/hub.go +++ b/backend/internal/ws/hub.go @@ -4,6 +4,7 @@ package ws import ( + "errors" "sync" "google.golang.org/protobuf/proto" @@ -11,40 +12,53 @@ import ( pb "github.com/secureagentics/Adrian/backend/internal/proto" ) +// ErrSessionOwnerConflict is returned when another logical client owner +// already holds the subscriber slot for a session_id. +var ErrSessionOwnerConflict = errors.New("session_id already registered by another owner") + +type subscriber struct { + owner string + ch chan []byte +} + // Hub is a process-local pub/sub keyed by session_id. The WS handler // for each connected SDK registers a write channel; the REST review // approve/reject path publishes a HITL-resolution Verdict frame to it. // -// Single subscriber per session_id, re-Register replaces any prior -// channel (covers SDK reconnect during a hold). On disconnect the WS -// handler must call deregister to free the slot. +// Single subscriber per session_id. Re-register by the same server-derived +// owner replaces the prior channel (SDK reconnect / key rotation). A +// different owner claiming the same session_id is rejected so it cannot steal +// verdict or HITL routing. type Hub struct { mu sync.Mutex - subs map[string]chan []byte + subs map[string]subscriber } // NewHub returns a fresh hub. func NewHub() *Hub { - return &Hub{subs: make(map[string]chan []byte)} + return &Hub{subs: make(map[string]subscriber)} } -// Register adds a subscriber for sessionID and returns its write +// Register adds a subscriber for sessionID + owner and returns its write // channel plus a deregister callback. The caller spawns a writer // goroutine that drains the channel and calls conn.WriteMessage. // -// If a prior subscriber exists for the same session_id, its channel -// is closed so its writer goroutine exits cleanly. (Concurrent -// connections under one session_id are not a normal case; an SDK -// reconnect during a hold is the realistic path.) -func (h *Hub) Register(sessionID string) (<-chan []byte, func()) { +// If a prior subscriber exists for the same session_id and owner, its channel +// is closed so its writer goroutine exits cleanly. If the existing subscriber +// belongs to another owner, registration fails and the old channel remains +// active. +func (h *Hub) Register(sessionID, owner string) (<-chan []byte, func(), error) { h.mu.Lock() defer h.mu.Unlock() if old, ok := h.subs[sessionID]; ok { - close(old) + if old.owner != owner { + return nil, nil, ErrSessionOwnerConflict + } + close(old.ch) } ch := make(chan []byte, 8) - h.subs[sessionID] = ch + h.subs[sessionID] = subscriber{owner: owner, ch: ch} deregister := func() { h.mu.Lock() @@ -52,12 +66,12 @@ func (h *Hub) Register(sessionID string) (<-chan []byte, func()) { // Only delete + close when the entry is still ours; a later // Register may have replaced it and already closed the prior // channel. - if cur, ok := h.subs[sessionID]; ok && cur == ch { + if cur, ok := h.subs[sessionID]; ok && cur.ch == ch { delete(h.subs, sessionID) close(ch) } } - return ch, deregister + return ch, deregister, nil } // Publish marshals and pushes a frame to the subscriber for sessionID. @@ -70,13 +84,13 @@ func (h *Hub) Publish(sessionID string, frame *pb.ServerFrame) bool { return false } h.mu.Lock() - ch, ok := h.subs[sessionID] - h.mu.Unlock() + defer h.mu.Unlock() + sub, ok := h.subs[sessionID] if !ok { return false } select { - case ch <- buf: + case sub.ch <- buf: return true default: return false diff --git a/backend/internal/ws/hub_test.go b/backend/internal/ws/hub_test.go index 3c33a39..17a5c18 100644 --- a/backend/internal/ws/hub_test.go +++ b/backend/internal/ws/hub_test.go @@ -4,6 +4,7 @@ package ws import ( + "errors" "testing" "google.golang.org/protobuf/proto" @@ -13,7 +14,10 @@ import ( func TestHubPublishDeliversToSubscriber(t *testing.T) { h := NewHub() - ch, dereg := h.Register("sess-1") + ch, dereg, err := h.Register("sess-1", "owner-1") + if err != nil { + t.Fatalf("Register: %v", err) + } defer dereg() frame := &pb.ServerFrame{Frame: &pb.ServerFrame_Verdict{Verdict: &pb.Verdict{ @@ -45,11 +49,17 @@ func TestHubPublishNoSubscriberReturnsFalse(t *testing.T) { func TestHubReRegisterClosesPriorChannel(t *testing.T) { h := NewHub() - first, _ := h.Register("sess-x") + first, _, err := h.Register("sess-x", "owner-1") + if err != nil { + t.Fatalf("first Register: %v", err) + } // New register replaces the slot; the old channel must close so a // writer goroutine reading from it exits cleanly. - second, dereg := h.Register("sess-x") + second, dereg, err := h.Register("sess-x", "owner-1") + if err != nil { + t.Fatalf("second Register: %v", err) + } defer dereg() if _, ok := <-first; ok { @@ -67,9 +77,39 @@ func TestHubReRegisterClosesPriorChannel(t *testing.T) { } } +func TestHubRejectsReRegisterFromDifferentOwner(t *testing.T) { + h := NewHub() + first, dereg, err := h.Register("sess-x", "owner-1") + if err != nil { + t.Fatalf("first Register: %v", err) + } + defer dereg() + + second, secondDereg, err := h.Register("sess-x", "owner-2") + if !errors.Is(err, ErrSessionOwnerConflict) { + t.Fatalf("Register err = %v, want ErrSessionOwnerConflict", err) + } + if second != nil || secondDereg != nil { + t.Fatal("conflicting Register returned a subscriber") + } + + if !h.Publish("sess-x", &pb.ServerFrame{ + Frame: &pb.ServerFrame_Verdict{Verdict: &pb.Verdict{EventId: "ev-y"}}, + }) { + t.Fatal("Publish to original subscriber should still succeed") + } + got := <-first + if len(got) == 0 { + t.Fatal("expected non-empty frame delivered to original subscriber") + } +} + func TestHubDeregisterRemovesEntry(t *testing.T) { h := NewHub() - _, dereg := h.Register("sess-d") + _, dereg, err := h.Register("sess-d", "owner-1") + if err != nil { + t.Fatalf("Register: %v", err) + } dereg() if h.Publish("sess-d", &pb.ServerFrame{ Frame: &pb.ServerFrame_Verdict{Verdict: &pb.Verdict{}}, diff --git a/backend/internal/ws/session.go b/backend/internal/ws/session.go index f10ed69..cb251a3 100644 --- a/backend/internal/ws/session.go +++ b/backend/internal/ws/session.go @@ -25,3 +25,16 @@ func (s *session) agentProfileID() *string { } return s.apiKey.AgentProfileID } + +// routeOwner returns the server-authenticated logical owner for hub routing. +// Agent-profile keys may rotate, so profile ownership is preferred over raw +// key ID to preserve reconnect continuity. Unprofiled keys fall back to key ID. +func (s *session) routeOwner() string { + if s.apiKey == nil { + return "" + } + if s.apiKey.AgentProfileID != nil && *s.apiKey.AgentProfileID != "" { + return "agent_profile:" + *s.apiKey.AgentProfileID + } + return "api_key:" + s.apiKey.ID +}