Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions packages/cubejs-api-gateway/src/ws/local-subscription-store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,33 @@ interface LocalSubscriptionStoreOptions {
heartBeatInterval?: number;
}

export type SubscriptionId = string | number;

export type LocalSubscriptionStoreSubscription = {
message: any,
state: any,
timestamp: Date,
};

export type LocalSubscriptionStoreConnection = {
subscriptions: Map<SubscriptionId, LocalSubscriptionStoreSubscription>,
subscriptions: Map<string, LocalSubscriptionStoreSubscription>,
authContext?: any,
};

export class LocalSubscriptionStore {
protected readonly connections: Map<string, LocalSubscriptionStoreConnection> = new Map();

protected readonly hearBeatInterval: number;
protected readonly heartBeatInterval: number;

public constructor(options: LocalSubscriptionStoreOptions = {}) {
this.hearBeatInterval = options.heartBeatInterval || 60;
this.heartBeatInterval = options.heartBeatInterval || 60;
}

public async getSubscription(connectionId: string, subscriptionId: string) {
const connection = this.getConnectionOrCreate(connectionId);
public async getSubscription(connectionId: string, subscriptionId: string): Promise<LocalSubscriptionStoreSubscription | undefined> {
// only get subscription, do not create connection if it doesn't exist
const connection = this.getConnection(connectionId);
if (!connection) {
return undefined;
}

return connection.subscriptions.get(subscriptionId);
}

Expand All @@ -37,14 +40,22 @@ export class LocalSubscriptionStore {
});
}

public async unsubscribe(connectionId: string, subscriptionId: SubscriptionId) {
const connection = this.getConnectionOrCreate(connectionId);
public async unsubscribe(connectionId: string, subscriptionId: string) {
const connection = this.getConnection(connectionId);
if (!connection) {
return;
}

if (!connection.subscriptions.has(subscriptionId)) {
return;
}

connection.subscriptions.delete(subscriptionId);
}

public getAllSubscriptions() {
const now = Date.now();
const staleThreshold = this.hearBeatInterval * 4 * 1000;
const staleThreshold = this.heartBeatInterval * 4 * 1000;
const result: Array<{ connectionId: string } & LocalSubscriptionStoreSubscription> = [];

for (const [connectionId, connection] of this.connections) {
Expand Down Expand Up @@ -75,17 +86,21 @@ export class LocalSubscriptionStore {
}

protected getConnectionOrCreate(connectionId: string): LocalSubscriptionStoreConnection {
const connect = this.connections.get(connectionId);
const connect = this.getConnection(connectionId);
if (connect) {
return connect;
}

const connection = { subscriptions: new Map() };
const connection: LocalSubscriptionStoreConnection = { subscriptions: new Map<string, LocalSubscriptionStoreSubscription>() };
this.connections.set(connectionId, connection);

return connection;
}

protected getConnection(connectionId: string): LocalSubscriptionStoreConnection | undefined {
return this.connections.get(connectionId);
}

public clear() {
this.connections.clear();
}
Expand Down
2 changes: 1 addition & 1 deletion packages/cubejs-api-gateway/src/ws/message-schema.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { z } from 'zod';

const messageId = z.union([z.string().max(16), z.int()]);
const messageId = z.union([z.string().max(16), z.int()]).transform(String);
const requestId = z.string().max(64).optional();

export const authMessageSchema = z.object({
Expand Down
4 changes: 2 additions & 2 deletions packages/cubejs-api-gateway/src/ws/subscription-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export class SubscriptionServer {
) {
}

protected resultFn(connectionId: string, messageId: string | number | undefined, requestId: string | undefined, logNetworkUsage: boolean = true) {
protected resultFn(connectionId: string, messageId: string | undefined, requestId: string | undefined, logNetworkUsage: boolean = true) {
return async (message, { status } = { status: 200 }) => {
if (logNetworkUsage) {
this.apiGateway.log({ type: 'Outgoing network usage', service: 'api-ws', bytes: calcMessageLength(message), }, { requestId });
Expand Down Expand Up @@ -158,7 +158,7 @@ export class SubscriptionServer {
throw new UserError(`Unsupported method: ${message.method}`);
}

const subscriptionId = String(message.messageId);
const subscriptionId = message.messageId;
const baseRequestId = message.requestId || `${connectionId}-${subscriptionId}`;
const requestId = `${baseRequestId}-span-${uuidv4()}`;

Expand Down
167 changes: 167 additions & 0 deletions packages/cubejs-api-gateway/test/ws/local-subscription-store.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import {
LocalSubscriptionStore,
} from '../../src/ws/local-subscription-store';

describe('LocalSubscriptionStore', () => {
it('stores and retrieves subscription by id', async () => {
const store = new LocalSubscriptionStore();

await store.subscribe('conn-1', 'sub-1', {
message: { method: 'load' },
state: { foo: 'bar' }
});

const subscription = await store.getSubscription('conn-1', 'sub-1');

expect(subscription).toBeDefined();
expect(subscription?.message).toEqual({ method: 'load' });
expect(subscription?.state).toEqual({ foo: 'bar' });
expect(subscription?.timestamp).toBeInstanceOf(Date);
});

it('stores and retrieves subscription by string id', async () => {
const store = new LocalSubscriptionStore();

await store.subscribe('conn-1', '123', {
message: { method: 'load' },
state: { answer: true }
});

const result = await store.getSubscription('conn-1', '123');

expect(result).toBeDefined();
expect(result?.state).toEqual({ answer: true });
});

it('does not create a connection when reading missing subscription', async () => {
const store = new LocalSubscriptionStore();

const missing = await store.getSubscription('unknown-conn', 'sub-1');

expect(missing).toBeUndefined();
// eslint-disable-next-line dot-notation
expect(store['connections'].size).toBe(0);
});

it('does not create a connection when unsubscribing unknown connection', async () => {
const store = new LocalSubscriptionStore();

await store.unsubscribe('unknown-conn', 'sub-1');

// eslint-disable-next-line dot-notation
expect(store['connections'].size).toBe(0);
});

it('unsubscribes existing subscription', async () => {
const store = new LocalSubscriptionStore();

await store.subscribe('conn-1', 'sub-1', {
message: { method: 'load' },
state: {}
});

await store.unsubscribe('conn-1', 'sub-1');

const subscription = await store.getSubscription('conn-1', 'sub-1');
expect(subscription).toBeUndefined();
});

it('returns all active subscriptions with connectionId', async () => {
const store = new LocalSubscriptionStore();

await store.subscribe('conn-1', 'sub-1', {
message: { method: 'load' },
state: { a: 1 }
});
await store.subscribe('conn-2', 'sub-2', {
message: { method: 'subscribe' },
state: { b: 2 }
});

const allSubscriptions = store.getAllSubscriptions();

expect(allSubscriptions).toHaveLength(2);
expect(allSubscriptions).toEqual(expect.arrayContaining([
expect.objectContaining({
connectionId: 'conn-1',
message: { method: 'load' },
state: { a: 1 }
}),
expect.objectContaining({
connectionId: 'conn-2',
message: { method: 'subscribe' },
state: { b: 2 }
})
]));
});

it('removes stale subscriptions during getAllSubscriptions', async () => {
const store = new LocalSubscriptionStore({ heartBeatInterval: 1 });

await store.subscribe('conn-1', 'stale', {
message: { method: 'load' },
state: {}
});
await store.subscribe('conn-1', 'active', {
message: { method: 'load' },
state: {}
});

const staleSubscription = await store.getSubscription('conn-1', 'stale');
expect(staleSubscription).toBeDefined();
if (!staleSubscription) {
throw new Error('Expected stale subscription to exist');
}
staleSubscription.timestamp = new Date(Date.now() - 5000);

const allSubscriptions = store.getAllSubscriptions();

expect(allSubscriptions).toHaveLength(1);
expect(allSubscriptions[0].connectionId).toBe('conn-1');
expect(allSubscriptions[0].message).toEqual({ method: 'load' });

const staleAfterCleanup = await store.getSubscription('conn-1', 'stale');
expect(staleAfterCleanup).toBeUndefined();
});

it('stores and retrieves auth context', async () => {
const store = new LocalSubscriptionStore();

const authContext = { securityContext: { userId: 42 } };
await store.setAuthContext('conn-1', authContext);

await expect(store.getAuthContext('conn-1')).resolves.toEqual(authContext);
});

it('removes connection on disconnect', async () => {
const store = new LocalSubscriptionStore();

await store.subscribe('conn-1', 'sub-1', {
message: { method: 'load' },
state: {}
});

await store.disconnect('conn-1');

// eslint-disable-next-line dot-notation
expect(store['connections'].has('conn-1')).toBe(false);
});

it('clears all connections', async () => {
const store = new LocalSubscriptionStore();

await store.subscribe('conn-1', 'sub-1', {
message: { method: 'load' },
state: {}
});
await store.subscribe('conn-2', 'sub-2', {
message: { method: 'subscribe' },
state: {}
});

store.clear();

// eslint-disable-next-line dot-notation
expect(store['connections'].size).toBe(0);
});
});
13 changes: 9 additions & 4 deletions packages/cubejs-api-gateway/test/ws/subscription-server.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ describe('SubscriptionServer', () => {
expect(mockSubscriptionStore.unsubscribe).toHaveBeenCalledWith('conn-1', 'msg-1');
});

it('should accept unsubscribe with numeric messageId', async () => {
it('should convert numeric unsubscribe id to string', async () => {
const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor } = createMocks();
const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor);

await server.processMessage('conn-1', JSON.stringify({ unsubscribe: 123 }));

expect(mockSubscriptionStore.unsubscribe).toHaveBeenCalledWith('conn-1', 123);
const callArgs = mockSubscriptionStore.unsubscribe.mock.calls[0];
expect(typeof callArgs[1]).toBe('string');
expect(callArgs[1]).toBe('123');
});

it('should accept valid load message', async () => {
Expand All @@ -83,7 +85,7 @@ describe('SubscriptionServer', () => {
expect(sentMessages).toContainEqual({ messageProcessedId: '123' });
});

it('should accept messageId as number', async () => {
it('should convert numeric messageId to string', async () => {
const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor, sentMessages } = createMocks();
const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor);

Expand All @@ -95,7 +97,10 @@ describe('SubscriptionServer', () => {
await server.processMessage('conn-1', JSON.stringify(message));

expect(mockApiGateway.load).toHaveBeenCalled();
expect(sentMessages).toContainEqual({ messageProcessedId: 123 });

const processedMsg = sentMessages.find((m) => m.messageProcessedId !== undefined);
expect(typeof processedMsg.messageProcessedId).toBe('string');
expect(processedMsg.messageProcessedId).toBe('123');
});

it('should reject invalid JSON payload', async () => {
Expand Down
Loading