From c16d87fabfb42afd90df781ea4b306565d966a02 Mon Sep 17 00:00:00 2001 From: lisiqi1983 <169274571+lisiqi1983@users.noreply.github.com> Date: Wed, 13 May 2026 10:44:47 +0800 Subject: [PATCH] fix allowlist literal matching --- src/allowlist/index.test.ts | 106 ++++++++++++++++++++++++++++++++++++ src/allowlist/index.ts | 29 +++++++++- 2 files changed, 132 insertions(+), 3 deletions(-) create mode 100644 src/allowlist/index.test.ts diff --git a/src/allowlist/index.test.ts b/src/allowlist/index.test.ts new file mode 100644 index 0000000..f35b7dc --- /dev/null +++ b/src/allowlist/index.test.ts @@ -0,0 +1,106 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { isQueryAllowed } from './index' +import { DataSource } from '../types' + +const createDataSource = (rows: unknown[] = []): DataSource => + ({ + source: 'internal', + rpc: { + executeQuery: vi.fn().mockResolvedValue(rows), + }, + }) as unknown as DataSource + +describe('Allowlist Module', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('allows queries when allowlist enforcement is disabled', async () => { + const dataSource = createDataSource() + + await expect( + isQueryAllowed({ + sql: 'SELECT * FROM users', + isEnabled: false, + dataSource, + config: { role: 'user' } as any, + }) + ).resolves.toBe(true) + + expect(dataSource.rpc.executeQuery).not.toHaveBeenCalled() + }) + + it('allows admin requests without loading allowlist entries', async () => { + const dataSource = createDataSource() + + await expect( + isQueryAllowed({ + sql: 'DELETE FROM users WHERE id = 1', + isEnabled: true, + dataSource, + config: { role: 'admin' } as any, + }) + ).resolves.toBe(true) + + expect(dataSource.rpc.executeQuery).not.toHaveBeenCalled() + }) + + it('allows structurally matching queries with different literal values', async () => { + const dataSource = createDataSource([ + { + sql_statement: 'SELECT * FROM users WHERE id = 1', + source: 'internal', + }, + ]) + + await expect( + isQueryAllowed({ + sql: 'SELECT * FROM users WHERE id = 42', + isEnabled: true, + dataSource, + config: { role: 'user' } as any, + }) + ).resolves.toBe(true) + }) + + it('normalizes trailing semicolons before comparing queries', async () => { + const dataSource = createDataSource([ + { + sql_statement: 'SELECT * FROM users WHERE id = 1', + source: 'internal', + }, + ]) + + await expect( + isQueryAllowed({ + sql: 'SELECT * FROM users WHERE id = 1;', + isEnabled: true, + dataSource, + config: { role: 'user' } as any, + }) + ).resolves.toBe(true) + }) + + it('records rejected queries before returning a query-not-allowed error', async () => { + const dataSource = createDataSource([ + { + sql_statement: 'SELECT * FROM users WHERE id = 1', + source: 'internal', + }, + ]) + + await expect( + isQueryAllowed({ + sql: 'SELECT email FROM users WHERE id = 1', + isEnabled: true, + dataSource, + config: { role: 'user' } as any, + }) + ).rejects.toThrow('Query not allowed') + + expect(dataSource.rpc.executeQuery).toHaveBeenCalledWith({ + sql: 'INSERT INTO tmp_allowlist_rejections (sql_statement, source) VALUES (?, ?)', + params: ['SELECT email FROM users WHERE id = 1', 'internal'], + }) + }) +}) diff --git a/src/allowlist/index.ts b/src/allowlist/index.ts index 71c12e3..acc616c 100644 --- a/src/allowlist/index.ts +++ b/src/allowlist/index.ts @@ -79,6 +79,20 @@ export async function isQueryAllowed(opts: { const normalizedQuery = parser.astify(normalizeSQL(sql)) // Compare ASTs while ignoring specific values + const literalNodeTypes = new Set([ + 'number', + 'string', + 'bool', + 'null', + 'single_quote_string', + 'double_quote_string', + ]) + + const isLiteralNode = (node: any): boolean => + node !== null && + typeof node === 'object' && + literalNodeTypes.has(node.type) + const deepCompareAst = (allowedAst: any, queryAst: any): boolean => { if (typeof allowedAst !== typeof queryAst) return false @@ -97,9 +111,18 @@ export async function isQueryAllowed(opts: { if (allowedKeys.length !== queryKeys.length) return false - return allowedKeys.every((key) => - deepCompareAst(allowedAst[key], queryAst[key]) - ) + return allowedKeys.every((key) => { + if ( + key === 'value' && + isLiteralNode(allowedAst) && + isLiteralNode(queryAst) && + allowedAst.type === queryAst.type + ) { + return true + } + + return deepCompareAst(allowedAst[key], queryAst[key]) + }) } // Base case: Primitive value comparison