From b9e61f4593e4e7551fcac8320c28c886ddeb944e Mon Sep 17 00:00:00 2001 From: Rene Heijdens <101724050+H31nz3l@users.noreply.github.com> Date: Tue, 10 Mar 2026 21:51:04 +0100 Subject: [PATCH] fiix: Improve subscription ID handling and connection management, thanks @H31nz3l (#10485) --------- Co-authored-by: Dmitry Patsura --- .../src/ws/local-subscription-store.ts | 39 ++-- .../src/ws/message-schema.ts | 2 +- .../src/ws/subscription-server.ts | 4 +- .../test/ws/local-subscription-store.test.ts | 167 ++++++++++++++++++ .../test/ws/subscription-server.test.ts | 13 +- 5 files changed, 206 insertions(+), 19 deletions(-) create mode 100644 packages/cubejs-api-gateway/test/ws/local-subscription-store.test.ts diff --git a/packages/cubejs-api-gateway/src/ws/local-subscription-store.ts b/packages/cubejs-api-gateway/src/ws/local-subscription-store.ts index 695e572b3ab2a..ebc8b0f1d16b7 100644 --- a/packages/cubejs-api-gateway/src/ws/local-subscription-store.ts +++ b/packages/cubejs-api-gateway/src/ws/local-subscription-store.ts @@ -2,8 +2,6 @@ interface LocalSubscriptionStoreOptions { heartBeatInterval?: number; } -export type SubscriptionId = string | number; - export type LocalSubscriptionStoreSubscription = { message: any, state: any, @@ -11,21 +9,26 @@ export type LocalSubscriptionStoreSubscription = { }; export type LocalSubscriptionStoreConnection = { - subscriptions: Map, + subscriptions: Map, authContext?: any, }; export class LocalSubscriptionStore { protected readonly connections: Map = new Map(); - protected readonly hearBeatInterval: number; + protected readonly heartBeatInterval: number; public constructor(options: LocalSubscriptionStoreOptions = {}) { - this.hearBeatInterval = options.heartBeatInterval || 60; + this.heartBeatInterval = options.heartBeatInterval || 60; } - public async getSubscription(connectionId: string, subscriptionId: string) { - const connection = this.getConnectionOrCreate(connectionId); + public async getSubscription(connectionId: string, subscriptionId: string): Promise { + // only get subscription, do not create connection if it doesn't exist + const connection = this.getConnection(connectionId); + if (!connection) { + return undefined; + } + return connection.subscriptions.get(subscriptionId); } @@ -37,14 +40,22 @@ export class LocalSubscriptionStore { }); } - public async unsubscribe(connectionId: string, subscriptionId: SubscriptionId) { - const connection = this.getConnectionOrCreate(connectionId); + public async unsubscribe(connectionId: string, subscriptionId: string) { + const connection = this.getConnection(connectionId); + if (!connection) { + return; + } + + if (!connection.subscriptions.has(subscriptionId)) { + return; + } + connection.subscriptions.delete(subscriptionId); } public getAllSubscriptions() { const now = Date.now(); - const staleThreshold = this.hearBeatInterval * 4 * 1000; + const staleThreshold = this.heartBeatInterval * 4 * 1000; const result: Array<{ connectionId: string } & LocalSubscriptionStoreSubscription> = []; for (const [connectionId, connection] of this.connections) { @@ -75,17 +86,21 @@ export class LocalSubscriptionStore { } protected getConnectionOrCreate(connectionId: string): LocalSubscriptionStoreConnection { - const connect = this.connections.get(connectionId); + const connect = this.getConnection(connectionId); if (connect) { return connect; } - const connection = { subscriptions: new Map() }; + const connection: LocalSubscriptionStoreConnection = { subscriptions: new Map() }; this.connections.set(connectionId, connection); return connection; } + protected getConnection(connectionId: string): LocalSubscriptionStoreConnection | undefined { + return this.connections.get(connectionId); + } + public clear() { this.connections.clear(); } diff --git a/packages/cubejs-api-gateway/src/ws/message-schema.ts b/packages/cubejs-api-gateway/src/ws/message-schema.ts index 5730e75672f2c..16469f1e0bc25 100644 --- a/packages/cubejs-api-gateway/src/ws/message-schema.ts +++ b/packages/cubejs-api-gateway/src/ws/message-schema.ts @@ -1,6 +1,6 @@ import { z } from 'zod'; -const messageId = z.union([z.string().max(16), z.int()]); +const messageId = z.union([z.string().max(16), z.int()]).transform(String); const requestId = z.string().max(64).optional(); export const authMessageSchema = z.object({ diff --git a/packages/cubejs-api-gateway/src/ws/subscription-server.ts b/packages/cubejs-api-gateway/src/ws/subscription-server.ts index 8555830e62171..781ce28cfc27c 100644 --- a/packages/cubejs-api-gateway/src/ws/subscription-server.ts +++ b/packages/cubejs-api-gateway/src/ws/subscription-server.ts @@ -38,7 +38,7 @@ export class SubscriptionServer { ) { } - protected resultFn(connectionId: string, messageId: string | number | undefined, requestId: string | undefined, logNetworkUsage: boolean = true) { + protected resultFn(connectionId: string, messageId: string | undefined, requestId: string | undefined, logNetworkUsage: boolean = true) { return async (message, { status } = { status: 200 }) => { if (logNetworkUsage) { this.apiGateway.log({ type: 'Outgoing network usage', service: 'api-ws', bytes: calcMessageLength(message), }, { requestId }); @@ -158,7 +158,7 @@ export class SubscriptionServer { throw new UserError(`Unsupported method: ${message.method}`); } - const subscriptionId = String(message.messageId); + const subscriptionId = message.messageId; const baseRequestId = message.requestId || `${connectionId}-${subscriptionId}`; const requestId = `${baseRequestId}-span-${uuidv4()}`; diff --git a/packages/cubejs-api-gateway/test/ws/local-subscription-store.test.ts b/packages/cubejs-api-gateway/test/ws/local-subscription-store.test.ts new file mode 100644 index 0000000000000..2d45e29828953 --- /dev/null +++ b/packages/cubejs-api-gateway/test/ws/local-subscription-store.test.ts @@ -0,0 +1,167 @@ +import { + LocalSubscriptionStore, +} from '../../src/ws/local-subscription-store'; + +describe('LocalSubscriptionStore', () => { + it('stores and retrieves subscription by id', async () => { + const store = new LocalSubscriptionStore(); + + await store.subscribe('conn-1', 'sub-1', { + message: { method: 'load' }, + state: { foo: 'bar' } + }); + + const subscription = await store.getSubscription('conn-1', 'sub-1'); + + expect(subscription).toBeDefined(); + expect(subscription?.message).toEqual({ method: 'load' }); + expect(subscription?.state).toEqual({ foo: 'bar' }); + expect(subscription?.timestamp).toBeInstanceOf(Date); + }); + + it('stores and retrieves subscription by string id', async () => { + const store = new LocalSubscriptionStore(); + + await store.subscribe('conn-1', '123', { + message: { method: 'load' }, + state: { answer: true } + }); + + const result = await store.getSubscription('conn-1', '123'); + + expect(result).toBeDefined(); + expect(result?.state).toEqual({ answer: true }); + }); + + it('does not create a connection when reading missing subscription', async () => { + const store = new LocalSubscriptionStore(); + + const missing = await store.getSubscription('unknown-conn', 'sub-1'); + + expect(missing).toBeUndefined(); + // eslint-disable-next-line dot-notation + expect(store['connections'].size).toBe(0); + }); + + it('does not create a connection when unsubscribing unknown connection', async () => { + const store = new LocalSubscriptionStore(); + + await store.unsubscribe('unknown-conn', 'sub-1'); + + // eslint-disable-next-line dot-notation + expect(store['connections'].size).toBe(0); + }); + + it('unsubscribes existing subscription', async () => { + const store = new LocalSubscriptionStore(); + + await store.subscribe('conn-1', 'sub-1', { + message: { method: 'load' }, + state: {} + }); + + await store.unsubscribe('conn-1', 'sub-1'); + + const subscription = await store.getSubscription('conn-1', 'sub-1'); + expect(subscription).toBeUndefined(); + }); + + it('returns all active subscriptions with connectionId', async () => { + const store = new LocalSubscriptionStore(); + + await store.subscribe('conn-1', 'sub-1', { + message: { method: 'load' }, + state: { a: 1 } + }); + await store.subscribe('conn-2', 'sub-2', { + message: { method: 'subscribe' }, + state: { b: 2 } + }); + + const allSubscriptions = store.getAllSubscriptions(); + + expect(allSubscriptions).toHaveLength(2); + expect(allSubscriptions).toEqual(expect.arrayContaining([ + expect.objectContaining({ + connectionId: 'conn-1', + message: { method: 'load' }, + state: { a: 1 } + }), + expect.objectContaining({ + connectionId: 'conn-2', + message: { method: 'subscribe' }, + state: { b: 2 } + }) + ])); + }); + + it('removes stale subscriptions during getAllSubscriptions', async () => { + const store = new LocalSubscriptionStore({ heartBeatInterval: 1 }); + + await store.subscribe('conn-1', 'stale', { + message: { method: 'load' }, + state: {} + }); + await store.subscribe('conn-1', 'active', { + message: { method: 'load' }, + state: {} + }); + + const staleSubscription = await store.getSubscription('conn-1', 'stale'); + expect(staleSubscription).toBeDefined(); + if (!staleSubscription) { + throw new Error('Expected stale subscription to exist'); + } + staleSubscription.timestamp = new Date(Date.now() - 5000); + + const allSubscriptions = store.getAllSubscriptions(); + + expect(allSubscriptions).toHaveLength(1); + expect(allSubscriptions[0].connectionId).toBe('conn-1'); + expect(allSubscriptions[0].message).toEqual({ method: 'load' }); + + const staleAfterCleanup = await store.getSubscription('conn-1', 'stale'); + expect(staleAfterCleanup).toBeUndefined(); + }); + + it('stores and retrieves auth context', async () => { + const store = new LocalSubscriptionStore(); + + const authContext = { securityContext: { userId: 42 } }; + await store.setAuthContext('conn-1', authContext); + + await expect(store.getAuthContext('conn-1')).resolves.toEqual(authContext); + }); + + it('removes connection on disconnect', async () => { + const store = new LocalSubscriptionStore(); + + await store.subscribe('conn-1', 'sub-1', { + message: { method: 'load' }, + state: {} + }); + + await store.disconnect('conn-1'); + + // eslint-disable-next-line dot-notation + expect(store['connections'].has('conn-1')).toBe(false); + }); + + it('clears all connections', async () => { + const store = new LocalSubscriptionStore(); + + await store.subscribe('conn-1', 'sub-1', { + message: { method: 'load' }, + state: {} + }); + await store.subscribe('conn-2', 'sub-2', { + message: { method: 'subscribe' }, + state: {} + }); + + store.clear(); + + // eslint-disable-next-line dot-notation + expect(store['connections'].size).toBe(0); + }); +}); diff --git a/packages/cubejs-api-gateway/test/ws/subscription-server.test.ts b/packages/cubejs-api-gateway/test/ws/subscription-server.test.ts index 4abc689c25eb8..e4582d94130d4 100644 --- a/packages/cubejs-api-gateway/test/ws/subscription-server.test.ts +++ b/packages/cubejs-api-gateway/test/ws/subscription-server.test.ts @@ -59,13 +59,15 @@ describe('SubscriptionServer', () => { expect(mockSubscriptionStore.unsubscribe).toHaveBeenCalledWith('conn-1', 'msg-1'); }); - it('should accept unsubscribe with numeric messageId', async () => { + it('should convert numeric unsubscribe id to string', async () => { const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor } = createMocks(); const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); await server.processMessage('conn-1', JSON.stringify({ unsubscribe: 123 })); - expect(mockSubscriptionStore.unsubscribe).toHaveBeenCalledWith('conn-1', 123); + const callArgs = mockSubscriptionStore.unsubscribe.mock.calls[0]; + expect(typeof callArgs[1]).toBe('string'); + expect(callArgs[1]).toBe('123'); }); it('should accept valid load message', async () => { @@ -83,7 +85,7 @@ describe('SubscriptionServer', () => { expect(sentMessages).toContainEqual({ messageProcessedId: '123' }); }); - it('should accept messageId as number', async () => { + it('should convert numeric messageId to string', async () => { const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor, sentMessages } = createMocks(); const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); @@ -95,7 +97,10 @@ describe('SubscriptionServer', () => { await server.processMessage('conn-1', JSON.stringify(message)); expect(mockApiGateway.load).toHaveBeenCalled(); - expect(sentMessages).toContainEqual({ messageProcessedId: 123 }); + + const processedMsg = sentMessages.find((m) => m.messageProcessedId !== undefined); + expect(typeof processedMsg.messageProcessedId).toBe('string'); + expect(processedMsg.messageProcessedId).toBe('123'); }); it('should reject invalid JSON payload', async () => {