Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions apps/web/src/app/discord/webhook/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ import {
truncateForDiscord,
} from '@/lib/discord-bot/discord-utils';
import { getDevUserSuffix } from '@/lib/slack-bot/dev-user-info';
import {
parseForwardedGatewayMessageEvent,
type ForwardedGatewayEvent,
} from '@/lib/discord-bot/forwarded-gateway-event';

export const maxDuration = 800;

Expand All @@ -26,24 +30,6 @@ export const maxDuration = 800;
const PROCESSING_EMOJI = '\u23f3'; // hourglass
const COMPLETE_EMOJI = '\u2705'; // white check mark

/**
* Forwarded Gateway event shape (from the Gateway listener)
*/
type ForwardedGatewayEvent = {
type: string;
timestamp: number;
botUserId: string | null;
data: {
id: string;
content: string;
channel_id: string;
guild_id: string;
author: { id: string; username: string; bot?: boolean };
mentions?: Array<{ id: string }>;
message_reference?: { message_id: string };
};
};

/**
* Discord webhook handler.
* Handles:
Expand All @@ -60,10 +46,19 @@ export async function POST(request: NextRequest) {
return new NextResponse('Unauthorized', { status: 401 });
}

const event = JSON.parse(rawBody) as ForwardedGatewayEvent;
if (event.type === 'GATEWAY_MESSAGE_CREATE') {
after(processGatewayMessage(event));
let parsedBody: unknown;
try {
parsedBody = JSON.parse(rawBody);
} catch {
return new NextResponse('Invalid gateway event', { status: 400 });
}

const event = parseForwardedGatewayMessageEvent(parsedBody);
if (!event) {
return new NextResponse('Invalid gateway event', { status: 400 });
}

after(processGatewayMessage(event));
return new NextResponse(null, { status: 200 });
}

Expand Down
76 changes: 76 additions & 0 deletions apps/web/src/lib/discord-bot/discord-channel-context.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
jest.mock('@/lib/config.server', () => ({
DISCORD_BOT_TOKEN: 'bot-token',
}));

jest.mock('@sentry/nextjs', () => ({
captureException: jest.fn(),
}));

import { getDiscordConversationContext } from './discord-channel-context';

describe('getDiscordConversationContext', () => {
beforeEach(() => {
jest.restoreAllMocks();
});

it('does not fetch Discord API data when the channel ID is malformed', async () => {
const fetchSpy = jest.spyOn(globalThis, 'fetch');

const result = await getDiscordConversationContext({
channelId: '../../users/@me',
guildId: '111111111111111111',
userId: '222222222222222222',
messageId: '333333333333333333',
});

expect(fetchSpy).not.toHaveBeenCalled();
expect(result.channel).toBeNull();
expect(result.recentMessages).toEqual([]);
expect(result.errors).toEqual(['Invalid Discord channel ID']);
});

it('uses fixed-origin Discord API URLs for valid context fetches', async () => {
const fetchSpy = jest
.spyOn(globalThis, 'fetch')
.mockResolvedValueOnce(
new Response(JSON.stringify({ id: '111111111111111111', type: 0, name: 'general' }), {
status: 200,
})
)
.mockResolvedValueOnce(
new Response(
JSON.stringify([
{
id: '222222222222222222',
content: 'hello',
timestamp: '2026-06-02T00:00:00.000Z',
author: { id: '333333333333333333', username: 'alice' },
},
]),
{ status: 200 }
)
);

const result = await getDiscordConversationContext(
{
channelId: '111111111111111111',
guildId: '444444444444444444',
userId: '333333333333333333',
messageId: '222222222222222222',
},
{ channelMessages: 1 }
);

expect(result.errors).toEqual([]);
expect(fetchSpy).toHaveBeenNthCalledWith(
1,
'https://discord.com/api/v10/channels/111111111111111111',
{ headers: { Authorization: 'Bot bot-token' } }
);
expect(fetchSpy).toHaveBeenNthCalledWith(
2,
'https://discord.com/api/v10/channels/111111111111111111/messages?limit=1',
{ headers: { Authorization: 'Bot bot-token' } }
);
});
});
56 changes: 52 additions & 4 deletions apps/web/src/lib/discord-bot/discord-channel-context.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import 'server-only';
import { DISCORD_BOT_TOKEN } from '@/lib/config.server';
import { captureException } from '@sentry/nextjs';
import { buildDiscordApiUrl, parseDiscordSnowflake } from './discord-id';

export type DiscordEventContext = {
channelId: string;
Expand Down Expand Up @@ -44,14 +45,15 @@ type DiscordApiMessage = {
};

async function fetchDiscordApi<T>(
path: string
pathSegments: string[],
query?: Record<string, string | number>
): Promise<{ ok: true; data: T } | { ok: false; error: string }> {
if (!DISCORD_BOT_TOKEN) {
return { ok: false, error: 'DISCORD_BOT_TOKEN is not configured' };
}

try {
const response = await fetch(`https://discord.com/api/v10${path}`, {
const response = await fetch(buildDiscordApiUrl(pathSegments, query), {
headers: { Authorization: `Bot ${DISCORD_BOT_TOKEN}` },
});

Expand All @@ -71,7 +73,14 @@ async function fetchDiscordApi<T>(
async function getChannelInfo(
channelId: string
): Promise<{ ok: true; channel: DiscordChannelInfo } | { ok: false; error: string }> {
const result = await fetchDiscordApi<DiscordApiChannel>(`/channels/${channelId}`);
let validatedChannelId: string;
try {
validatedChannelId = parseDiscordSnowflake(channelId, 'channel ID');
} catch (error) {
return { ok: false, error: error instanceof Error ? error.message : 'Invalid channel ID' };
}

const result = await fetchDiscordApi<DiscordApiChannel>(['channels', validatedChannelId]);
if (!result.ok) return result;

return {
Expand All @@ -89,8 +98,22 @@ async function getChannelMessages(
channelId: string,
limit: number
): Promise<{ ok: true; messages: DiscordMessageForPrompt[] } | { ok: false; error: string }> {
let validatedChannelId: string;
try {
validatedChannelId = parseDiscordSnowflake(channelId, 'channel ID');
} catch (error) {
return { ok: false, error: error instanceof Error ? error.message : 'Invalid channel ID' };
}

if (!Number.isInteger(limit) || limit < 1 || limit > 100) {
return { ok: false, error: 'Invalid Discord channel message limit' };
}

const result = await fetchDiscordApi<DiscordApiMessage[]>(
`/channels/${channelId}/messages?limit=${limit}`
['channels', validatedChannelId, 'messages'],
{
limit,
}
);
if (!result.ok) return result;

Expand All @@ -111,6 +134,31 @@ export async function getDiscordConversationContext(
const channelMessagesLimit = limits?.channelMessages ?? 12;
const errors: string[] = [];

const contextIds = [
{ fieldName: 'guild ID', value: context.guildId },
{ fieldName: 'channel ID', value: context.channelId },
{ fieldName: 'user ID', value: context.userId },
{ fieldName: 'message ID', value: context.messageId },
];

for (const { fieldName, value } of contextIds) {
try {
parseDiscordSnowflake(value, fieldName);
} catch (error) {
errors.push(error instanceof Error ? error.message : `Invalid Discord ${fieldName}`);
}
}

if (errors.length > 0) {
captureException(new Error('Invalid Discord conversation context'), {
level: 'warning',
tags: { source: 'discord_conversation_context' },
extra: { errors },
});

return { channel: null, recentMessages: [], errors };
}

const [channelInfoResult, messagesResult] = await Promise.all([
getChannelInfo(context.channelId),
getChannelMessages(context.channelId, channelMessagesLimit),
Expand Down
35 changes: 35 additions & 0 deletions apps/web/src/lib/discord-bot/discord-id.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { buildDiscordApiUrl, isDiscordSnowflake, parseDiscordSnowflake } from './discord-id';

describe('discord-id', () => {
it('accepts numeric snowflake values', () => {
expect(isDiscordSnowflake('123456789012345678')).toBe(true);
expect(parseDiscordSnowflake('123456789012345678', 'user ID')).toBe('123456789012345678');
});

it.each([
'',
' ',
'abc',
'123/456',
'123?limit=1',
'123#frag',
'%2f',
'..',
'1',
'1234',
'1234567890123456',
'1'.repeat(21),
])('rejects malformed snowflake value %p', value => {
expect(isDiscordSnowflake(value)).toBe(false);
expect(() => parseDiscordSnowflake(value, 'user ID')).toThrow('Invalid Discord user ID');
});

it('builds fixed-origin Discord API URLs with encoded path segments', () => {
expect(buildDiscordApiUrl(['channels', '123', 'messages'], { limit: 12 })).toBe(
'https://discord.com/api/v10/channels/123/messages?limit=12'
);
expect(buildDiscordApiUrl(['reactions', '✅', '@me'])).toBe(
'https://discord.com/api/v10/reactions/%E2%9C%85/%40me'
);
});
});
32 changes: 32 additions & 0 deletions apps/web/src/lib/discord-bot/discord-id.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
const DISCORD_API_BASE_URL = 'https://discord.com/api/v10/';
const DISCORD_SNOWFLAKE_PATTERN = /^\d{17,20}$/;

export function isDiscordSnowflake(value: string): boolean {
return DISCORD_SNOWFLAKE_PATTERN.test(value);
}

export function parseDiscordSnowflake(value: string, fieldName: string): string {
if (isDiscordSnowflake(value)) {
return value;
}

throw new Error(`Invalid Discord ${fieldName}`);
}

export function buildDiscordApiUrl(
pathSegments: string[],
query?: Record<string, string | number>
): string {
const url = new URL(
pathSegments.map(segment => encodeURIComponent(segment)).join('/'),
DISCORD_API_BASE_URL
);

if (query) {
for (const [key, value] of Object.entries(query)) {
url.searchParams.set(key, String(value));
}
}

return url.toString();
}
49 changes: 49 additions & 0 deletions apps/web/src/lib/discord-bot/discord-utils.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
jest.mock('@/lib/config.server', () => ({
DISCORD_BOT_TOKEN: 'bot-token',
}));

import {
buildDiscordMessageLink,
replaceDiscordUserMentionsWithNames,
stripDiscordBotMention,
} from './discord-utils';

describe('discord-utils', () => {
beforeEach(() => {
jest.restoreAllMocks();
});

it('does not strip mentions when the bot ID is malformed', () => {
expect(stripDiscordBotMention('<@bot/1> hello', 'bot/1')).toBe('<@bot/1> hello');
});

it('rejects malformed message link IDs', () => {
expect(() => buildDiscordMessageLink('111111111111111111', '2/../3', '4')).toThrow(
'Invalid Discord channel ID'
);
});

it('does not fetch members when the guild ID is malformed', async () => {
const fetchSpy = jest.spyOn(globalThis, 'fetch');

await expect(
replaceDiscordUserMentionsWithNames('<@123456789012345678>', 'guild/1')
).resolves.toBe('<@123456789012345678>');
expect(fetchSpy).not.toHaveBeenCalled();
});

it('fetches valid mention IDs through the fixed Discord API origin', async () => {
const fetchSpy = jest
.spyOn(globalThis, 'fetch')
.mockResolvedValue(new Response(JSON.stringify({ nick: 'Alice' }), { status: 200 }));

await expect(
replaceDiscordUserMentionsWithNames('<@123456789012345678>', '234567890123456789')
).resolves.toBe('@Alice');

expect(fetchSpy).toHaveBeenCalledWith(
'https://discord.com/api/v10/guilds/234567890123456789/members/123456789012345678',
{ headers: { Authorization: 'Bot bot-token' } }
);
});
});
Loading