Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import { Test, TestingModule } from '@nestjs/testing';
import { InternalChecksController } from './internal-checks.controller';
import { HybridAuthGuard } from '../../auth/hybrid-auth.guard';
import { PermissionGuard } from '../../auth/permission.guard';
import { ServiceTokenOnlyGuard } from '../../auth/service-token-only.guard';
import { ConnectionCheckRunnerService } from '../services/connection-check-runner.service';

jest.mock('@db', () => ({ db: {} }));
jest.mock('../../auth/auth.server', () => ({
auth: { api: { getSession: jest.fn() } },
}));
jest.mock('@trycompai/auth', () => ({
statement: { integration: ['create', 'read', 'update', 'delete'] },
BUILT_IN_ROLE_PERMISSIONS: {},
}));

describe('InternalChecksController', () => {
let controller: InternalChecksController;
const mockRunner = { runChecks: jest.fn() };
const mockGuard = { canActivate: jest.fn().mockReturnValue(true) };

beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
controllers: [InternalChecksController],
providers: [
{ provide: ConnectionCheckRunnerService, useValue: mockRunner },
],
})
.overrideGuard(HybridAuthGuard)
.useValue(mockGuard)
.overrideGuard(ServiceTokenOnlyGuard)
.useValue(mockGuard)
.overrideGuard(PermissionGuard)
.useValue(mockGuard)
.compile();

controller = module.get(InternalChecksController);
jest.clearAllMocks();
});

it('delegates to the runner with the connection, org and checkId', async () => {
const runResult = { results: [], totalFindings: 0, totalPassing: 0 };
mockRunner.runChecks.mockResolvedValue(runResult);

const result = await controller.runConnectionChecks('conn_1', 'org_1', {
checkId: 'aws-s3-public-access',
});

expect(mockRunner.runChecks).toHaveBeenCalledWith({
connectionId: 'conn_1',
organizationId: 'org_1',
checkId: 'aws-s3-public-access',
});
expect(result).toBe(runResult);
});

it('passes checkId undefined when omitted (run all)', async () => {
mockRunner.runChecks.mockResolvedValue({});
await controller.runConnectionChecks('conn_1', 'org_1', {});
expect(mockRunner.runChecks).toHaveBeenCalledWith({
connectionId: 'conn_1',
organizationId: 'org_1',
checkId: undefined,
});
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { Body, Controller, Param, Post, UseGuards } from '@nestjs/common';
import {
ApiBody,
ApiOperation,
ApiPropertyOptional,
ApiTags,
} from '@nestjs/swagger';
import { SkipThrottle } from '@nestjs/throttler';
import { IsOptional, IsString } from 'class-validator';
import { HybridAuthGuard } from '../../auth/hybrid-auth.guard';
import { PermissionGuard } from '../../auth/permission.guard';
import { ServiceTokenOnlyGuard } from '../../auth/service-token-only.guard';
import { RequirePermission } from '../../auth/require-permission.decorator';
import { OrganizationId } from '../../auth/auth-context.decorator';
import {
ConnectionCheckRunnerService,
type RunAllChecksResult,
} from '../services/connection-check-runner.service';

// Internal payload. Service-token only — never called by the UI/customers.
class RunConnectionChecksOnServerDto {
@ApiPropertyOptional({
description:
"Run a single check. Omit to run all of the connection's checks.",
})
@IsOptional()
@IsString()
checkId?: string;
}

/**
* Internal, service-token-only endpoint that runs a connection's checks ON OUR
* SERVER and returns the raw result (no persistence). Used exclusively by the
* AWS Trigger tasks so AWS S3 calls egress our VPC instead of Trigger.dev's
* (whose endpoint policy blocks our cross-account reads). All other providers
* keep executing inside Trigger.dev unchanged.
*/
@Controller({ path: 'integrations/internal', version: '1' })
@ApiTags('Integrations')
export class InternalChecksController {
constructor(private readonly runner: ConnectionCheckRunnerService) {}

@Post('run-connection-checks/:connectionId')
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
// Called by the AWS Trigger tasks in bursts (the 6 AM schedule fans out across
// every AWS connection/check). Exempt from the global rate limiter so the burst
// doesn't hit 429s and re-fail the very checks this path exists to fix.
@SkipThrottle()
@UseGuards(HybridAuthGuard, ServiceTokenOnlyGuard, PermissionGuard)
@RequirePermission('integration', 'update')
@ApiOperation({
summary: "Run a connection's checks on the API server (internal only)",
})
@ApiBody({ type: RunConnectionChecksOnServerDto })
async runConnectionChecks(
@Param('connectionId') connectionId: string,
@OrganizationId() organizationId: string,
@Body() body: RunConnectionChecksOnServerDto,
): Promise<RunAllChecksResult> {
return this.runner.runChecks({
connectionId,
organizationId,
checkId: body.checkId,
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { ConnectionsController } from './controllers/connections.controller';
import { AdminIntegrationsController } from './controllers/admin-integrations.controller';
import { DynamicIntegrationsController } from './controllers/dynamic-integrations.controller';
import { ChecksController } from './controllers/checks.controller';
import { InternalChecksController } from './controllers/internal-checks.controller';
import { VariablesController } from './controllers/variables.controller';
import { TaskIntegrationsController } from './controllers/task-integrations.controller';
import { WebhookController } from './controllers/webhook.controller';
Expand All @@ -20,6 +21,7 @@ import { ConnectionAuthTeardownService } from './services/connection-auth-teardo
import { OAuthTokenRevocationService } from './services/oauth-token-revocation.service';
import { DynamicManifestLoaderService } from './services/dynamic-manifest-loader.service';
import { TaskIntegrationChecksService } from './services/task-integration-checks.service';
import { ConnectionCheckRunnerService } from './services/connection-check-runner.service';
import { ProviderRepository } from './repositories/provider.repository';
import { ConnectionRepository } from './repositories/connection.repository';
import { CredentialRepository } from './repositories/credential.repository';
Expand All @@ -42,6 +44,7 @@ import { GenericDeviceSyncService } from './services/generic-device-sync.service
AdminIntegrationsController,
DynamicIntegrationsController,
ChecksController,
InternalChecksController,
VariablesController,
TaskIntegrationsController,
WebhookController,
Expand All @@ -58,6 +61,7 @@ import { GenericDeviceSyncService } from './services/generic-device-sync.service
ConnectionAuthTeardownService,
DynamicManifestLoaderService,
TaskIntegrationChecksService,
ConnectionCheckRunnerService,
IntegrationSyncLoggerService,
GenericEmployeeSyncService,
GenericDeviceSyncService,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import { Test, TestingModule } from '@nestjs/testing';
import { BadRequestException, NotFoundException } from '@nestjs/common';
import { ConnectionCheckRunnerService } from './connection-check-runner.service';
import { ConnectionRepository } from '../repositories/connection.repository';
import { ProviderRepository } from '../repositories/provider.repository';
import { CredentialVaultService } from './credential-vault.service';
import { OAuthCredentialsService } from './oauth-credentials.service';

jest.mock('@db', () => ({ db: {} }));

jest.mock('@trycompai/integration-platform', () => ({
getManifest: jest.fn(),
runAllChecks: jest.fn(),
}));

import { getManifest, runAllChecks } from '@trycompai/integration-platform';

const mockedGetManifest = getManifest as jest.Mock;
const mockedRunAllChecks = runAllChecks as jest.Mock;

const AWS_MANIFEST = {
id: 'aws',
name: 'AWS',
auth: { type: 'custom' },
checks: [{ id: 'aws-s3-public-access', name: 'S3 public access' }],
};

const RUN_RESULT = {
results: [{ checkId: 'aws-s3-public-access', status: 'success', result: {} }],
totalFindings: 0,
totalPassing: 3,
};

describe('ConnectionCheckRunnerService', () => {
let service: ConnectionCheckRunnerService;

const mockConnectionRepository = { findById: jest.fn() };
const mockProviderRepository = { findById: jest.fn() };
const mockCredentialVaultService = {
getDecryptedCredentials: jest.fn(),
getValidAccessToken: jest.fn(),
refreshOAuthTokens: jest.fn(),
};
const mockOAuthCredentialsService = { getCredentials: jest.fn() };

beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
providers: [
ConnectionCheckRunnerService,
{ provide: ConnectionRepository, useValue: mockConnectionRepository },
{ provide: ProviderRepository, useValue: mockProviderRepository },
{
provide: CredentialVaultService,
useValue: mockCredentialVaultService,
},
{
provide: OAuthCredentialsService,
useValue: mockOAuthCredentialsService,
},
],
}).compile();

service = module.get(ConnectionCheckRunnerService);
jest.clearAllMocks();

mockConnectionRepository.findById.mockResolvedValue({
id: 'conn_1',
organizationId: 'org_1',
providerId: 'prov_aws',
status: 'active',
variables: {},
});
mockProviderRepository.findById.mockResolvedValue({
id: 'prov_aws',
slug: 'aws',
});
mockedGetManifest.mockReturnValue(AWS_MANIFEST);
mockCredentialVaultService.getDecryptedCredentials.mockResolvedValue({
roleArn: 'arn:aws:iam::111111111111:role/x',
externalId: 'ext',
});
mockedRunAllChecks.mockResolvedValue(RUN_RESULT);
});

it('runs the checks on the server and returns the raw result (no persistence)', async () => {
const result = await service.runChecks({
connectionId: 'conn_1',
organizationId: 'org_1',
checkId: 'aws-s3-public-access',
});

expect(mockedRunAllChecks).toHaveBeenCalledWith(
expect.objectContaining({
connectionId: 'conn_1',
organizationId: 'org_1',
checkId: 'aws-s3-public-access',
}),
);
expect(result).toBe(RUN_RESULT);
});

it('runs ALL checks when no checkId is given (auto-run path)', async () => {
await service.runChecks({
connectionId: 'conn_1',
organizationId: 'org_1',
});
expect(mockedRunAllChecks).toHaveBeenCalledWith(
expect.objectContaining({ checkId: undefined }),
);
});

it('throws NotFound for a connection in another org (no cross-tenant run)', async () => {
mockConnectionRepository.findById.mockResolvedValue({
id: 'conn_1',
organizationId: 'org_OTHER',
providerId: 'prov_aws',
status: 'active',
});
await expect(
service.runChecks({ connectionId: 'conn_1', organizationId: 'org_1' }),
).rejects.toBeInstanceOf(NotFoundException);
expect(mockedRunAllChecks).not.toHaveBeenCalled();
});

it('throws BadRequest for an inactive connection', async () => {
mockConnectionRepository.findById.mockResolvedValue({
id: 'conn_1',
organizationId: 'org_1',
providerId: 'prov_aws',
status: 'paused',
});
await expect(
service.runChecks({ connectionId: 'conn_1', organizationId: 'org_1' }),
).rejects.toBeInstanceOf(BadRequestException);
expect(mockedRunAllChecks).not.toHaveBeenCalled();
});

it('throws BadRequest when credentials are missing', async () => {
mockCredentialVaultService.getDecryptedCredentials.mockResolvedValue(null);
await expect(
service.runChecks({ connectionId: 'conn_1', organizationId: 'org_1' }),
).rejects.toBeInstanceOf(BadRequestException);
expect(mockedRunAllChecks).not.toHaveBeenCalled();
});

it('validates by auth type — rejects empty custom credentials (matches in-app run)', async () => {
// AWS uses custom auth; empty creds must be rejected up front, not executed.
mockCredentialVaultService.getDecryptedCredentials.mockResolvedValue({});
await expect(
service.runChecks({ connectionId: 'conn_1', organizationId: 'org_1' }),
).rejects.toBeInstanceOf(BadRequestException);
expect(mockedRunAllChecks).not.toHaveBeenCalled();
});
});
Loading
Loading