diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e01fae..ea9e5d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ ## Unreleased +### Fix transient memory amplification in `batch_get_documents_tool` (#4) + +- Deduplicate intra-batch URLs by normalized key: multiple input URLs that resolve to the same page (e.g. cache-busting query params) now share a single HTTP request and a single buffered body instead of each triggering a separate fetch +- Add streaming response body size cap via `readBodyWithLimit`: aborts reads that exceed 2 MB before the full body is buffered; also rejects on `Content-Length` before reading begins +- Apply the same body size cap to `get_document_tool` + ### Fix unbounded cache growth in `DocCache` (#2) - Added a 512-entry LRU eviction limit to prevent unbounded Map growth diff --git a/src/tools/batch-get-documents-tool/BatchGetDocumentsTool.ts b/src/tools/batch-get-documents-tool/BatchGetDocumentsTool.ts index c7ec529..a08123a 100644 --- a/src/tools/batch-get-documents-tool/BatchGetDocumentsTool.ts +++ b/src/tools/batch-get-documents-tool/BatchGetDocumentsTool.ts @@ -2,7 +2,12 @@ // Licensed under the MIT License. import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; -import { docCache } from '../../utils/docCache.js'; +import { + docCache, + normalizeCacheKey, + MAX_ENTRY_BYTES, + readBodyWithLimit +} from '../../utils/docCache.js'; import type { HttpRequest } from '../../utils/types.js'; import { BaseTool } from '../BaseTool.js'; import { @@ -60,13 +65,20 @@ export class BatchGetDocumentsTool extends BaseTool< }; } - const results = await Promise.allSettled( - input.urls.map(async (url): Promise => { - const cached = docCache.get(url); - if (cached !== null) { - return { url, content: cached }; - } + // One in-flight fetch per unique normalized URL. Multiple input URLs that + // normalize to the same key (e.g. cache-busting query params) share a + // single HTTP request and a single buffered body. + const inflightByKey = new Map>(); + + const fetchOne = (url: string): Promise => { + const cached = docCache.get(url); + if (cached !== null) return Promise.resolve(cached); + + const key = normalizeCacheKey(url); + const existing = inflightByKey.get(key); + if (existing) return existing; + const promise = (async () => { const response = await this.httpRequest(url, { headers: { Accept: 'text/markdown, text/plain;q=0.9, */*;q=0.8' } }); @@ -75,8 +87,18 @@ export class BatchGetDocumentsTool extends BaseTool< throw new Error(`${response.status} ${response.statusText}`); } - const content = await response.text(); + const content = await readBodyWithLimit(response, MAX_ENTRY_BYTES); docCache.set(url, content); + return content; + })(); + + inflightByKey.set(key, promise); + return promise; + }; + + const results = await Promise.allSettled( + input.urls.map(async (url): Promise => { + const content = await fetchOne(url); return { url, content }; }) ); diff --git a/src/tools/get-document-tool/GetDocumentTool.ts b/src/tools/get-document-tool/GetDocumentTool.ts index 07a4d26..c723144 100644 --- a/src/tools/get-document-tool/GetDocumentTool.ts +++ b/src/tools/get-document-tool/GetDocumentTool.ts @@ -2,7 +2,11 @@ // Licensed under the MIT License. import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; -import { docCache } from '../../utils/docCache.js'; +import { + docCache, + MAX_ENTRY_BYTES, + readBodyWithLimit +} from '../../utils/docCache.js'; import type { HttpRequest } from '../../utils/types.js'; import { BaseTool } from '../BaseTool.js'; import { @@ -73,7 +77,7 @@ export class GetDocumentTool extends BaseTool { }; } - const content = await response.text(); + const content = await readBodyWithLimit(response, MAX_ENTRY_BYTES); docCache.set(input.url, content); return { content: [{ type: 'text', text: content }], isError: false }; diff --git a/src/utils/docCache.ts b/src/utils/docCache.ts index ceefef9..b186735 100644 --- a/src/utils/docCache.ts +++ b/src/utils/docCache.ts @@ -8,7 +8,7 @@ const DEFAULT_TTL_MS = parseInt( // Cache limits const MAX_ENTRIES = 512; -const MAX_ENTRY_BYTES = 2 * 1024 * 1024; // 2 MB per entry +export const MAX_ENTRY_BYTES = 2 * 1024 * 1024; // 2 MB per entry const MAX_TOTAL_BYTES = 50 * 1024 * 1024; // 50 MB total interface CacheEntry { @@ -104,3 +104,51 @@ class DocCache { } export const docCache = new DocCache(); + +/** + * Read a Response body up to `maxBytes`, aborting early if the limit is + * exceeded. Checks Content-Length first when present so no bytes are + * buffered for obviously-oversized responses. + */ +export async function readBodyWithLimit( + response: Response, + maxBytes: number +): Promise { + const contentLength = response.headers.get('content-length'); + if (contentLength) { + const cl = parseInt(contentLength, 10); + if (Number.isFinite(cl) && cl > maxBytes) { + throw new Error( + `Response too large: Content-Length ${cl} exceeds limit of ${maxBytes} bytes` + ); + } + } + + if (!response.body) { + const text = await response.text(); + if (Buffer.byteLength(text, 'utf8') > maxBytes) { + throw new Error('Response too large'); + } + return text; + } + + const chunks: Buffer[] = []; + let totalBytes = 0; + const reader = response.body.getReader(); + + try { + for (;;) { + const { done, value } = await reader.read(); + if (done) break; + totalBytes += value.byteLength; + if (totalBytes > maxBytes) { + throw new Error('Response too large'); + } + chunks.push(Buffer.from(value)); + } + } finally { + reader.releaseLock(); + } + + return Buffer.concat(chunks).toString('utf8'); +} diff --git a/test/tools/batch-get-documents-tool/BatchGetDocumentsTool.test.ts b/test/tools/batch-get-documents-tool/BatchGetDocumentsTool.test.ts new file mode 100644 index 0000000..cc3ae66 --- /dev/null +++ b/test/tools/batch-get-documents-tool/BatchGetDocumentsTool.test.ts @@ -0,0 +1,147 @@ +// Copyright (c) Mapbox, Inc. +// Licensed under the MIT License. + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { BatchGetDocumentsTool } from '../../../src/tools/batch-get-documents-tool/BatchGetDocumentsTool.js'; +import { docCache } from '../../../src/utils/docCache.js'; + +beforeEach(() => { + docCache.clear(); +}); + +function makeResponse(body: string, status = 200): Response { + return new Response(body, { + status, + headers: { + 'content-type': 'text/plain', + 'content-length': String(Buffer.byteLength(body, 'utf8')) + } + }); +} + +describe('BatchGetDocumentsTool', () => { + describe('intra-batch URL deduplication', () => { + it('issues only one HTTP request for multiple URLs with the same normalized key', async () => { + const httpRequest = vi + .fn() + .mockResolvedValue(makeResponse('page content')); + const tool = new BatchGetDocumentsTool({ httpRequest }); + + const urls = [ + 'https://docs.mapbox.com/page?bust=1', + 'https://docs.mapbox.com/page?bust=2', + 'https://docs.mapbox.com/page?bust=3' + ]; + + const result = await tool.run({ urls }); + + expect(httpRequest).toHaveBeenCalledTimes(1); + expect(result.isError).toBe(false); + + const output = JSON.parse((result.content[0] as { text: string }).text); + expect(output).toHaveLength(3); + expect( + output.every((r: { content: string }) => r.content === 'page content') + ).toBe(true); + }); + + it('issues one request per distinct normalized URL', async () => { + const httpRequest = vi + .fn() + .mockResolvedValueOnce(makeResponse('page A')) + .mockResolvedValueOnce(makeResponse('page B')); + const tool = new BatchGetDocumentsTool({ httpRequest }); + + const urls = [ + 'https://docs.mapbox.com/a?x=1', + 'https://docs.mapbox.com/a?x=2', + 'https://docs.mapbox.com/b?x=1' + ]; + + const result = await tool.run({ urls }); + + expect(httpRequest).toHaveBeenCalledTimes(2); + expect(result.isError).toBe(false); + + const output = JSON.parse((result.content[0] as { text: string }).text); + expect(output[0].content).toBe('page A'); + expect(output[1].content).toBe('page A'); + expect(output[2].content).toBe('page B'); + }); + + it('uses cached content and skips fetch for already-cached normalized URL', async () => { + docCache.set('https://docs.mapbox.com/page', 'cached content'); + const httpRequest = vi.fn(); + const tool = new BatchGetDocumentsTool({ httpRequest }); + + const result = await tool.run({ + urls: [ + 'https://docs.mapbox.com/page?bust=1', + 'https://docs.mapbox.com/page?bust=2' + ] + }); + + expect(httpRequest).not.toHaveBeenCalled(); + const output = JSON.parse((result.content[0] as { text: string }).text); + expect( + output.every((r: { content: string }) => r.content === 'cached content') + ).toBe(true); + }); + }); + + describe('response body size limit', () => { + it('returns an error for a URL whose Content-Length exceeds the limit', async () => { + const oversizeHeaders = new Headers({ + 'content-type': 'text/plain', + 'content-length': String(3 * 1024 * 1024) // 3 MB > 2 MB limit + }); + const httpRequest = vi + .fn() + .mockResolvedValue( + new Response('x', { status: 200, headers: oversizeHeaders }) + ); + const tool = new BatchGetDocumentsTool({ httpRequest }); + + const result = await tool.run({ + urls: ['https://docs.mapbox.com/page'] + }); + + expect(result.isError).toBe(false); // batch doesn't fail entirely + const output = JSON.parse((result.content[0] as { text: string }).text); + expect(output[0].error).toMatch(/too large/i); + }); + }); + + describe('invalid URLs', () => { + it('rejects non-mapbox URLs', async () => { + const httpRequest = vi.fn(); + const tool = new BatchGetDocumentsTool({ httpRequest }); + + const result = await tool.run({ + urls: ['https://evil.com/page'] + }); + + expect(result.isError).toBe(true); + expect(httpRequest).not.toHaveBeenCalled(); + }); + }); + + describe('HTTP errors', () => { + it('returns per-URL error on non-ok response', async () => { + const httpRequest = vi + .fn() + .mockResolvedValue( + new Response('Not Found', { status: 404, statusText: 'Not Found' }) + ); + const tool = new BatchGetDocumentsTool({ httpRequest }); + + const result = await tool.run({ + urls: ['https://docs.mapbox.com/missing'] + }); + + expect(result.isError).toBe(false); + const output = JSON.parse((result.content[0] as { text: string }).text); + expect(output[0].error).toBe('404 Not Found'); + }); + }); +});