Skip to content

Commit e051547

Browse files
committed
added tests, made oauth mcp servers workspace scoped
1 parent be03225 commit e051547

21 files changed

Lines changed: 477 additions & 91 deletions

File tree

apps/sim/app/api/mcp/oauth/callback/route.ts

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,23 @@ export const GET = withRouteHandler(async (request: NextRequest) => {
9494

9595
const preregistered = await loadPreregisteredClient(server.id)
9696
const provider = new SimMcpOauthProvider({ row, preregistered })
97-
const result = await mcpAuth(provider, {
98-
serverUrl: server.url,
99-
authorizationCode: code,
100-
})
101-
102-
await clearVerifier(row.id)
97+
let result: Awaited<ReturnType<typeof mcpAuth>>
98+
try {
99+
result = await mcpAuth(provider, {
100+
serverUrl: server.url,
101+
authorizationCode: code,
102+
})
103+
} finally {
104+
await clearVerifier(row.id)
105+
}
103106

104107
if (result !== 'AUTHORIZED') {
105108
return htmlClose('Authorization did not complete.', false, server.id)
106109
}
107110

108111
try {
109112
await mcpService.clearCache(server.workspaceId)
110-
await mcpService.discoverServerTools(row.userId, server.id, server.workspaceId)
113+
await mcpService.discoverServerTools(session.user.id, server.id, server.workspaceId)
111114
} catch (e) {
112115
logger.warn('Post-auth tools refresh failed', toError(e).message)
113116
}
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/**
2+
* @vitest-environment node
3+
*/
4+
import {
5+
dbChainMock,
6+
dbChainMockFns,
7+
hybridAuthMock,
8+
hybridAuthMockFns,
9+
permissionsMock,
10+
permissionsMockFns,
11+
resetDbChainMock,
12+
schemaMock,
13+
} from '@sim/testing'
14+
import { NextRequest } from 'next/server'
15+
import { beforeEach, describe, expect, it, vi } from 'vitest'
16+
17+
const {
18+
mockMcpAuth,
19+
mockGetOrCreateOauthRow,
20+
mockLoadPreregisteredClient,
21+
mockSetOauthRowUser,
22+
MockMcpOauthRedirectRequired,
23+
} = vi.hoisted(() => ({
24+
mockMcpAuth: vi.fn(),
25+
mockGetOrCreateOauthRow: vi.fn(),
26+
mockLoadPreregisteredClient: vi.fn(),
27+
mockSetOauthRowUser: vi.fn(),
28+
MockMcpOauthRedirectRequired: class MockMcpOauthRedirectRequired extends Error {
29+
constructor(public readonly authorizationUrl: string) {
30+
super('redirect required')
31+
}
32+
},
33+
}))
34+
35+
vi.mock('@sim/db', () => dbChainMock)
36+
vi.mock('@sim/db/schema', () => schemaMock)
37+
vi.mock('drizzle-orm', () => ({
38+
and: vi.fn(),
39+
eq: vi.fn(),
40+
isNull: vi.fn(),
41+
}))
42+
vi.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({
43+
auth: mockMcpAuth,
44+
}))
45+
vi.mock('@/lib/auth/hybrid', () => hybridAuthMock)
46+
vi.mock('@/lib/workspaces/permissions/utils', () => permissionsMock)
47+
vi.mock('@/lib/mcp/oauth', () => ({
48+
getOrCreateOauthRow: mockGetOrCreateOauthRow,
49+
loadPreregisteredClient: mockLoadPreregisteredClient,
50+
McpOauthRedirectRequired: MockMcpOauthRedirectRequired,
51+
setOauthRowUser: mockSetOauthRowUser,
52+
SimMcpOauthProvider: vi.fn().mockImplementation((value) => value),
53+
}))
54+
55+
import { GET } from './route'
56+
57+
describe('MCP OAuth start route', () => {
58+
beforeEach(() => {
59+
vi.clearAllMocks()
60+
resetDbChainMock()
61+
hybridAuthMockFns.mockCheckSessionOrInternalAuth.mockResolvedValue({
62+
success: true,
63+
userId: 'user-2',
64+
userName: 'User Two',
65+
userEmail: 'user2@example.com',
66+
authType: 'session',
67+
})
68+
permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValue('write')
69+
dbChainMockFns.limit.mockResolvedValue([
70+
{
71+
id: 'server-1',
72+
name: 'Exa',
73+
url: 'https://mcp.exa.ai/mcp',
74+
workspaceId: 'workspace-1',
75+
authType: 'oauth',
76+
deletedAt: null,
77+
},
78+
])
79+
mockGetOrCreateOauthRow.mockResolvedValue({
80+
id: 'oauth-row-1',
81+
mcpServerId: 'server-1',
82+
userId: 'user-1',
83+
workspaceId: 'workspace-1',
84+
clientInformation: null,
85+
tokens: null,
86+
codeVerifier: null,
87+
state: null,
88+
updatedAt: new Date(),
89+
})
90+
mockLoadPreregisteredClient.mockResolvedValue(undefined)
91+
mockMcpAuth.mockRejectedValue(new MockMcpOauthRedirectRequired('https://mcp.exa.ai/authorize'))
92+
})
93+
94+
it('requires workspace write permission via MCP auth middleware', async () => {
95+
const request = new NextRequest(
96+
'http://localhost:3000/api/mcp/oauth/start?workspaceId=workspace-1&serverId=server-1'
97+
)
98+
99+
await GET(request)
100+
101+
expect(permissionsMockFns.mockGetUserEntityPermissions).toHaveBeenCalledWith(
102+
'user-2',
103+
'workspace',
104+
'workspace-1'
105+
)
106+
})
107+
108+
it('uses a workspace-scoped OAuth row and stamps the latest authorizing user', async () => {
109+
const request = new NextRequest(
110+
'http://localhost:3000/api/mcp/oauth/start?workspaceId=workspace-1&serverId=server-1'
111+
)
112+
113+
const response = await GET(request)
114+
const body = await response.json()
115+
116+
expect(response.status).toBe(200)
117+
expect(body).toEqual({
118+
status: 'redirect',
119+
authorizationUrl: 'https://mcp.exa.ai/authorize',
120+
})
121+
expect(mockGetOrCreateOauthRow).toHaveBeenCalledWith({
122+
mcpServerId: 'server-1',
123+
userId: 'user-2',
124+
workspaceId: 'workspace-1',
125+
})
126+
expect(mockSetOauthRowUser).toHaveBeenCalledWith('oauth-row-1', 'user-2')
127+
})
128+
129+
it('rejects a second user starting OAuth while another authorization is active', async () => {
130+
mockGetOrCreateOauthRow.mockResolvedValueOnce({
131+
id: 'oauth-row-1',
132+
mcpServerId: 'server-1',
133+
userId: 'user-1',
134+
workspaceId: 'workspace-1',
135+
clientInformation: null,
136+
tokens: null,
137+
codeVerifier: null,
138+
state: 'hashed-active-state',
139+
updatedAt: new Date(),
140+
})
141+
const request = new NextRequest(
142+
'http://localhost:3000/api/mcp/oauth/start?workspaceId=workspace-1&serverId=server-1'
143+
)
144+
145+
const response = await GET(request)
146+
const body = await response.json()
147+
148+
expect(response.status).toBe(409)
149+
expect(body.error).toBe('OAuth authorization already in progress for this server')
150+
expect(mockMcpAuth).not.toHaveBeenCalled()
151+
})
152+
})

apps/sim/app/api/mcp/oauth/start/route.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ import {
1515
loadPreregisteredClient,
1616
McpOauthRedirectRequired,
1717
SimMcpOauthProvider,
18+
setOauthRowUser,
1819
} from '@/lib/mcp/oauth'
1920
import { createMcpErrorResponse } from '@/lib/mcp/utils'
2021

2122
const logger = createLogger('McpOauthStartAPI')
23+
const OAUTH_START_TTL_MS = 10 * 60 * 1000
2224

2325
export const dynamic = 'force-dynamic'
2426

@@ -64,6 +66,18 @@ export const GET = withRouteHandler(
6466
userId,
6567
workspaceId,
6668
})
69+
const hasActiveFlow = !!row.state && row.updatedAt.getTime() > Date.now() - OAUTH_START_TTL_MS
70+
if (hasActiveFlow && row.userId && row.userId !== userId) {
71+
return createMcpErrorResponse(
72+
new Error('OAuth authorization already in progress'),
73+
'OAuth authorization already in progress for this server',
74+
409
75+
)
76+
}
77+
if (row.userId !== userId) {
78+
await setOauthRowUser(row.id, userId)
79+
row.userId = userId
80+
}
6781
const preregistered = await loadPreregisteredClient(server.id)
6882
const provider = new SimMcpOauthProvider({ row, preregistered })
6983

apps/sim/app/api/mcp/tools/discover/route.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js'
12
import { createLogger } from '@sim/logger'
23
import type { NextRequest } from 'next/server'
34
import { mcpToolDiscoveryQuerySchema, refreshMcpToolsBodySchema } from '@/lib/api/contracts/mcp'
45
import { validationErrorResponse } from '@/lib/api/server'
56
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
67
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
78
import { mcpService } from '@/lib/mcp/service'
8-
import type { McpToolDiscoveryResponse } from '@/lib/mcp/types'
9+
import { McpOauthAuthorizationRequiredError, type McpToolDiscoveryResponse } from '@/lib/mcp/types'
910
import { categorizeError, createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils'
1011

1112
const logger = createLogger('McpToolDiscoveryAPI')
@@ -46,6 +47,12 @@ export const GET = withRouteHandler(
4647
)
4748
return createMcpSuccessResponse(responseData)
4849
} catch (error) {
50+
if (
51+
error instanceof McpOauthAuthorizationRequiredError ||
52+
error instanceof UnauthorizedError
53+
) {
54+
return createMcpErrorResponse(error, 'OAuth re-authorization required', 401)
55+
}
4956
logger.error(`[${requestId}] Error discovering MCP tools:`, error)
5057
const { message, status } = categorizeError(error)
5158
return createMcpErrorResponse(new Error(message), 'Failed to discover MCP tools', status)
@@ -100,6 +107,12 @@ export const POST = withRouteHandler(
100107
},
101108
})
102109
} catch (error) {
110+
if (
111+
error instanceof McpOauthAuthorizationRequiredError ||
112+
error instanceof UnauthorizedError
113+
) {
114+
return createMcpErrorResponse(error, 'OAuth re-authorization required', 401)
115+
}
103116
logger.error(`[${requestId}] Error refreshing tool discovery:`, error)
104117
const { message, status } = categorizeError(error)
105118
return createMcpErrorResponse(new Error(message), 'Failed to refresh tool discovery', status)

apps/sim/app/workspace/[workspaceId]/settings/components/mcp/components/form-field/form-field.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ interface FormFieldProps {
99
export function FormField({ label, children, optional }: FormFieldProps) {
1010
return (
1111
<div className='flex items-center justify-between gap-3'>
12-
<Label className='w-[100px] shrink-0 font-medium text-[var(--text-secondary)] text-sm'>
12+
<Label className='w-[116px] shrink-0 font-medium text-[var(--text-secondary)] text-sm'>
1313
{label}
1414
{optional && (
1515
<span className='ml-1 font-normal text-[var(--text-muted)] text-xs'>(optional)</span>

apps/sim/app/workspace/[workspaceId]/settings/components/mcp/components/mcp-server-form-modal/mcp-server-form-modal.tsx

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import {
1212
ModalContent,
1313
ModalFooter,
1414
ModalHeader,
15+
SecretInput,
1516
Textarea,
1617
} from '@/components/emcn'
1718
import { cn } from '@/lib/core/utils/cn'
@@ -617,9 +618,9 @@ export function McpServerFormModal({
617618

618619
return (
619620
<Modal open={open} onOpenChange={onOpenChange}>
620-
<ModalContent>
621-
<ModalHeader>{title}</ModalHeader>
622-
<ModalBody>
621+
<ModalContent size='lg' className='max-h-[82vh]'>
622+
<ModalHeader className='border-[var(--border)] border-b pb-3'>{title}</ModalHeader>
623+
<ModalBody className='min-h-0 px-4 pt-4 pb-4'>
623624
{formMode === 'json' ? (
624625
<div className='flex flex-col gap-2'>
625626
<Textarea
@@ -631,12 +632,28 @@ export function McpServerFormModal({
631632
if (testResult) clearTestResult()
632633
if (submitError) setSubmitError(null)
633634
}}
634-
className='min-h-[200px] font-mono text-small'
635+
className='min-h-[280px] font-mono text-small leading-5'
635636
/>
636637
{jsonError && <p className='text-[var(--text-error)] text-caption'>{jsonError}</p>}
637638
</div>
638639
) : (
639-
<div className='flex flex-col gap-2'>
640+
<div className='flex flex-col gap-3'>
641+
<input
642+
type='text'
643+
name='fakeusernameremembered'
644+
autoComplete='username'
645+
style={{ position: 'absolute', left: '-9999px', opacity: 0, pointerEvents: 'none' }}
646+
tabIndex={-1}
647+
readOnly
648+
/>
649+
<input
650+
type='password'
651+
name='fakepasswordremembered'
652+
autoComplete='current-password'
653+
style={{ position: 'absolute', left: '-9999px', opacity: 0, pointerEvents: 'none' }}
654+
tabIndex={-1}
655+
readOnly
656+
/>
640657
<FormField label='Server Name'>
641658
<EmcnInput
642659
placeholder='e.g., My MCP Server'
@@ -675,8 +692,7 @@ export function McpServerFormModal({
675692
)}
676693
</FormField>
677694

678-
<div className='flex flex-col gap-2'>
679-
<span className='font-medium text-[var(--text-secondary)] text-small'>Headers</span>
695+
<FormField label='Headers'>
680696
<div className='flex max-h-[140px] flex-col gap-2 overflow-y-auto'>
681697
{(formData.headers || []).map((header, index) => (
682698
<HeaderRow
@@ -698,7 +714,7 @@ export function McpServerFormModal({
698714
/>
699715
))}
700716
</div>
701-
</div>
717+
</FormField>
702718

703719
<Button
704720
type='button'
@@ -715,10 +731,16 @@ export function McpServerFormModal({
715731
</Button>
716732
{showAdvanced && (
717733
<div className='flex flex-col gap-2'>
718-
<FormField label='OAuth Client ID (optional)'>
734+
<FormField label='Client ID'>
719735
<EmcnInput
720-
placeholder='Pre-registered client ID'
736+
placeholder='OAuth Client ID (optional)'
721737
value={formData.oauthClientId || ''}
738+
name='mcp_oauth_client_id'
739+
autoComplete='off'
740+
autoCorrect='off'
741+
autoCapitalize='off'
742+
data-lpignore='true'
743+
data-form-type='other'
722744
onChange={(e) => {
723745
if (testResult) clearTestResult()
724746
if (submitError) setSubmitError(null)
@@ -727,16 +749,21 @@ export function McpServerFormModal({
727749
className='h-9'
728750
/>
729751
</FormField>
730-
<FormField label='OAuth Client Secret (optional)'>
731-
<EmcnInput
732-
type='password'
733-
placeholder='Pre-registered client secret'
752+
<FormField label='Client Secret'>
753+
<SecretInput
754+
placeholder='OAuth Client Secret (optional)'
734755
value={formData.oauthClientSecret || ''}
735-
onChange={(e) => {
756+
name='mcp_oauth_client_secret'
757+
autoComplete='new-password'
758+
autoCorrect='off'
759+
autoCapitalize='off'
760+
data-lpignore='true'
761+
data-form-type='other'
762+
onChange={(value) => {
736763
if (testResult) clearTestResult()
737764
if (submitError) setSubmitError(null)
738765
setOauthClientSecretTouched(true)
739-
setFormData((prev) => ({ ...prev, oauthClientSecret: e.target.value }))
766+
setFormData((prev) => ({ ...prev, oauthClientSecret: value }))
740767
}}
741768
className='h-9'
742769
/>
@@ -749,9 +776,9 @@ export function McpServerFormModal({
749776
</div>
750777
)}
751778
</ModalBody>
752-
<ModalFooter>
779+
<ModalFooter className='flex-col items-stretch gap-2'>
753780
{submitError && (
754-
<p className='mb-2 w-full text-[var(--text-error)] text-small'>{submitError}</p>
781+
<p className='w-full text-[var(--text-error)] text-small'>{submitError}</p>
755782
)}
756783
<div className='flex w-full items-center justify-between'>
757784
<div className='flex items-center gap-2'>

0 commit comments

Comments
 (0)