Skip to content
Open
5 changes: 4 additions & 1 deletion backend/internal/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,10 @@ func TestApproveReviewPublishesToSubscriber(t *testing.T) {
}

// Fake SDK subscriber.
ch, dereg := hub.Register(sessID)
ch, dereg, err := hub.Register(sessID, "test-owner")
if err != nil {
t.Fatalf("Register: %v", err)
}
defer dereg()

resp := postJSON(t, srv, cookie, "/api/reviews/"+queueID+"/approve", map[string]any{})
Expand Down
1 change: 1 addition & 0 deletions backend/internal/ws/frames.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const schemaVersion = 2
// plus our application-specific 4xxx codes.
const (
closeProtocolError = 1002
closePolicyViolation = 1008
closeInternalServerErr = 1011
closeQuotaExhausted = 4003
)
Expand Down
20 changes: 19 additions & 1 deletion backend/internal/ws/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,25 @@ func serve(ctx context.Context, conn *websocket.Conn, sess *session, st *store.S
// verdicts and HITL resolutions land here, drained by the writer
// goroutine. The LoginAck is written directly inside handleLogin
// (single goroutine, pre-register, no concurrency to serialise).
hubCh, deregister := hub.Register(sess.sessionID)
hubCh, deregister, err := hub.Register(sess.sessionID, sess.routeOwner())
if err != nil {
if errors.Is(err, ErrSessionOwnerConflict) {
slog.WarnContext(ctx, "ws.session_owner_conflict",
"session_id", sess.sessionID,
"api_key_id", sess.apiKey.ID,
"route_owner", sess.routeOwner(),
)
closeWith(conn, closePolicyViolation, "session_id already active for another owner")
return
}
slog.ErrorContext(ctx, "ws.session_register_failed",
"error", err,
"session_id", sess.sessionID,
"api_key_id", sess.apiKey.ID,
)
closeWith(conn, closeInternalServerErr, "internal error")
return
}
writerDone := make(chan struct{})
go func() {
defer close(writerDone)
Expand Down
102 changes: 102 additions & 0 deletions backend/internal/ws/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,97 @@ func TestAlertModeNoFanOut(t *testing.T) {
}
}

func TestSessionIDReuseDifferentOwnerDoesNotStealVerdicts(t *testing.T) {
db := openInMemoryDB(t)
t.Cleanup(func() { _ = db.Close() })

st := store.New(db)
plaintextKeyA := "adr_local_test_key_owner_a"
plaintextKeyB := "adr_local_test_key_owner_b"
insertAPIKeyWithProfile(t, db, sha256Hex(plaintextKeyA), "agent-profile-a")
insertAPIKeyWithProfile(t, db, sha256Hex(plaintextKeyB), "agent-profile-b")

if _, err := db.Exec(`UPDATE policies SET mode = 'block' WHERE id = 1`); err != nil {
t.Fatalf("set mode=block: %v", err)
}

mux := http.NewServeMux()
mux.Handle("/ws", ws.AuthMiddleware(st)(ws.NewHandler(st, &fakeClassifier{}, ws.NewHub(), nil, nil)))
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)

wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws"
connA, _, err := websocket.DefaultDialer.Dial(wsURL, http.Header{
"Authorization": {"Bearer " + plaintextKeyA},
})
if err != nil {
t.Fatalf("dial client A: %v", err)
}
t.Cleanup(func() { _ = connA.Close() })

const sessionID = "shared-session-takeover-test"
if err := writeProto(connA, &bpb.ClientFrame{
Frame: &bpb.ClientFrame_Login{Login: &bpb.SessionLogin{
SessionId: sessionID, SchemaVersion: 2,
}},
}); err != nil {
t.Fatalf("send login client A: %v", err)
}
if _, err := readServerFrame(connA); err != nil {
t.Fatalf("read login_ack client A: %v", err)
}

connB, _, err := websocket.DefaultDialer.Dial(wsURL, http.Header{
"Authorization": {"Bearer " + plaintextKeyB},
})
if err != nil {
t.Fatalf("dial client B: %v", err)
}
t.Cleanup(func() { _ = connB.Close() })
if err := writeProto(connB, &bpb.ClientFrame{
Frame: &bpb.ClientFrame_Login{Login: &bpb.SessionLogin{
SessionId: sessionID, SchemaVersion: 2,
}},
}); err != nil {
t.Fatalf("send login client B: %v", err)
}
if _, err := readServerFrame(connB); err != nil {
t.Fatalf("read login_ack client B: %v", err)
}
if err := connB.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("set client B deadline: %v", err)
}
if _, _, err := connB.ReadMessage(); err == nil {
t.Fatal("expected conflicting client B to be closed")
} else if closeErr, ok := err.(*websocket.CloseError); !ok || closeErr.Code != websocket.ClosePolicyViolation {
t.Fatalf("client B close err = %v, want close code %d", err, websocket.ClosePolicyViolation)
}

eventID := uuid.NewString()
if err := writeProto(connA, &bpb.ClientFrame{
Frame: &bpb.ClientFrame_PairedBatch{PairedBatch: &bpb.PairedEventBatch{
Events: []*bpb.PairedEvent{{
EventId: eventID, SessionId: sessionID,
PairType: bpb.PairType_PAIR_TYPE_TOOL,
Agent: &bpb.AgentContext{AgentId: "owner-a-agent"},
Data: &bpb.PairedEvent_Tool{Tool: &bpb.ToolPairData{
ToolName: "noop", ToolCallId: "tc-owner-a", Input: "{}", Output: "ok",
}},
}},
}},
}); err != nil {
t.Fatalf("send paired_batch client A: %v", err)
}

verdict, err := readServerFrame(connA)
if err != nil {
t.Fatalf("read verdict client A: %v", err)
}
if got := verdict.GetVerdict(); got == nil || got.EventId != eventID {
t.Fatalf("client A verdict = %+v, want event_id %q", got, eventID)
}
}

// TestRevokeKicksLiveWS asserts the security guarantee: an open WS
// authenticated with key X gets terminated within seconds when X is
// revoked, not at next-disconnect-whenever. Drives the path the REST
Expand Down Expand Up @@ -403,6 +494,17 @@ func insertAPIKey(t *testing.T, db *sql.DB, hashHex string) {
}
}

func insertAPIKeyWithProfile(t *testing.T, db *sql.DB, hashHex, agentProfileID string) {
t.Helper()
_, err := db.Exec(
`INSERT INTO api_keys (id, key_hash, prefix, label, agent_profile_id) VALUES (?, ?, ?, ?, ?)`,
uuid.NewString(), hashHex, "adr_local_te", "test", agentProfileID,
)
if err != nil {
t.Fatalf("insert api_keys: %v", err)
}
}

func writeProto(conn *websocket.Conn, msg proto.Message) error {
buf, err := proto.Marshal(msg)
if err != nil {
Expand Down
50 changes: 32 additions & 18 deletions backend/internal/ws/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,74 @@
package ws

import (
"errors"
"sync"

"google.golang.org/protobuf/proto"

pb "github.com/secureagentics/Adrian/backend/internal/proto"
)

// ErrSessionOwnerConflict is returned when another logical client owner
// already holds the subscriber slot for a session_id.
var ErrSessionOwnerConflict = errors.New("session_id already registered by another owner")

type subscriber struct {
owner string
ch chan []byte
}

// Hub is a process-local pub/sub keyed by session_id. The WS handler
// for each connected SDK registers a write channel; the REST review
// approve/reject path publishes a HITL-resolution Verdict frame to it.
//
// Single subscriber per session_id, re-Register replaces any prior
// channel (covers SDK reconnect during a hold). On disconnect the WS
// handler must call deregister to free the slot.
// Single subscriber per session_id. Re-register by the same server-derived
// owner replaces the prior channel (SDK reconnect / key rotation). A
// different owner claiming the same session_id is rejected so it cannot steal
// verdict or HITL routing.
type Hub struct {
mu sync.Mutex
subs map[string]chan []byte
subs map[string]subscriber
}

// NewHub returns a fresh hub.
func NewHub() *Hub {
return &Hub{subs: make(map[string]chan []byte)}
return &Hub{subs: make(map[string]subscriber)}
}

// Register adds a subscriber for sessionID and returns its write
// Register adds a subscriber for sessionID + owner and returns its write
// channel plus a deregister callback. The caller spawns a writer
// goroutine that drains the channel and calls conn.WriteMessage.
//
// If a prior subscriber exists for the same session_id, its channel
// is closed so its writer goroutine exits cleanly. (Concurrent
// connections under one session_id are not a normal case; an SDK
// reconnect during a hold is the realistic path.)
func (h *Hub) Register(sessionID string) (<-chan []byte, func()) {
// If a prior subscriber exists for the same session_id and owner, its channel
// is closed so its writer goroutine exits cleanly. If the existing subscriber
// belongs to another owner, registration fails and the old channel remains
// active.
func (h *Hub) Register(sessionID, owner string) (<-chan []byte, func(), error) {
h.mu.Lock()
defer h.mu.Unlock()

if old, ok := h.subs[sessionID]; ok {
close(old)
if old.owner != owner {
return nil, nil, ErrSessionOwnerConflict
}
close(old.ch)
}
ch := make(chan []byte, 8)
h.subs[sessionID] = ch
h.subs[sessionID] = subscriber{owner: owner, ch: ch}

deregister := func() {
h.mu.Lock()
defer h.mu.Unlock()
// Only delete + close when the entry is still ours; a later
// Register may have replaced it and already closed the prior
// channel.
if cur, ok := h.subs[sessionID]; ok && cur == ch {
if cur, ok := h.subs[sessionID]; ok && cur.ch == ch {
delete(h.subs, sessionID)
close(ch)
}
}
return ch, deregister
return ch, deregister, nil
}

// Publish marshals and pushes a frame to the subscriber for sessionID.
Expand All @@ -70,13 +84,13 @@ func (h *Hub) Publish(sessionID string, frame *pb.ServerFrame) bool {
return false
}
h.mu.Lock()
ch, ok := h.subs[sessionID]
h.mu.Unlock()
defer h.mu.Unlock()
sub, ok := h.subs[sessionID]
if !ok {
return false
}
select {
case ch <- buf:
case sub.ch <- buf:
return true
default:
return false
Expand Down
48 changes: 44 additions & 4 deletions backend/internal/ws/hub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package ws

import (
"errors"
"testing"

"google.golang.org/protobuf/proto"
Expand All @@ -13,7 +14,10 @@ import (

func TestHubPublishDeliversToSubscriber(t *testing.T) {
h := NewHub()
ch, dereg := h.Register("sess-1")
ch, dereg, err := h.Register("sess-1", "owner-1")
if err != nil {
t.Fatalf("Register: %v", err)
}
defer dereg()

frame := &pb.ServerFrame{Frame: &pb.ServerFrame_Verdict{Verdict: &pb.Verdict{
Expand Down Expand Up @@ -45,11 +49,17 @@ func TestHubPublishNoSubscriberReturnsFalse(t *testing.T) {

func TestHubReRegisterClosesPriorChannel(t *testing.T) {
h := NewHub()
first, _ := h.Register("sess-x")
first, _, err := h.Register("sess-x", "owner-1")
if err != nil {
t.Fatalf("first Register: %v", err)
}

// New register replaces the slot; the old channel must close so a
// writer goroutine reading from it exits cleanly.
second, dereg := h.Register("sess-x")
second, dereg, err := h.Register("sess-x", "owner-1")
if err != nil {
t.Fatalf("second Register: %v", err)
}
defer dereg()

if _, ok := <-first; ok {
Expand All @@ -67,9 +77,39 @@ func TestHubReRegisterClosesPriorChannel(t *testing.T) {
}
}

func TestHubRejectsReRegisterFromDifferentOwner(t *testing.T) {
h := NewHub()
first, dereg, err := h.Register("sess-x", "owner-1")
if err != nil {
t.Fatalf("first Register: %v", err)
}
defer dereg()

second, secondDereg, err := h.Register("sess-x", "owner-2")
if !errors.Is(err, ErrSessionOwnerConflict) {
t.Fatalf("Register err = %v, want ErrSessionOwnerConflict", err)
}
if second != nil || secondDereg != nil {
t.Fatal("conflicting Register returned a subscriber")
}

if !h.Publish("sess-x", &pb.ServerFrame{
Frame: &pb.ServerFrame_Verdict{Verdict: &pb.Verdict{EventId: "ev-y"}},
}) {
t.Fatal("Publish to original subscriber should still succeed")
}
got := <-first
if len(got) == 0 {
t.Fatal("expected non-empty frame delivered to original subscriber")
}
}

func TestHubDeregisterRemovesEntry(t *testing.T) {
h := NewHub()
_, dereg := h.Register("sess-d")
_, dereg, err := h.Register("sess-d", "owner-1")
if err != nil {
t.Fatalf("Register: %v", err)
}
dereg()
if h.Publish("sess-d", &pb.ServerFrame{
Frame: &pb.ServerFrame_Verdict{Verdict: &pb.Verdict{}},
Expand Down
13 changes: 13 additions & 0 deletions backend/internal/ws/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,16 @@ func (s *session) agentProfileID() *string {
}
return s.apiKey.AgentProfileID
}

// routeOwner returns the server-authenticated logical owner for hub routing.
// Agent-profile keys may rotate, so profile ownership is preferred over raw
// key ID to preserve reconnect continuity. Unprofiled keys fall back to key ID.
func (s *session) routeOwner() string {
if s.apiKey == nil {
return ""
}
if s.apiKey.AgentProfileID != nil && *s.apiKey.AgentProfileID != "" {
return "agent_profile:" + *s.apiKey.AgentProfileID
}
return "api_key:" + s.apiKey.ID
}