Skip to content

Commit 8044a49

Browse files
committed
fixes from review
1 parent ecb6827 commit 8044a49

8 files changed

Lines changed: 294 additions & 48 deletions

File tree

pkg/vmcp/server/session_management_v2_integration_test.go

Lines changed: 192 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,25 +105,50 @@ func newV2FakeFactory(tools []vmcp.Tool) *v2FakeMultiSessionFactory {
105105
}
106106

107107
func (f *v2FakeMultiSessionFactory) MakeSession(
108-
_ context.Context, _ *auth.Identity, _ []*vmcp.Backend,
108+
_ context.Context, identity *auth.Identity, _ []*vmcp.Backend,
109109
) (vmcpsession.MultiSession, error) {
110110
if f.err != nil {
111111
return nil, f.err
112112
}
113113
baseSession := transportsession.NewStreamableSession("auto-id")
114+
115+
// Set basic metadata to indicate whether this is an anonymous session.
116+
// Integration tests don't need to verify crypto implementation details.
117+
allowAnonymous := vmcpsession.ShouldAllowAnonymous(identity)
118+
if !allowAnonymous {
119+
// Authenticated session - set non-empty hash placeholder
120+
baseSession.SetMetadata(vmcpsession.MetadataKeyTokenHash, "fake-hash-for-testing")
121+
baseSession.SetMetadata(vmcpsession.MetadataKeyTokenSalt, "fake-salt-for-testing")
122+
} else {
123+
// Anonymous session - set empty hash
124+
baseSession.SetMetadata(vmcpsession.MetadataKeyTokenHash, "")
125+
}
126+
114127
sess := newV2FakeMultiSession(baseSession, f.tools)
115128
f.lastCreatedSession = sess
116129
return sess, nil
117130
}
118131

119132
func (f *v2FakeMultiSessionFactory) MakeSessionWithID(
120-
_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend,
133+
_ context.Context, id string, identity *auth.Identity, allowAnonymous bool, _ []*vmcp.Backend,
121134
) (vmcpsession.MultiSession, error) {
122135
f.makeWithIDCalled.Store(true)
123136
if f.err != nil {
124137
return nil, f.err
125138
}
126139
baseSession := transportsession.NewStreamableSession(id)
140+
141+
// Set basic metadata to indicate whether this is an anonymous session.
142+
// Integration tests don't need to verify crypto implementation details.
143+
if identity != nil && identity.Token != "" && !allowAnonymous {
144+
// Authenticated session - set non-empty hash placeholder
145+
baseSession.SetMetadata(vmcpsession.MetadataKeyTokenHash, "fake-hash-for-testing")
146+
baseSession.SetMetadata(vmcpsession.MetadataKeyTokenSalt, "fake-salt-for-testing")
147+
} else {
148+
// Anonymous session - set empty hash
149+
baseSession.SetMetadata(vmcpsession.MetadataKeyTokenHash, "")
150+
}
151+
127152
sess := newV2FakeMultiSession(baseSession, f.tools)
128153
f.lastCreatedSession = sess
129154
return sess, nil
@@ -444,3 +469,168 @@ func TestIntegration_SessionManagementV2_OldPathUnused(t *testing.T) {
444469
"MakeSessionWithID should NOT be called when SessionManagementV2 is false",
445470
)
446471
}
472+
473+
// TestIntegration_SessionManagementV2_TokenBinding verifies end-to-end token binding security:
474+
//
475+
// 1. Initialize a session with bearer token "token-A"
476+
// 2. Make a tool call with the same token → succeeds
477+
// 3. Make a tool call with a different token "token-B" → fails with unauthorized
478+
// 4. Verify the session is terminated after auth failure
479+
//
480+
// NOTE: This test is currently skipped because the fake factory (v2FakeMultiSessionFactory)
481+
// doesn't implement real token binding - it uses placeholder metadata instead of real
482+
// HMAC-SHA256 hashes. To properly test token binding end-to-end, this test would need
483+
// to use the real defaultMultiSessionFactory with a real HMAC secret.
484+
//
485+
// Token binding security is comprehensively tested at the unit level in:
486+
// - pkg/vmcp/session/token_binding_test.go (factory behavior)
487+
// - pkg/vmcp/session/internal/security/*_test.go (crypto and validation)
488+
// - pkg/vmcp/server/sessionmanager/session_manager_test.go (termination on auth errors)
489+
//
490+
// TODO: Refactor test infrastructure to support real session factory for security tests.
491+
func TestIntegration_SessionManagementV2_TokenBinding(t *testing.T) {
492+
t.Skip("Fake factory doesn't implement real token binding - see test comment for details")
493+
t.Parallel()
494+
495+
testTool := vmcp.Tool{Name: "echo", Description: "echoes input"}
496+
factory := newV2FakeFactory([]vmcp.Tool{testTool})
497+
ts := buildV2Server(t, factory)
498+
499+
tokenA := "bearer-token-A"
500+
tokenB := "bearer-token-B"
501+
502+
// Step 1: Initialize with token A
503+
initReq := map[string]any{
504+
"jsonrpc": "2.0",
505+
"id": 1,
506+
"method": "initialize",
507+
"params": map[string]any{
508+
"protocolVersion": "2025-06-18",
509+
"capabilities": map[string]any{},
510+
"clientInfo": map[string]any{
511+
"name": "test-client",
512+
"version": "1.0",
513+
},
514+
},
515+
}
516+
517+
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, ts.URL+"/mcp", nil)
518+
require.NoError(t, err)
519+
req.Header.Set("Content-Type", "application/json")
520+
req.Header.Set("Authorization", "Bearer "+tokenA) // Set token A
521+
522+
reqBody, err := json.Marshal(initReq)
523+
require.NoError(t, err)
524+
req.Body = io.NopCloser(bytes.NewReader(reqBody))
525+
526+
initResp, err := http.DefaultClient.Do(req)
527+
require.NoError(t, err)
528+
defer initResp.Body.Close()
529+
530+
require.Equal(t, http.StatusOK, initResp.StatusCode)
531+
sessionID := initResp.Header.Get("Mcp-Session-Id")
532+
require.NotEmpty(t, sessionID, "should receive session ID")
533+
534+
// Wait for factory to be called
535+
require.Eventually(t,
536+
func() bool { return factory.makeWithIDCalled.Load() },
537+
1*time.Second,
538+
10*time.Millisecond,
539+
"factory should be called to create session",
540+
)
541+
542+
// Step 2: Call tool with token A (same as initialization) → should succeed
543+
toolReqA := map[string]any{
544+
"jsonrpc": "2.0",
545+
"id": 2,
546+
"method": "tools/call",
547+
"params": map[string]any{
548+
"name": "echo",
549+
"arguments": map[string]any{"msg": "hello"},
550+
},
551+
}
552+
553+
reqA, err := http.NewRequestWithContext(context.Background(), http.MethodPost, ts.URL+"/mcp", nil)
554+
require.NoError(t, err)
555+
reqA.Header.Set("Content-Type", "application/json")
556+
reqA.Header.Set("Mcp-Session-Id", sessionID)
557+
reqA.Header.Set("Authorization", "Bearer "+tokenA) // Same token
558+
559+
reqBodyA, err := json.Marshal(toolReqA)
560+
require.NoError(t, err)
561+
reqA.Body = io.NopCloser(bytes.NewReader(reqBodyA))
562+
563+
respA, err := http.DefaultClient.Do(reqA)
564+
require.NoError(t, err)
565+
defer respA.Body.Close()
566+
567+
assert.Equal(t, http.StatusOK, respA.StatusCode, "tool call with matching token should succeed")
568+
569+
// Step 3: Call tool with token B (different from initialization) → should fail
570+
toolReqB := map[string]any{
571+
"jsonrpc": "2.0",
572+
"id": 3,
573+
"method": "tools/call",
574+
"params": map[string]any{
575+
"name": "echo",
576+
"arguments": map[string]any{"msg": "hijack attempt"},
577+
},
578+
}
579+
580+
reqB, err := http.NewRequestWithContext(context.Background(), http.MethodPost, ts.URL+"/mcp", nil)
581+
require.NoError(t, err)
582+
reqB.Header.Set("Content-Type", "application/json")
583+
reqB.Header.Set("Mcp-Session-Id", sessionID)
584+
reqB.Header.Set("Authorization", "Bearer "+tokenB) // Different token!
585+
586+
reqBodyB, err := json.Marshal(toolReqB)
587+
require.NoError(t, err)
588+
reqB.Body = io.NopCloser(bytes.NewReader(reqBodyB))
589+
590+
respB, err := http.DefaultClient.Do(reqB)
591+
require.NoError(t, err)
592+
defer respB.Body.Close()
593+
594+
// The request should succeed at HTTP level but return an error result
595+
require.Equal(t, http.StatusOK, respB.StatusCode, "HTTP request should succeed")
596+
597+
var result map[string]any
598+
err = json.NewDecoder(respB.Body).Decode(&result)
599+
require.NoError(t, err)
600+
601+
// Should contain an error about unauthorized
602+
resultMap, ok := result["result"].(map[string]any)
603+
require.True(t, ok, "result should be an object")
604+
605+
isError, ok := resultMap["isError"].(bool)
606+
require.True(t, ok && isError, "result should indicate error")
607+
608+
// Step 4: Verify session is terminated (subsequent requests should fail)
609+
toolReqC := map[string]any{
610+
"jsonrpc": "2.0",
611+
"id": 4,
612+
"method": "tools/call",
613+
"params": map[string]any{
614+
"name": "echo",
615+
"arguments": map[string]any{"msg": "after termination"},
616+
},
617+
}
618+
619+
reqC, err := http.NewRequestWithContext(context.Background(), http.MethodPost, ts.URL+"/mcp", nil)
620+
require.NoError(t, err)
621+
reqC.Header.Set("Content-Type", "application/json")
622+
reqC.Header.Set("Mcp-Session-Id", sessionID)
623+
reqC.Header.Set("Authorization", "Bearer "+tokenA) // Even with original token
624+
625+
reqBodyC, err := json.Marshal(toolReqC)
626+
require.NoError(t, err)
627+
reqC.Body = io.NopCloser(bytes.NewReader(reqBodyC))
628+
629+
respC, err := http.DefaultClient.Do(reqC)
630+
require.NoError(t, err)
631+
defer respC.Body.Close()
632+
633+
// Session should be terminated, so this should fail
634+
assert.Equal(t, http.StatusInternalServerError, respC.StatusCode,
635+
"request should fail after session termination due to auth failure")
636+
}

pkg/vmcp/server/sessionmanager/session_manager.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func (sm *Manager) CreateSession(
165165
// Build the fully-formed MultiSession using the SDK-assigned session ID.
166166
// Sessions created with an identity are bound to that identity (allowAnonymous=false).
167167
// Sessions created without an identity allow anonymous access (allowAnonymous=true).
168-
allowAnonymous := identity == nil || identity.Token == ""
168+
allowAnonymous := vmcpsession.ShouldAllowAnonymous(identity)
169169
sess, err := sm.factory.MakeSessionWithID(ctx, sessionID, identity, allowAnonymous, backends)
170170
if err != nil {
171171
return nil, fmt.Errorf("Manager.CreateSession: failed to create multi-session: %w", err)

pkg/vmcp/session/factory.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ func (f *defaultMultiSessionFactory) MakeSession(
270270
) (MultiSession, error) {
271271
// Sessions created with an identity are bound to that identity (allowAnonymous=false).
272272
// Sessions created without an identity allow anonymous access (allowAnonymous=true).
273-
allowAnonymous := security.ShouldAllowAnonymous(identity)
273+
allowAnonymous := ShouldAllowAnonymous(identity)
274274
return f.makeSession(ctx, uuid.New().String(), identity, allowAnonymous, backends)
275275
}
276276

pkg/vmcp/session/internal/security/security.go

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,29 +27,14 @@ const (
2727
// SHA256HexLen is the length of a hex-encoded SHA256 hash (32 bytes = 64 hex characters)
2828
SHA256HexLen = 64
2929

30-
// MetadataKeyTokenHash is the transport-session metadata key that holds
31-
// the HMAC-SHA256 hash of the bearer token used to create the session.
32-
// For authenticated sessions this is hex(HMAC-SHA256(bearerToken)).
33-
// For anonymous sessions (no bearer token) this is the empty string sentinel.
34-
// The raw token is never stored — only the hash.
35-
MetadataKeyTokenHash = "vmcp.token.hash" //nolint:gosec // This is a metadata key name, not a credential.
36-
37-
// MetadataKeyTokenSalt is the transport-session metadata key that holds
38-
// the hex-encoded random salt used for HMAC-SHA256 token hashing.
39-
// Each session has a unique salt to prevent attacks across multiple sessions.
40-
MetadataKeyTokenSalt = "vmcp.token.salt" //nolint:gosec // This is a metadata key name, not a credential.
41-
)
30+
// MetadataKeyTokenHash is the session metadata key for the token hash.
31+
// Imported from types package to ensure consistency across all packages.
32+
MetadataKeyTokenHash = sessiontypes.MetadataKeyTokenHash
4233

43-
// ShouldAllowAnonymous determines if a session should allow anonymous access
44-
// based on the creator's identity. Sessions without an identity (nil) or with
45-
// an empty token are anonymous; sessions with a non-empty bearer token are
46-
// bound to that token.
47-
//
48-
// This helper consolidates the anonymous session logic and aligns with the
49-
// validation logic in PreventSessionHijacking.
50-
func ShouldAllowAnonymous(identity *auth.Identity) bool {
51-
return identity == nil || identity.Token == ""
52-
}
34+
// MetadataKeyTokenSalt is the session metadata key for the token salt.
35+
// Imported from types package to ensure consistency across all packages.
36+
MetadataKeyTokenSalt = sessiontypes.MetadataKeyTokenSalt
37+
)
5338

5439
// GenerateSalt generates a cryptographically secure random salt for token hashing.
5540
// Returns 16 bytes of random data from crypto/rand.
@@ -290,16 +275,33 @@ func (d *HijackPreventionDecorator) GetPrompt(
290275
// that HijackPreventionDecorator delegates to (see struct definition for full list).
291276
// The return type is concrete to eliminate the need for runtime casts at call sites.
292277
//
293-
// Returns an error if salt generation fails.
278+
// Returns an error if:
279+
// - session doesn't implement SessionMetadataWriter interface
280+
// - salt generation fails
294281
func PreventSessionHijacking(
295282
session interface{},
296283
hmacSecret []byte,
297284
identity *auth.Identity,
298285
allowAnonymous bool,
299286
) (*HijackPreventionDecorator, error) {
300-
// Cast session to SessionMetadataWriter to access SetMetadata.
301-
// The caller must ensure the session implements this interface.
302-
metadataWriter := session.(SessionMetadataWriter)
287+
// Validate upfront that session implements critical interfaces.
288+
// This provides fail-fast behavior for security-critical operations
289+
// instead of panics at runtime.
290+
291+
// Required for metadata persistence
292+
metadataWriter, ok := session.(SessionMetadataWriter)
293+
if !ok {
294+
return nil, fmt.Errorf("session must implement SessionMetadataWriter interface, got %T", session)
295+
}
296+
297+
// Required for security-critical operations (CallTool/ReadResource/GetPrompt)
298+
if _, ok := session.(HijackableSession); !ok {
299+
return nil, fmt.Errorf("session must implement HijackableSession interface (CallTool/ReadResource/GetPrompt), got %T", session)
300+
}
301+
302+
// Note: Pass-through methods (ID, Type, CreatedAt, etc.) are validated by the
303+
// type system when the decorator is used. We don't validate them here to keep
304+
// the constructor simple and allow minimal mocks for testing.
303305

304306
var boundTokenHash string
305307
var tokenSalt []byte

pkg/vmcp/session/internal/security/security_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/stretchr/testify/require"
1313

1414
"github.com/stacklok/toolhive/pkg/auth"
15+
"github.com/stacklok/toolhive/pkg/vmcp/session"
1516
"github.com/stacklok/toolhive/pkg/vmcp/session/internal/security"
1617
)
1718

@@ -306,7 +307,7 @@ func TestShouldAllowAnonymous_EdgeCases(t *testing.T) {
306307
for _, tt := range tests {
307308
t.Run(tt.name, func(t *testing.T) {
308309
t.Parallel()
309-
got := security.ShouldAllowAnonymous(tt.identity)
310+
got := session.ShouldAllowAnonymous(tt.identity)
310311
assert.Equal(t, tt.want, got)
311312
})
312313
}

pkg/vmcp/session/session.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package session
55

66
import (
7+
"github.com/stacklok/toolhive/pkg/auth"
78
transportsession "github.com/stacklok/toolhive/pkg/transport/session"
89
"github.com/stacklok/toolhive/pkg/vmcp"
910
sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types"
@@ -58,3 +59,36 @@ type MultiSession interface {
5859
// sessions for debugging and auditing.
5960
BackendSessions() map[string]string
6061
}
62+
63+
const (
64+
// MetadataKeyTokenHash is the session metadata key that holds the HMAC-SHA256
65+
// hash of the bearer token used to create the session. For authenticated sessions
66+
// this is hex(HMAC-SHA256(bearerToken)). For anonymous sessions this is the empty
67+
// string sentinel. The raw token is never stored — only the hash.
68+
//
69+
// Re-exported from types package for convenience.
70+
MetadataKeyTokenHash = sessiontypes.MetadataKeyTokenHash
71+
72+
// MetadataKeyTokenSalt is the session metadata key that holds the hex-encoded
73+
// random salt used for HMAC-SHA256 token hashing. Each session has a unique salt
74+
// to prevent attacks across multiple sessions.
75+
//
76+
// Re-exported from types package for convenience.
77+
MetadataKeyTokenSalt = sessiontypes.MetadataKeyTokenSalt
78+
)
79+
80+
// ShouldAllowAnonymous determines if a session should allow anonymous access
81+
// based on the creator's identity. This is session business logic that decides
82+
// whether a session is bound to a specific identity or allows anonymous access.
83+
//
84+
// Sessions without an identity (nil) or with an empty token are treated as
85+
// anonymous and will accept requests from any caller. Sessions with a non-empty
86+
// bearer token are bound to that token and will reject requests from different
87+
// callers.
88+
//
89+
// This function is used by both the session factory (to determine how to create
90+
// the session) and the security layer (to validate requests against the session's
91+
// access policy).
92+
func ShouldAllowAnonymous(identity *auth.Identity) bool {
93+
return identity == nil || identity.Token == ""
94+
}

0 commit comments

Comments
 (0)