From fc12e7c97b4943b2544b5390ffdc97b256ed9167 Mon Sep 17 00:00:00 2001 From: Maxwell Calkin Date: Sun, 8 Mar 2026 13:03:48 -0400 Subject: [PATCH] =?UTF-8?q?fix:=20respect=20AbortSignal=20in=20run()=20?= =?UTF-8?q?=E2=80=94=20throw=20AbortError=20on=20pre-aborted=20and=20mid-p?= =?UTF-8?q?olling=20abort?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #370. The run() method previously ignored AbortSignal in three ways: 1. If the signal was already aborted before calling run(), the prediction started anyway instead of throwing immediately. 2. During polling, the signal was not passed to HTTP requests or the sleep timer, so abort could not interrupt in-flight requests or waits. 3. After cancellation, run() returned undefined instead of throwing an AbortError, so callers had no way to distinguish abort from failure. Changes: - Check signal.aborted before starting the prediction and throw AbortError - Pass signal through to predictions.create() so the initial HTTP request can be aborted by the signal - Pass signal to wait() and forward it to predictions.get() polling calls - Make sleep() in wait() abort-aware: listen for the abort event to resolve immediately instead of waiting the full interval - Throw AbortError after canceling the prediction, matching standard AbortSignal behavior (e.g. fetch) - Update TypeScript types to include signal option on wait() - Update existing abort test to expect thrown error, add pre-aborted test Co-Authored-By: Claude Opus 4.6 --- index.d.ts | 1 + index.js | 60 +++++++++++++++++++++++++++++++++++++++++++++++---- index.test.ts | 38 +++++++++++++++++++++++--------- 3 files changed, 85 insertions(+), 14 deletions(-) diff --git a/index.d.ts b/index.d.ts index d4e0784..1c184b2 100644 --- a/index.d.ts +++ b/index.d.ts @@ -232,6 +232,7 @@ declare module "replicate" { prediction: Prediction, options?: { interval?: number; + signal?: AbortSignal; }, stop?: (prediction: Prediction) => Promise ): Promise; diff --git a/index.js b/index.js index 3fc14dd..65c2515 100644 --- a/index.js +++ b/index.js @@ -145,6 +145,13 @@ class Replicate { async run(ref, options, progress) { const { wait = { mode: "block" }, signal, ...data } = options; + // Check if the signal is already aborted before starting + if (signal && signal.aborted) { + const error = new Error("Prediction canceled"); + error.name = "AbortError"; + throw error; + } + const identifier = ModelVersionIdentifier.parse(ref); let prediction; @@ -153,12 +160,14 @@ class Replicate { ...data, version: identifier.version, wait: wait.mode === "block" ? wait.timeout ?? true : false, + signal, }); } else if (identifier.owner && identifier.name) { prediction = await this.predictions.create({ ...data, model: `${identifier.owner}/${identifier.name}`, wait: wait.mode === "block" ? wait.timeout ?? true : false, + signal, }); } else { throw new Error("Invalid model version identifier"); @@ -173,7 +182,10 @@ class Replicate { if (!isDone) { prediction = await this.wait( prediction, - { interval: wait.mode === "poll" ? wait.interval : undefined }, + { + interval: wait.mode === "poll" ? wait.interval : undefined, + signal, + }, async (updatedPrediction) => { // Call progress callback with the updated prediction object if (progress) { @@ -192,6 +204,15 @@ class Replicate { if (signal && signal.aborted) { prediction = await this.predictions.cancel(prediction.id); + + // Call progress callback with the canceled prediction object + if (progress) { + progress(prediction); + } + + const error = new Error("Prediction canceled"); + error.name = "AbortError"; + throw error; } // Call progress callback with the completed prediction object @@ -411,12 +432,32 @@ class Replicate { return prediction; } + const signal = options && options.signal; + // eslint-disable-next-line no-promise-executor-return - const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms)); + const sleep = (ms) => + new Promise((resolve, reject) => { + if (signal && signal.aborted) { + return resolve(); + } + + const timer = setTimeout(resolve, ms); + + if (signal) { + signal.addEventListener( + "abort", + () => { + clearTimeout(timer); + resolve(); + }, + { once: true } + ); + } + }); const interval = (options && options.interval) || 500; - let updatedPrediction = await this.predictions.get(id); + let updatedPrediction = await this.predictions.get(id, { signal }); while ( updatedPrediction.status !== "succeeded" && @@ -428,8 +469,19 @@ class Replicate { break; } + if (signal && signal.aborted) { + break; + } + await sleep(interval); - updatedPrediction = await this.predictions.get(prediction.id); + + if (signal && signal.aborted) { + break; + } + + updatedPrediction = await this.predictions.get(prediction.id, { + signal, + }); /* eslint-enable no-await-in-loop */ } diff --git a/index.test.ts b/index.test.ts index 96f50db..817a47a 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1598,20 +1598,22 @@ describe("Replicate client", () => { }); const onProgress = jest.fn(); - const output = await client.run( - "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - { - wait: { mode: "poll" }, - input: { text: "Hello, world!" }, - signal, - }, - onProgress - ); + + await expect( + client.run( + "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + wait: { mode: "poll" }, + input: { text: "Hello, world!" }, + signal, + }, + onProgress + ) + ).rejects.toThrow("Prediction canceled"); expect(body).toBeDefined(); expect(body?.["signal"]).toBeUndefined(); expect(signal.aborted).toBe(true); - expect(output).toBeUndefined(); expect(onProgress).toHaveBeenNthCalledWith( 1, @@ -1635,6 +1637,22 @@ describe("Replicate client", () => { scope.done(); }); + test("Throws immediately when abort signal is already aborted", async () => { + const controller = new AbortController(); + controller.abort(); + const { signal } = controller; + + await expect( + client.run( + "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { text: "Hello, world!" }, + signal, + } + ) + ).rejects.toThrow("Prediction canceled"); + }); + test("returns FileOutput for URLs when useFileOutput is true", async () => { client = new Replicate({ auth: "foo", useFileOutput: true });