diff --git a/apps/web/src/app/discord/webhook/route.ts b/apps/web/src/app/discord/webhook/route.ts index 7b80b678f0..e901eb1db8 100644 --- a/apps/web/src/app/discord/webhook/route.ts +++ b/apps/web/src/app/discord/webhook/route.ts @@ -17,6 +17,10 @@ import { truncateForDiscord, } from '@/lib/discord-bot/discord-utils'; import { getDevUserSuffix } from '@/lib/slack-bot/dev-user-info'; +import { + parseForwardedGatewayMessageEvent, + type ForwardedGatewayEvent, +} from '@/lib/discord-bot/forwarded-gateway-event'; export const maxDuration = 800; @@ -26,24 +30,6 @@ export const maxDuration = 800; const PROCESSING_EMOJI = '\u23f3'; // hourglass const COMPLETE_EMOJI = '\u2705'; // white check mark -/** - * Forwarded Gateway event shape (from the Gateway listener) - */ -type ForwardedGatewayEvent = { - type: string; - timestamp: number; - botUserId: string | null; - data: { - id: string; - content: string; - channel_id: string; - guild_id: string; - author: { id: string; username: string; bot?: boolean }; - mentions?: Array<{ id: string }>; - message_reference?: { message_id: string }; - }; -}; - /** * Discord webhook handler. * Handles: @@ -60,10 +46,19 @@ export async function POST(request: NextRequest) { return new NextResponse('Unauthorized', { status: 401 }); } - const event = JSON.parse(rawBody) as ForwardedGatewayEvent; - if (event.type === 'GATEWAY_MESSAGE_CREATE') { - after(processGatewayMessage(event)); + let parsedBody: unknown; + try { + parsedBody = JSON.parse(rawBody); + } catch { + return new NextResponse('Invalid gateway event', { status: 400 }); } + + const event = parseForwardedGatewayMessageEvent(parsedBody); + if (!event) { + return new NextResponse('Invalid gateway event', { status: 400 }); + } + + after(processGatewayMessage(event)); return new NextResponse(null, { status: 200 }); } diff --git a/apps/web/src/lib/discord-bot/discord-channel-context.test.ts b/apps/web/src/lib/discord-bot/discord-channel-context.test.ts new file mode 100644 index 0000000000..9d8cf8e229 --- /dev/null +++ b/apps/web/src/lib/discord-bot/discord-channel-context.test.ts @@ -0,0 +1,76 @@ +jest.mock('@/lib/config.server', () => ({ + DISCORD_BOT_TOKEN: 'bot-token', +})); + +jest.mock('@sentry/nextjs', () => ({ + captureException: jest.fn(), +})); + +import { getDiscordConversationContext } from './discord-channel-context'; + +describe('getDiscordConversationContext', () => { + beforeEach(() => { + jest.restoreAllMocks(); + }); + + it('does not fetch Discord API data when the channel ID is malformed', async () => { + const fetchSpy = jest.spyOn(globalThis, 'fetch'); + + const result = await getDiscordConversationContext({ + channelId: '../../users/@me', + guildId: '111111111111111111', + userId: '222222222222222222', + messageId: '333333333333333333', + }); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(result.channel).toBeNull(); + expect(result.recentMessages).toEqual([]); + expect(result.errors).toEqual(['Invalid Discord channel ID']); + }); + + it('uses fixed-origin Discord API URLs for valid context fetches', async () => { + const fetchSpy = jest + .spyOn(globalThis, 'fetch') + .mockResolvedValueOnce( + new Response(JSON.stringify({ id: '111111111111111111', type: 0, name: 'general' }), { + status: 200, + }) + ) + .mockResolvedValueOnce( + new Response( + JSON.stringify([ + { + id: '222222222222222222', + content: 'hello', + timestamp: '2026-06-02T00:00:00.000Z', + author: { id: '333333333333333333', username: 'alice' }, + }, + ]), + { status: 200 } + ) + ); + + const result = await getDiscordConversationContext( + { + channelId: '111111111111111111', + guildId: '444444444444444444', + userId: '333333333333333333', + messageId: '222222222222222222', + }, + { channelMessages: 1 } + ); + + expect(result.errors).toEqual([]); + expect(fetchSpy).toHaveBeenNthCalledWith( + 1, + 'https://discord.com/api/v10/channels/111111111111111111', + { headers: { Authorization: 'Bot bot-token' } } + ); + expect(fetchSpy).toHaveBeenNthCalledWith( + 2, + 'https://discord.com/api/v10/channels/111111111111111111/messages?limit=1', + { headers: { Authorization: 'Bot bot-token' } } + ); + }); +}); diff --git a/apps/web/src/lib/discord-bot/discord-channel-context.ts b/apps/web/src/lib/discord-bot/discord-channel-context.ts index bd30e38a68..5fea4abc7b 100644 --- a/apps/web/src/lib/discord-bot/discord-channel-context.ts +++ b/apps/web/src/lib/discord-bot/discord-channel-context.ts @@ -1,6 +1,7 @@ import 'server-only'; import { DISCORD_BOT_TOKEN } from '@/lib/config.server'; import { captureException } from '@sentry/nextjs'; +import { buildDiscordApiUrl, parseDiscordSnowflake } from './discord-id'; export type DiscordEventContext = { channelId: string; @@ -44,14 +45,15 @@ type DiscordApiMessage = { }; async function fetchDiscordApi( - path: string + pathSegments: string[], + query?: Record ): Promise<{ ok: true; data: T } | { ok: false; error: string }> { if (!DISCORD_BOT_TOKEN) { return { ok: false, error: 'DISCORD_BOT_TOKEN is not configured' }; } try { - const response = await fetch(`https://discord.com/api/v10${path}`, { + const response = await fetch(buildDiscordApiUrl(pathSegments, query), { headers: { Authorization: `Bot ${DISCORD_BOT_TOKEN}` }, }); @@ -71,7 +73,14 @@ async function fetchDiscordApi( async function getChannelInfo( channelId: string ): Promise<{ ok: true; channel: DiscordChannelInfo } | { ok: false; error: string }> { - const result = await fetchDiscordApi(`/channels/${channelId}`); + let validatedChannelId: string; + try { + validatedChannelId = parseDiscordSnowflake(channelId, 'channel ID'); + } catch (error) { + return { ok: false, error: error instanceof Error ? error.message : 'Invalid channel ID' }; + } + + const result = await fetchDiscordApi(['channels', validatedChannelId]); if (!result.ok) return result; return { @@ -89,8 +98,22 @@ async function getChannelMessages( channelId: string, limit: number ): Promise<{ ok: true; messages: DiscordMessageForPrompt[] } | { ok: false; error: string }> { + let validatedChannelId: string; + try { + validatedChannelId = parseDiscordSnowflake(channelId, 'channel ID'); + } catch (error) { + return { ok: false, error: error instanceof Error ? error.message : 'Invalid channel ID' }; + } + + if (!Number.isInteger(limit) || limit < 1 || limit > 100) { + return { ok: false, error: 'Invalid Discord channel message limit' }; + } + const result = await fetchDiscordApi( - `/channels/${channelId}/messages?limit=${limit}` + ['channels', validatedChannelId, 'messages'], + { + limit, + } ); if (!result.ok) return result; @@ -111,6 +134,31 @@ export async function getDiscordConversationContext( const channelMessagesLimit = limits?.channelMessages ?? 12; const errors: string[] = []; + const contextIds = [ + { fieldName: 'guild ID', value: context.guildId }, + { fieldName: 'channel ID', value: context.channelId }, + { fieldName: 'user ID', value: context.userId }, + { fieldName: 'message ID', value: context.messageId }, + ]; + + for (const { fieldName, value } of contextIds) { + try { + parseDiscordSnowflake(value, fieldName); + } catch (error) { + errors.push(error instanceof Error ? error.message : `Invalid Discord ${fieldName}`); + } + } + + if (errors.length > 0) { + captureException(new Error('Invalid Discord conversation context'), { + level: 'warning', + tags: { source: 'discord_conversation_context' }, + extra: { errors }, + }); + + return { channel: null, recentMessages: [], errors }; + } + const [channelInfoResult, messagesResult] = await Promise.all([ getChannelInfo(context.channelId), getChannelMessages(context.channelId, channelMessagesLimit), diff --git a/apps/web/src/lib/discord-bot/discord-id.test.ts b/apps/web/src/lib/discord-bot/discord-id.test.ts new file mode 100644 index 0000000000..5ea03b7ac9 --- /dev/null +++ b/apps/web/src/lib/discord-bot/discord-id.test.ts @@ -0,0 +1,35 @@ +import { buildDiscordApiUrl, isDiscordSnowflake, parseDiscordSnowflake } from './discord-id'; + +describe('discord-id', () => { + it('accepts numeric snowflake values', () => { + expect(isDiscordSnowflake('123456789012345678')).toBe(true); + expect(parseDiscordSnowflake('123456789012345678', 'user ID')).toBe('123456789012345678'); + }); + + it.each([ + '', + ' ', + 'abc', + '123/456', + '123?limit=1', + '123#frag', + '%2f', + '..', + '1', + '1234', + '1234567890123456', + '1'.repeat(21), + ])('rejects malformed snowflake value %p', value => { + expect(isDiscordSnowflake(value)).toBe(false); + expect(() => parseDiscordSnowflake(value, 'user ID')).toThrow('Invalid Discord user ID'); + }); + + it('builds fixed-origin Discord API URLs with encoded path segments', () => { + expect(buildDiscordApiUrl(['channels', '123', 'messages'], { limit: 12 })).toBe( + 'https://discord.com/api/v10/channels/123/messages?limit=12' + ); + expect(buildDiscordApiUrl(['reactions', '✅', '@me'])).toBe( + 'https://discord.com/api/v10/reactions/%E2%9C%85/%40me' + ); + }); +}); diff --git a/apps/web/src/lib/discord-bot/discord-id.ts b/apps/web/src/lib/discord-bot/discord-id.ts new file mode 100644 index 0000000000..b287b9678b --- /dev/null +++ b/apps/web/src/lib/discord-bot/discord-id.ts @@ -0,0 +1,32 @@ +const DISCORD_API_BASE_URL = 'https://discord.com/api/v10/'; +const DISCORD_SNOWFLAKE_PATTERN = /^\d{17,20}$/; + +export function isDiscordSnowflake(value: string): boolean { + return DISCORD_SNOWFLAKE_PATTERN.test(value); +} + +export function parseDiscordSnowflake(value: string, fieldName: string): string { + if (isDiscordSnowflake(value)) { + return value; + } + + throw new Error(`Invalid Discord ${fieldName}`); +} + +export function buildDiscordApiUrl( + pathSegments: string[], + query?: Record +): string { + const url = new URL( + pathSegments.map(segment => encodeURIComponent(segment)).join('/'), + DISCORD_API_BASE_URL + ); + + if (query) { + for (const [key, value] of Object.entries(query)) { + url.searchParams.set(key, String(value)); + } + } + + return url.toString(); +} diff --git a/apps/web/src/lib/discord-bot/discord-utils.test.ts b/apps/web/src/lib/discord-bot/discord-utils.test.ts new file mode 100644 index 0000000000..5ab4eec77d --- /dev/null +++ b/apps/web/src/lib/discord-bot/discord-utils.test.ts @@ -0,0 +1,49 @@ +jest.mock('@/lib/config.server', () => ({ + DISCORD_BOT_TOKEN: 'bot-token', +})); + +import { + buildDiscordMessageLink, + replaceDiscordUserMentionsWithNames, + stripDiscordBotMention, +} from './discord-utils'; + +describe('discord-utils', () => { + beforeEach(() => { + jest.restoreAllMocks(); + }); + + it('does not strip mentions when the bot ID is malformed', () => { + expect(stripDiscordBotMention('<@bot/1> hello', 'bot/1')).toBe('<@bot/1> hello'); + }); + + it('rejects malformed message link IDs', () => { + expect(() => buildDiscordMessageLink('111111111111111111', '2/../3', '4')).toThrow( + 'Invalid Discord channel ID' + ); + }); + + it('does not fetch members when the guild ID is malformed', async () => { + const fetchSpy = jest.spyOn(globalThis, 'fetch'); + + await expect( + replaceDiscordUserMentionsWithNames('<@123456789012345678>', 'guild/1') + ).resolves.toBe('<@123456789012345678>'); + expect(fetchSpy).not.toHaveBeenCalled(); + }); + + it('fetches valid mention IDs through the fixed Discord API origin', async () => { + const fetchSpy = jest + .spyOn(globalThis, 'fetch') + .mockResolvedValue(new Response(JSON.stringify({ nick: 'Alice' }), { status: 200 })); + + await expect( + replaceDiscordUserMentionsWithNames('<@123456789012345678>', '234567890123456789') + ).resolves.toBe('@Alice'); + + expect(fetchSpy).toHaveBeenCalledWith( + 'https://discord.com/api/v10/guilds/234567890123456789/members/123456789012345678', + { headers: { Authorization: 'Bot bot-token' } } + ); + }); +}); diff --git a/apps/web/src/lib/discord-bot/discord-utils.ts b/apps/web/src/lib/discord-bot/discord-utils.ts index eaace88edc..d6d1275dcf 100644 --- a/apps/web/src/lib/discord-bot/discord-utils.ts +++ b/apps/web/src/lib/discord-bot/discord-utils.ts @@ -1,12 +1,13 @@ import 'server-only'; import { DISCORD_BOT_TOKEN } from '@/lib/config.server'; +import { buildDiscordApiUrl, isDiscordSnowflake, parseDiscordSnowflake } from './discord-id'; /** * Strip the bot's own mention from a Discord message. * Discord mentions look like <@BOT_ID> or <@!BOT_ID> (nickname mention). */ export function stripDiscordBotMention(text: string, botUserId: string | null): string { - if (!botUserId) return text; + if (!botUserId || !isDiscordSnowflake(botUserId)) return text; // Match both <@ID> and <@!ID> (nickname mention format) return text.replace(new RegExp(`<@!?${botUserId}>`, 'g'), '').trim(); } @@ -21,12 +22,19 @@ export async function replaceDiscordUserMentionsWithNames( ): Promise { if (!DISCORD_BOT_TOKEN) return text; + let validatedGuildId: string; + try { + validatedGuildId = parseDiscordSnowflake(guildId, 'guild ID'); + } catch { + return text; + } + const mentionRegex = /<@!?(\d+)>/g; const mentions = [...text.matchAll(mentionRegex)]; if (mentions.length === 0) return text; // Deduplicate user IDs - const uniqueUserIds = [...new Set(mentions.map(m => m[1]))]; + const uniqueUserIds = [...new Set(mentions.map(m => m[1]).filter(isDiscordSnowflake))]; // Fetch display names in parallel const nameMap = new Map(); @@ -34,7 +42,7 @@ export async function replaceDiscordUserMentionsWithNames( uniqueUserIds.map(async userId => { try { const response = await fetch( - `https://discord.com/api/v10/guilds/${guildId}/members/${userId}`, + buildDiscordApiUrl(['guilds', validatedGuildId, 'members', userId]), { headers: { Authorization: `Bot ${DISCORD_BOT_TOKEN}` }, } @@ -72,7 +80,10 @@ export function buildDiscordMessageLink( channelId: string, messageId: string ): string { - return `https://discord.com/channels/${guildId}/${channelId}/${messageId}`; + const validatedGuildId = parseDiscordSnowflake(guildId, 'guild ID'); + const validatedChannelId = parseDiscordSnowflake(channelId, 'channel ID'); + const validatedMessageId = parseDiscordSnowflake(messageId, 'message ID'); + return `https://discord.com/channels/${validatedGuildId}/${validatedChannelId}/${validatedMessageId}`; } /** diff --git a/apps/web/src/lib/discord-bot/forwarded-gateway-event.test.ts b/apps/web/src/lib/discord-bot/forwarded-gateway-event.test.ts new file mode 100644 index 0000000000..e5e90b30f0 --- /dev/null +++ b/apps/web/src/lib/discord-bot/forwarded-gateway-event.test.ts @@ -0,0 +1,75 @@ +import { parseForwardedGatewayMessageEvent } from './forwarded-gateway-event'; + +const validEvent = { + type: 'GATEWAY_MESSAGE_CREATE', + timestamp: 1, + botUserId: '111111111111111111', + data: { + id: '222222222222222222', + content: '<@111111111111111111> hello', + channel_id: '333333333333333333', + guild_id: '444444444444444444', + author: { id: '555555555555555555', username: 'alice', bot: false }, + mentions: [{ id: '111111111111111111' }], + message_reference: { message_id: '666666666666666666' }, + }, +}; + +describe('parseForwardedGatewayMessageEvent', () => { + it('accepts valid forwarded Discord message events', () => { + expect(parseForwardedGatewayMessageEvent(validEvent)).toEqual(validEvent); + }); + + it.each([ + ['message ID', { data: { id: '../../users/@me' } }], + ['channel ID', { data: { channel_id: '333/../../users/@me' } }], + ['guild ID', { data: { guild_id: 'guild?x=1' } }], + ['author ID', { data: { author: { id: 'author#frag' } } }], + ['mention ID', { data: { mentions: [{ id: 'mention/1' }] } }], + ['message reference ID', { data: { message_reference: { message_id: 'ref/1' } } }], + ])('rejects malformed %s', (_name, override) => { + expect(parseForwardedGatewayMessageEvent(mergeEvent(validEvent, override))).toBeNull(); + }); + + it('rejects invalid JSON shapes', () => { + expect(parseForwardedGatewayMessageEvent(null)).toBeNull(); + expect(parseForwardedGatewayMessageEvent({ type: 'GATEWAY_MESSAGE_UPDATE' })).toBeNull(); + }); +}); + +function mergeEvent(base: typeof validEvent, override: Record) { + return { + ...base, + ...override, + data: { + ...base.data, + ...(typeof override.data === 'object' && override.data !== null ? override.data : {}), + author: { + ...base.data.author, + ...getNestedRecord(override, 'data', 'author'), + }, + }, + }; +} + +function getNestedRecord( + value: Record, + firstKey: string, + secondKey: string +): Record { + const firstValue = value[firstKey]; + if (!isTestRecord(firstValue)) { + return {}; + } + + const secondValue = firstValue[secondKey]; + if (!isTestRecord(secondValue)) { + return {}; + } + + return secondValue; +} + +function isTestRecord(value: unknown): value is Record { + return typeof value === 'object' && value !== null && !Array.isArray(value); +} diff --git a/apps/web/src/lib/discord-bot/forwarded-gateway-event.ts b/apps/web/src/lib/discord-bot/forwarded-gateway-event.ts new file mode 100644 index 0000000000..9570fe0e73 --- /dev/null +++ b/apps/web/src/lib/discord-bot/forwarded-gateway-event.ts @@ -0,0 +1,113 @@ +import { isDiscordSnowflake } from '@/lib/discord-bot/discord-id'; + +export type ForwardedGatewayEvent = { + type: 'GATEWAY_MESSAGE_CREATE'; + timestamp: number; + botUserId: string | null; + data: { + id: string; + content: string; + channel_id: string; + guild_id: string; + author: { id: string; username: string; bot?: boolean }; + mentions?: Array<{ id: string }>; + message_reference?: { message_id: string }; + }; +}; + +function isRecord(value: unknown): value is Record { + return typeof value === 'object' && value !== null && !Array.isArray(value); +} + +export function parseForwardedGatewayMessageEvent(value: unknown): ForwardedGatewayEvent | null { + if (!isRecord(value) || value.type !== 'GATEWAY_MESSAGE_CREATE') { + return null; + } + + if (typeof value.timestamp !== 'number') { + return null; + } + + const botUserId = value.botUserId; + if (botUserId !== null && (typeof botUserId !== 'string' || !isDiscordSnowflake(botUserId))) { + return null; + } + + const data = value.data; + if (!isRecord(data)) { + return null; + } + + const author = data.author; + if (!isRecord(author)) { + return null; + } + + if ( + typeof data.id !== 'string' || + !isDiscordSnowflake(data.id) || + typeof data.content !== 'string' || + typeof data.channel_id !== 'string' || + !isDiscordSnowflake(data.channel_id) || + typeof data.guild_id !== 'string' || + !isDiscordSnowflake(data.guild_id) || + typeof author.id !== 'string' || + !isDiscordSnowflake(author.id) || + typeof author.username !== 'string' + ) { + return null; + } + + if (author.bot !== undefined && typeof author.bot !== 'boolean') { + return null; + } + + const mentions = data.mentions; + let validatedMentions: Array<{ id: string }> | undefined; + if (mentions !== undefined) { + if (!Array.isArray(mentions)) { + return null; + } + + const nextMentions: Array<{ id: string }> = []; + for (const mention of mentions) { + if (!isRecord(mention) || typeof mention.id !== 'string' || !isDiscordSnowflake(mention.id)) { + return null; + } + nextMentions.push({ id: mention.id }); + } + validatedMentions = nextMentions; + } + + const messageReference = data.message_reference; + let validatedMessageReference: { message_id: string } | undefined; + if (messageReference !== undefined) { + if ( + !isRecord(messageReference) || + typeof messageReference.message_id !== 'string' || + !isDiscordSnowflake(messageReference.message_id) + ) { + return null; + } + validatedMessageReference = { message_id: messageReference.message_id }; + } + + return { + type: 'GATEWAY_MESSAGE_CREATE', + timestamp: value.timestamp, + botUserId, + data: { + id: data.id, + content: data.content, + channel_id: data.channel_id, + guild_id: data.guild_id, + author: { + id: author.id, + username: author.username, + bot: author.bot, + }, + mentions: validatedMentions, + message_reference: validatedMessageReference, + }, + }; +} diff --git a/apps/web/src/lib/integrations/discord-guild-membership.test.ts b/apps/web/src/lib/integrations/discord-guild-membership.test.ts new file mode 100644 index 0000000000..fb6d3e1a7b --- /dev/null +++ b/apps/web/src/lib/integrations/discord-guild-membership.test.ts @@ -0,0 +1,31 @@ +jest.mock('@/lib/config.server', () => ({ + DISCORD_OAUTH_BOT_TOKEN: 'bot-token', + DISCORD_SERVER_ID: '123456789012345678', +})); + +import { checkDiscordGuildMembership } from './discord-guild-membership'; + +describe('checkDiscordGuildMembership', () => { + beforeEach(() => { + jest.restoreAllMocks(); + }); + + it('rejects malformed Discord user IDs before fetching', async () => { + const fetchSpy = jest.spyOn(globalThis, 'fetch'); + + await expect(checkDiscordGuildMembership('user/1')).rejects.toThrow('Invalid Discord user ID'); + expect(fetchSpy).not.toHaveBeenCalled(); + }); + + it('uses fixed-origin Discord API URLs for valid user IDs', async () => { + const fetchSpy = jest + .spyOn(globalThis, 'fetch') + .mockResolvedValue(new Response(null, { status: 200 })); + + await expect(checkDiscordGuildMembership('234567890123456789')).resolves.toBe(true); + expect(fetchSpy).toHaveBeenCalledWith( + 'https://discord.com/api/v10/guilds/123456789012345678/members/234567890123456789', + expect.objectContaining({ headers: { Authorization: 'Bot bot-token' } }) + ); + }); +}); diff --git a/apps/web/src/lib/integrations/discord-guild-membership.ts b/apps/web/src/lib/integrations/discord-guild-membership.ts index 7f0aa31085..f20345c2aa 100644 --- a/apps/web/src/lib/integrations/discord-guild-membership.ts +++ b/apps/web/src/lib/integrations/discord-guild-membership.ts @@ -1,4 +1,5 @@ import { DISCORD_OAUTH_BOT_TOKEN, DISCORD_SERVER_ID } from '@/lib/config.server'; +import { buildDiscordApiUrl, parseDiscordSnowflake } from '@/lib/discord-bot/discord-id'; /** * Check if a Discord user is a member of the Kilo Discord server. @@ -13,15 +14,15 @@ export async function checkDiscordGuildMembership(discordUserId: string): Promis throw new Error('DISCORD_OAUTH_BOT_TOKEN or DISCORD_SERVER_ID not configured'); } - const response = await fetch( - `https://discord.com/api/v10/guilds/${DISCORD_SERVER_ID}/members/${discordUserId}`, - { - headers: { - Authorization: `Bot ${DISCORD_OAUTH_BOT_TOKEN}`, - }, - signal: AbortSignal.timeout(5_000), - } - ); + const guildId = parseDiscordSnowflake(DISCORD_SERVER_ID, 'server ID'); + const userId = parseDiscordSnowflake(discordUserId, 'user ID'); + + const response = await fetch(buildDiscordApiUrl(['guilds', guildId, 'members', userId]), { + headers: { + Authorization: `Bot ${DISCORD_OAUTH_BOT_TOKEN}`, + }, + signal: AbortSignal.timeout(5_000), + }); if (response.ok) return true; if (response.status === 404) return false; diff --git a/apps/web/src/lib/integrations/discord-service.test.ts b/apps/web/src/lib/integrations/discord-service.test.ts new file mode 100644 index 0000000000..28524516db --- /dev/null +++ b/apps/web/src/lib/integrations/discord-service.test.ts @@ -0,0 +1,191 @@ +jest.mock('@/lib/config.server', () => ({ + DISCORD_BOT_TOKEN: 'bot-token', + DISCORD_CLIENT_ID: 'client-id', + DISCORD_CLIENT_SECRET: 'client-secret', +})); + +const mockLimit = jest.fn(); +const mockUpdateSet = jest.fn(); +const mockUpdateWhere = jest.fn(); +const mockUpdateReturning = jest.fn(); +const mockInsertValues = jest.fn(); +const mockInsertReturning = jest.fn(); + +jest.mock('@/lib/drizzle', () => ({ + db: { + select: jest.fn(() => ({ + from: jest.fn(() => ({ + where: jest.fn(() => ({ + limit: mockLimit, + })), + })), + })), + update: jest.fn(() => ({ + set: mockUpdateSet, + })), + insert: jest.fn(() => ({ + values: mockInsertValues, + })), + delete: jest.fn(() => ({ + where: jest.fn(), + })), + }, +})); + +jest.mock('@/lib/organizations/organizations', () => ({ + getOrganizationById: jest.fn(), +})); + +jest.mock('@/lib/slack-bot/model-allow-list', () => ({ + getDefaultAllowedModel: jest.fn(async () => 'gpt-test'), +})); + +jest.mock('@/lib/model-allow.server', () => ({ + createAllowPredicateFromRestrictions: jest.fn(), + hasActiveModelRestrictions: jest.fn(() => false), +})); + +jest.mock('@/lib/organizations/model-restrictions', () => ({ + getEffectiveModelRestrictions: jest.fn(), +})); + +import type { Owner } from '@/lib/integrations/core/types'; +import { + addDiscordReaction, + postDiscordMessage, + removeDiscordReaction, + testConnection, + upsertDiscordInstallation, +} from './discord-service'; + +const owner = { type: 'user', id: 'user-1' } satisfies Owner; + +function buildDiscordIntegration(overrides: Record = {}) { + return { + id: 'integration-1', + integration_status: 'active', + platform_account_id: '123456789012345678', + platform_installation_id: '123456789012345678', + owned_by_user_id: owner.id, + owned_by_organization_id: null, + metadata: {}, + ...overrides, + }; +} + +describe('discord-service API URL validation', () => { + beforeEach(() => { + jest.restoreAllMocks(); + }); + + it('does not post a message when the channel ID is malformed', async () => { + const fetchSpy = jest.spyOn(globalThis, 'fetch'); + + await expect(postDiscordMessage('123/../456', 'hello')).resolves.toEqual({ + ok: false, + error: 'Invalid Discord channel ID', + }); + expect(fetchSpy).not.toHaveBeenCalled(); + }); + + it('does not post a reply when the message reference ID is malformed', async () => { + const fetchSpy = jest.spyOn(globalThis, 'fetch'); + + await expect( + postDiscordMessage('123456789012345678', 'hello', { + messageReference: { message_id: '123?x=1' }, + }) + ).resolves.toEqual({ ok: false, error: 'Invalid Discord message reference ID' }); + expect(fetchSpy).not.toHaveBeenCalled(); + }); + + it('does not add a reaction when the message ID is malformed', async () => { + const fetchSpy = jest.spyOn(globalThis, 'fetch'); + + await expect(addDiscordReaction('123456789012345678', 'message/1', '✅')).resolves.toEqual({ + ok: false, + error: 'Invalid Discord message ID', + }); + expect(fetchSpy).not.toHaveBeenCalled(); + }); + + it('does not remove a reaction when the channel ID is malformed', async () => { + const fetchSpy = jest.spyOn(globalThis, 'fetch'); + + await expect(removeDiscordReaction('channel#1', '123456789012345678', '✅')).resolves.toEqual({ + ok: false, + error: 'Invalid Discord channel ID', + }); + expect(fetchSpy).not.toHaveBeenCalled(); + }); + + it('uses fixed-origin Discord API URLs for valid message posts', async () => { + const fetchSpy = jest + .spyOn(globalThis, 'fetch') + .mockResolvedValue( + new Response(JSON.stringify({ id: '234567890123456789' }), { status: 200 }) + ); + + await expect(postDiscordMessage('123456789012345678', 'hello')).resolves.toEqual({ + ok: true, + messageId: '234567890123456789', + }); + + expect(fetchSpy).toHaveBeenCalledWith( + 'https://discord.com/api/v10/channels/123456789012345678/messages', + expect.objectContaining({ method: 'POST' }) + ); + }); +}); + +describe('discord-service persisted guild ID validation', () => { + beforeEach(() => { + jest.restoreAllMocks(); + mockLimit.mockReset(); + mockUpdateSet.mockReset(); + mockUpdateWhere.mockReset(); + mockUpdateReturning.mockReset(); + mockInsertValues.mockReset(); + mockInsertReturning.mockReset(); + }); + + it('does not test a connection when the stored guild ID is malformed', async () => { + mockLimit.mockResolvedValue([buildDiscordIntegration({ platform_account_id: 'guild/1' })]); + const fetchSpy = jest.spyOn(globalThis, 'fetch'); + + await expect(testConnection(owner)).resolves.toEqual({ + success: false, + error: 'Invalid guild ID found for this installation', + }); + expect(fetchSpy).not.toHaveBeenCalled(); + }); + + it('rejects malformed OAuth guild IDs before persistence', async () => { + mockLimit.mockResolvedValue([]); + + await expect( + upsertDiscordInstallation(owner, { + access_token: 'access-token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'refresh-token', + scope: 'bot guilds', + guild: { id: 'guild/1', name: 'Test Guild', icon: null }, + }) + ).rejects.toThrow('Invalid Discord guild ID'); + + expect(mockInsertValues).not.toHaveBeenCalled(); + }); + + it('tests valid stored guild IDs through the fixed Discord API origin', async () => { + mockLimit.mockResolvedValue([buildDiscordIntegration()]); + const fetchSpy = jest + .spyOn(globalThis, 'fetch') + .mockResolvedValue(new Response(null, { status: 200 })); + + await expect(testConnection(owner)).resolves.toEqual({ success: true }); + expect(fetchSpy).toHaveBeenCalledWith('https://discord.com/api/v10/guilds/123456789012345678', { + headers: { Authorization: 'Bot bot-token' }, + }); + }); +}); diff --git a/apps/web/src/lib/integrations/discord-service.ts b/apps/web/src/lib/integrations/discord-service.ts index b88c6fbc02..3b583bb71c 100644 --- a/apps/web/src/lib/integrations/discord-service.ts +++ b/apps/web/src/lib/integrations/discord-service.ts @@ -16,6 +16,7 @@ import { } from '@/lib/model-allow.server'; import { DEFAULT_BOT_MODEL } from '@/lib/bot/constants'; import { getEffectiveModelRestrictions } from '@/lib/organizations/model-restrictions'; +import { buildDiscordApiUrl, parseDiscordSnowflake } from '@/lib/discord-bot/discord-id'; // Discord OAuth2 scopes for the bot integration // 'bot' scope is needed for the bot to join servers @@ -176,7 +177,7 @@ export async function upsertDiscordInstallation( const existing = await getInstallation(owner); - const guildId = oauthResponse.guild.id; + const guildId = parseDiscordSnowflake(oauthResponse.guild.id, 'guild ID'); const guildName = oauthResponse.guild.name || 'Unknown Server'; const scopes = oauthResponse.scope?.split(' ') || null; @@ -299,9 +300,16 @@ export async function testConnection(owner: Owner): Promise<{ success: boolean; return { success: false, error: 'No guild ID found for this installation' }; } + let validatedGuildId: string; + try { + validatedGuildId = parseDiscordSnowflake(guildId, 'guild ID'); + } catch { + return { success: false, error: 'Invalid guild ID found for this installation' }; + } + try { // Verify the bot can access this guild - const response = await fetch(`https://discord.com/api/v10/guilds/${guildId}`, { + const response = await fetch(buildDiscordApiUrl(['guilds', validatedGuildId]), { headers: { Authorization: `Bot ${DISCORD_BOT_TOKEN}`, }, @@ -403,13 +411,32 @@ export async function postDiscordMessage( return { ok: false, error: 'DISCORD_BOT_TOKEN is not configured' }; } + let validatedChannelId: string; + try { + validatedChannelId = parseDiscordSnowflake(channelId, 'channel ID'); + } catch (error) { + return { ok: false, error: error instanceof Error ? error.message : 'Invalid channel ID' }; + } + try { const body: Record = { content }; if (options?.messageReference) { - body.message_reference = options.messageReference; + try { + body.message_reference = { + message_id: parseDiscordSnowflake( + options.messageReference.message_id, + 'message reference ID' + ), + }; + } catch (error) { + return { + ok: false, + error: error instanceof Error ? error.message : 'Invalid message reference ID', + }; + } } - const response = await fetch(`https://discord.com/api/v10/channels/${channelId}/messages`, { + const response = await fetch(buildDiscordApiUrl(['channels', validatedChannelId, 'messages']), { method: 'POST', headers: { Authorization: `Bot ${DISCORD_BOT_TOKEN}`, @@ -444,10 +471,26 @@ export async function addDiscordReaction( return { ok: false, error: 'DISCORD_BOT_TOKEN is not configured' }; } + let validatedChannelId: string; + let validatedMessageId: string; + try { + validatedChannelId = parseDiscordSnowflake(channelId, 'channel ID'); + validatedMessageId = parseDiscordSnowflake(messageId, 'message ID'); + } catch (error) { + return { ok: false, error: error instanceof Error ? error.message : 'Invalid Discord ID' }; + } + try { - const encodedEmoji = encodeURIComponent(emoji); const response = await fetch( - `https://discord.com/api/v10/channels/${channelId}/messages/${messageId}/reactions/${encodedEmoji}/@me`, + buildDiscordApiUrl([ + 'channels', + validatedChannelId, + 'messages', + validatedMessageId, + 'reactions', + emoji, + '@me', + ]), { method: 'PUT', headers: { Authorization: `Bot ${DISCORD_BOT_TOKEN}` }, @@ -479,10 +522,26 @@ export async function removeDiscordReaction( return { ok: false, error: 'DISCORD_BOT_TOKEN is not configured' }; } + let validatedChannelId: string; + let validatedMessageId: string; + try { + validatedChannelId = parseDiscordSnowflake(channelId, 'channel ID'); + validatedMessageId = parseDiscordSnowflake(messageId, 'message ID'); + } catch (error) { + return { ok: false, error: error instanceof Error ? error.message : 'Invalid Discord ID' }; + } + try { - const encodedEmoji = encodeURIComponent(emoji); const response = await fetch( - `https://discord.com/api/v10/channels/${channelId}/messages/${messageId}/reactions/${encodedEmoji}/@me`, + buildDiscordApiUrl([ + 'channels', + validatedChannelId, + 'messages', + validatedMessageId, + 'reactions', + emoji, + '@me', + ]), { method: 'DELETE', headers: { Authorization: `Bot ${DISCORD_BOT_TOKEN}` },