From c5e01aa2d3f77954f95c07938eb368e9ec6c6a2c Mon Sep 17 00:00:00 2001 From: shogun444 Date: Thu, 2 Jul 2026 20:42:24 +0530 Subject: [PATCH] fix(#701): persist thread messages before switching to new chat MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add optional cacheMessages method to ThreadStorage to allow storages to cache current messages before a thread switch. The default in-memory storage implements it by saving to the internal messages map. - ThreadStorage: add cacheMessages?(threadId, messages) — optional - _defaultStorage: implement cacheMessages to persist to in-memory Map - switchToNewThread: save current messages before clearing state - processMessage: persist messages after successful LLM streaming - Tests: verify cacheMessages is called in both paths with correct data --- .../src/adapters/_defaultStorage.ts | 3 ++ packages/react-headless/src/adapters/types.ts | 2 + .../store/__tests__/__helpers/makeStore.ts | 1 + .../store/__tests__/createChatStore.test.ts | 53 +++++++++++++++++++ .../src/store/createChatStore.ts | 7 +++ 5 files changed, 66 insertions(+) diff --git a/packages/react-headless/src/adapters/_defaultStorage.ts b/packages/react-headless/src/adapters/_defaultStorage.ts index 85fec9617..8ff010fae 100644 --- a/packages/react-headless/src/adapters/_defaultStorage.ts +++ b/packages/react-headless/src/adapters/_defaultStorage.ts @@ -35,6 +35,9 @@ export function createDefaultInMemoryStorage(): ChatStorage { async getMessages(threadId: string) { return messagesByThread.get(threadId) ?? []; }, + async cacheMessages(threadId: string, messages: Message[]) { + messagesByThread.set(threadId, messages); + }, async updateThread(thread: Thread) { threads = threads.map((t) => (t.id === thread.id ? thread : t)); return thread; diff --git a/packages/react-headless/src/adapters/types.ts b/packages/react-headless/src/adapters/types.ts index 6c8af8828..c5b9783ba 100644 --- a/packages/react-headless/src/adapters/types.ts +++ b/packages/react-headless/src/adapters/types.ts @@ -9,6 +9,8 @@ export interface ThreadStorage { listThreads(cursor?: string): Promise<{ threads: Thread[]; nextCursor?: string }>; createThread(firstMessage: UserMessage): Promise; getMessages(threadId: string): Promise; + /** Optional — cache current messages for a thread. Not all storages support it. */ + cacheMessages?(threadId: string, messages: Message[]): Promise; updateThread(thread: Thread): Promise; deleteThread(id: string): Promise; } diff --git a/packages/react-headless/src/store/__tests__/__helpers/makeStore.ts b/packages/react-headless/src/store/__tests__/__helpers/makeStore.ts index 4f08e1e86..54bfcd2cf 100644 --- a/packages/react-headless/src/store/__tests__/__helpers/makeStore.ts +++ b/packages/react-headless/src/store/__tests__/__helpers/makeStore.ts @@ -25,6 +25,7 @@ export function makeStore(overrides: MakeStoreOverrides = {}) { }), getMessages: vi.fn().mockResolvedValue([]), updateThread: vi.fn(async (t) => t), + cacheMessages: vi.fn().mockResolvedValue(undefined), deleteThread: vi.fn().mockResolvedValue(undefined), ...threadOverrides, }, diff --git a/packages/react-headless/src/store/__tests__/createChatStore.test.ts b/packages/react-headless/src/store/__tests__/createChatStore.test.ts index 1d28dfbab..d755c1d7e 100644 --- a/packages/react-headless/src/store/__tests__/createChatStore.test.ts +++ b/packages/react-headless/src/store/__tests__/createChatStore.test.ts @@ -156,6 +156,41 @@ describe("createChatStore", () => { expect(store.getState().messages).toEqual([]); expect(store.getState().threadError).toBeNull(); }); + + it("persists current messages via cacheMessages before clearing", () => { + const cacheMessages = vi.fn().mockResolvedValue(undefined); + const store = makeStore({ cacheMessages }); + + const msgs = [makeMessage("m1"), makeMessage("m2", "assistant")]; + store.setState({ + selectedThreadId: "t1", + messages: msgs, + }); + + store.getState().switchToNewThread(); + + expect(cacheMessages).toHaveBeenCalledWith("t1", msgs); + }); + + it("does not call cacheMessages when no thread is selected", () => { + const cacheMessages = vi.fn().mockResolvedValue(undefined); + const store = makeStore({ cacheMessages }); + + store.setState({ messages: [makeMessage("m1")] }); + store.getState().switchToNewThread(); + + expect(cacheMessages).not.toHaveBeenCalled(); + }); + + it("does not call cacheMessages when messages are empty", () => { + const cacheMessages = vi.fn().mockResolvedValue(undefined); + const store = makeStore({ cacheMessages }); + + store.setState({ selectedThreadId: "t1", messages: [] }); + store.getState().switchToNewThread(); + + expect(cacheMessages).not.toHaveBeenCalled(); + }); }); describe("createThread", () => { @@ -324,6 +359,24 @@ describe("createChatStore", () => { expect(store.getState().selectedThreadId).toBe("t-auto"); }); + it("persists messages via cacheMessages after streaming completes", async () => { + const cacheMessages = vi.fn().mockResolvedValue(undefined); + const send = vi.fn().mockResolvedValue(new Response("", { status: 200 })); + + const store = makeStore({ + cacheMessages, + send, + streamProtocol: { parse: async function* () {} }, + }); + store.setState({ selectedThreadId: "t1" }); + + await store.getState().processMessage({ role: "user", content: "hello" }); + + const finalMessages = store.getState().messages; + expect(finalMessages.length).toBeGreaterThan(0); + expect(cacheMessages).toHaveBeenCalledWith("t1", finalMessages); + }); + it("no-ops when already running", async () => { const send = vi.fn().mockResolvedValue(new Response("", { status: 200 })); diff --git a/packages/react-headless/src/store/createChatStore.ts b/packages/react-headless/src/store/createChatStore.ts index ef26cfa0c..6e7caae87 100644 --- a/packages/react-headless/src/store/createChatStore.ts +++ b/packages/react-headless/src/store/createChatStore.ts @@ -74,6 +74,10 @@ export const createChatStore = (config: CreateChatStoreConfig) => { switchToNewThread: () => { get().cancelMessage(); + const { selectedThreadId, messages } = get(); + if (selectedThreadId && messages.length > 0) { + threadStorage.cacheMessages?.(selectedThreadId, messages).catch(() => {}); + } set({ selectedThreadId: null, messages: [], @@ -201,6 +205,9 @@ export const createChatStore = (config: CreateChatStoreConfig) => { }), adapter: llm.streamProtocol, }); + + // Persist messages after successful streaming so they survive thread switches. + await threadStorage.cacheMessages?.(threadId, get().messages); } catch (e) { if (!abortController.signal.aborted) { set({ threadError: e instanceof Error ? e : new Error(String(e)) });