Skip to content

Commit 0f6b08d

Browse files
committed
fix(reddit): add input validation and HTTP error guards
- Add validateEnum/validatePathSegment to prevent URL path traversal - Add !response.ok guards to send_message and reply tools - Centralize subreddit validation in normalizeSubreddit
1 parent a5dbd96 commit 0f6b08d

File tree

7 files changed

+68
-2
lines changed

7 files changed

+68
-2
lines changed

apps/sim/tools/reddit/get_comments.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { validatePathSegment } from '@/lib/core/security/input-validation'
12
import type { RedditCommentsParams, RedditCommentsResponse } from '@/tools/reddit/types'
23
import { normalizeSubreddit } from '@/tools/reddit/utils'
34
import type { ToolConfig } from '@/tools/types'
@@ -114,8 +115,15 @@ export const getCommentsTool: ToolConfig<RedditCommentsParams, RedditCommentsRes
114115

115116
if (params.comment) urlParams.append('comment', params.comment)
116117

118+
// Validate postId to prevent path traversal
119+
const postId = params.postId.trim()
120+
const postIdValidation = validatePathSegment(postId, { paramName: 'postId' })
121+
if (!postIdValidation.isValid) {
122+
throw new Error(postIdValidation.error)
123+
}
124+
117125
// Build URL using OAuth endpoint
118-
return `https://oauth.reddit.com/r/${subreddit}/comments/${params.postId.trim()}?${urlParams.toString()}`
126+
return `https://oauth.reddit.com/r/${subreddit}/comments/${postId}?${urlParams.toString()}`
119127
},
120128
method: 'GET',
121129
headers: (params: RedditCommentsParams) => {

apps/sim/tools/reddit/get_messages.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
1+
import { validateEnum } from '@/lib/core/security/input-validation'
12
import type { RedditGetMessagesParams, RedditMessagesResponse } from '@/tools/reddit/types'
23
import type { ToolConfig } from '@/tools/types'
34

5+
const ALLOWED_MESSAGE_FOLDERS = [
6+
'inbox',
7+
'unread',
8+
'sent',
9+
'messages',
10+
'comments',
11+
'selfreply',
12+
'mentions',
13+
] as const
14+
415
export const getMessagesTool: ToolConfig<RedditGetMessagesParams, RedditMessagesResponse> = {
516
id: 'reddit_get_messages',
617
name: 'Get Reddit Messages',
@@ -67,6 +78,10 @@ export const getMessagesTool: ToolConfig<RedditGetMessagesParams, RedditMessages
6778
request: {
6879
url: (params: RedditGetMessagesParams) => {
6980
const where = params.where || 'inbox'
81+
const validation = validateEnum(where, ALLOWED_MESSAGE_FOLDERS, 'where')
82+
if (!validation.isValid) {
83+
throw new Error(validation.error)
84+
}
7085
const limit = Math.min(Math.max(1, params.limit ?? 25), 100)
7186

7287
const urlParams = new URLSearchParams({

apps/sim/tools/reddit/get_posts.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import { validateEnum } from '@/lib/core/security/input-validation'
12
import type { RedditPostsParams, RedditPostsResponse } from '@/tools/reddit/types'
23
import { normalizeSubreddit } from '@/tools/reddit/utils'
34
import type { ToolConfig } from '@/tools/types'
45

6+
const ALLOWED_SORT_OPTIONS = ['hot', 'new', 'top', 'controversial', 'rising'] as const
7+
58
export const getPostsTool: ToolConfig<RedditPostsParams, RedditPostsResponse> = {
69
id: 'reddit_get_posts',
710
name: 'Get Reddit Posts',
@@ -88,6 +91,10 @@ export const getPostsTool: ToolConfig<RedditPostsParams, RedditPostsResponse> =
8891
url: (params: RedditPostsParams) => {
8992
const subreddit = normalizeSubreddit(params.subreddit)
9093
const sort = params.sort || 'hot'
94+
const sortValidation = validateEnum(sort, ALLOWED_SORT_OPTIONS, 'sort')
95+
if (!sortValidation.isValid) {
96+
throw new Error(sortValidation.error)
97+
}
9198
const limit = Math.min(Math.max(1, params.limit ?? 10), 100)
9299

93100
// Build URL with appropriate parameters using OAuth endpoint

apps/sim/tools/reddit/get_user.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { validatePathSegment } from '@/lib/core/security/input-validation'
12
import type { RedditGetUserParams, RedditUserResponse } from '@/tools/reddit/types'
23
import type { ToolConfig } from '@/tools/types'
34

@@ -30,6 +31,10 @@ export const getUserTool: ToolConfig<RedditGetUserParams, RedditUserResponse> =
3031
request: {
3132
url: (params: RedditGetUserParams) => {
3233
const username = params.username.trim().replace(/^u\//, '')
34+
const validation = validatePathSegment(username, { paramName: 'username' })
35+
if (!validation.isValid) {
36+
throw new Error(validation.error)
37+
}
3338
return `https://oauth.reddit.com/user/${username}/about?raw_json=1`
3439
},
3540
method: 'GET',

apps/sim/tools/reddit/reply.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,17 @@ export const replyTool: ToolConfig<RedditReplyParams, RedditWriteResponse> = {
7171
transformResponse: async (response: Response) => {
7272
const data = await response.json()
7373

74+
if (!response.ok) {
75+
const errorMsg = data?.message || `HTTP error ${response.status}`
76+
return {
77+
success: false,
78+
output: {
79+
success: false,
80+
message: `Failed to post reply: ${errorMsg}`,
81+
},
82+
}
83+
}
84+
7485
// Reddit API returns errors in json.errors array
7586
if (data.json?.errors && data.json.errors.length > 0) {
7687
const errors = data.json.errors.map((err: any) => err.join(': ')).join(', ')

apps/sim/tools/reddit/send_message.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,17 @@ export const sendMessageTool: ToolConfig<RedditSendMessageParams, RedditWriteRes
7878
transformResponse: async (response: Response) => {
7979
const data = await response.json()
8080

81+
if (!response.ok) {
82+
const errorMsg = data?.message || `HTTP error ${response.status}`
83+
return {
84+
success: false,
85+
output: {
86+
success: false,
87+
message: `Failed to send message: ${errorMsg}`,
88+
},
89+
}
90+
}
91+
8192
if (data.json?.errors && data.json.errors.length > 0) {
8293
const errors = data.json.errors.map((err: any) => err.join(': ')).join(', ')
8394
return {

apps/sim/tools/reddit/utils.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
1+
import { validatePathSegment } from '@/lib/core/security/input-validation'
2+
13
const SUBREDDIT_PREFIX = /^r\//
24

35
/**
46
* Normalizes a subreddit name by removing the 'r/' prefix if present and trimming whitespace.
7+
* Validates the result to prevent path traversal attacks.
58
* @param subreddit - The subreddit name to normalize
69
* @returns The normalized subreddit name without the 'r/' prefix
10+
* @throws Error if the subreddit name contains invalid characters
711
*/
812
export function normalizeSubreddit(subreddit: string): string {
9-
return subreddit.trim().replace(SUBREDDIT_PREFIX, '')
13+
const normalized = subreddit.trim().replace(SUBREDDIT_PREFIX, '')
14+
const validation = validatePathSegment(normalized, { paramName: 'subreddit' })
15+
if (!validation.isValid) {
16+
throw new Error(validation.error)
17+
}
18+
return normalized
1019
}

0 commit comments

Comments
 (0)