Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions src/allowlist/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { isQueryAllowed } from './index'
import { DataSource } from '../types'

const createDataSource = (rows: unknown[] = []): DataSource =>
({
source: 'internal',
rpc: {
executeQuery: vi.fn().mockResolvedValue(rows),
},
}) as unknown as DataSource

describe('Allowlist Module', () => {
beforeEach(() => {
vi.clearAllMocks()
})

it('allows queries when allowlist enforcement is disabled', async () => {
const dataSource = createDataSource()

await expect(
isQueryAllowed({
sql: 'SELECT * FROM users',
isEnabled: false,
dataSource,
config: { role: 'user' } as any,
})
).resolves.toBe(true)

expect(dataSource.rpc.executeQuery).not.toHaveBeenCalled()
})

it('allows admin requests without loading allowlist entries', async () => {
const dataSource = createDataSource()

await expect(
isQueryAllowed({
sql: 'DELETE FROM users WHERE id = 1',
isEnabled: true,
dataSource,
config: { role: 'admin' } as any,
})
).resolves.toBe(true)

expect(dataSource.rpc.executeQuery).not.toHaveBeenCalled()
})

it('allows structurally matching queries with different literal values', async () => {
const dataSource = createDataSource([
{
sql_statement: 'SELECT * FROM users WHERE id = 1',
source: 'internal',
},
])

await expect(
isQueryAllowed({
sql: 'SELECT * FROM users WHERE id = 42',
isEnabled: true,
dataSource,
config: { role: 'user' } as any,
})
).resolves.toBe(true)
})

it('normalizes trailing semicolons before comparing queries', async () => {
const dataSource = createDataSource([
{
sql_statement: 'SELECT * FROM users WHERE id = 1',
source: 'internal',
},
])

await expect(
isQueryAllowed({
sql: 'SELECT * FROM users WHERE id = 1;',
isEnabled: true,
dataSource,
config: { role: 'user' } as any,
})
).resolves.toBe(true)
})

it('records rejected queries before returning a query-not-allowed error', async () => {
const dataSource = createDataSource([
{
sql_statement: 'SELECT * FROM users WHERE id = 1',
source: 'internal',
},
])

await expect(
isQueryAllowed({
sql: 'SELECT email FROM users WHERE id = 1',
isEnabled: true,
dataSource,
config: { role: 'user' } as any,
})
).rejects.toThrow('Query not allowed')

expect(dataSource.rpc.executeQuery).toHaveBeenCalledWith({
sql: 'INSERT INTO tmp_allowlist_rejections (sql_statement, source) VALUES (?, ?)',
params: ['SELECT email FROM users WHERE id = 1', 'internal'],
})
})
})
29 changes: 26 additions & 3 deletions src/allowlist/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ export async function isQueryAllowed(opts: {
const normalizedQuery = parser.astify(normalizeSQL(sql))

// Compare ASTs while ignoring specific values
const literalNodeTypes = new Set([
'number',
'string',
'bool',
'null',
'single_quote_string',
'double_quote_string',
])

const isLiteralNode = (node: any): boolean =>
node !== null &&
typeof node === 'object' &&
literalNodeTypes.has(node.type)

const deepCompareAst = (allowedAst: any, queryAst: any): boolean => {
if (typeof allowedAst !== typeof queryAst) return false

Expand All @@ -97,9 +111,18 @@ export async function isQueryAllowed(opts: {

if (allowedKeys.length !== queryKeys.length) return false

return allowedKeys.every((key) =>
deepCompareAst(allowedAst[key], queryAst[key])
)
return allowedKeys.every((key) => {
if (
key === 'value' &&
isLiteralNode(allowedAst) &&
isLiteralNode(queryAst) &&
allowedAst.type === queryAst.type
) {
return true
}

return deepCompareAst(allowedAst[key], queryAst[key])
})
}

// Base case: Primitive value comparison
Expand Down