diff --git a/apps/api/src/frameworks/framework-versioning/framework-diff.spec.ts b/apps/api/src/frameworks/framework-versioning/framework-diff.spec.ts index 053e9fa257..fa378b3115 100644 --- a/apps/api/src/frameworks/framework-versioning/framework-diff.spec.ts +++ b/apps/api/src/frameworks/framework-versioning/framework-diff.spec.ts @@ -65,6 +65,29 @@ describe('diffManifests', () => { expect(diff.requirementMapEdges.removed).toContainEqual({ controlTemplateId: 'c1', requirementTemplateId: 'r1' }); }); + it('reports no framework-metadata change for identical manifests', () => { + const m = emptyManifest(); + expect(diffManifests(m, m).framework.changed).toBe(false); + }); + + it('detects a framework name change (FRAME-9)', () => { + const from = emptyManifest(); + const to = { ...emptyManifest(), framework: { ...from.framework, name: 'New Name' } }; + const diff = diffManifests(from, to); + expect(diff.framework.changed).toBe(true); + expect(diff.framework.name).toEqual({ from: 'n', to: 'New Name' }); + expect(diff.framework.description).toBeUndefined(); + }); + + it('detects a framework description change (FRAME-9)', () => { + const from = { ...emptyManifest(), framework: { id: 'f', name: 'n', catalogVersion: '1', description: 'old' } }; + const to = { ...emptyManifest(), framework: { id: 'f', name: 'n', catalogVersion: '1', description: 'new' } }; + const diff = diffManifests(from, to); + expect(diff.framework.changed).toBe(true); + expect(diff.framework.description).toEqual({ from: 'old', to: 'new' }); + expect(diff.framework.name).toBeUndefined(); + }); + it('drops phantom edges that reference entities missing from the manifest', () => { // Older snapshots sometimes stored cross-framework requirement IDs in // control.requirementIds. Those IDs are not in manifest.requirements, so diff --git a/apps/api/src/frameworks/framework-versioning/framework-diff.ts b/apps/api/src/frameworks/framework-versioning/framework-diff.ts index 0d8113789f..e947160862 100644 --- a/apps/api/src/frameworks/framework-versioning/framework-diff.ts +++ b/apps/api/src/frameworks/framework-versioning/framework-diff.ts @@ -38,7 +38,19 @@ export interface ControlDocumentTypeEdge { formType: string; } +/** + * Changes to the framework's own metadata (name / description). These don't + * live in any entity list, so without this the diff treats a name- or + * description-only edit as "no changes" and the Publish button stays disabled. + */ +export interface FrameworkMetaDiff { + changed: boolean; + name?: { from: string; to: string }; + description?: { from: string | null; to: string | null }; +} + export interface ManifestDiff { + framework: FrameworkMetaDiff; controls: EntityDiff; requirements: EntityDiff; policies: EntityDiff; @@ -85,6 +97,7 @@ export function diffManifests(fromRaw: FrameworkManifest, toRaw: FrameworkManife const from = sanitizeManifestEdges(fromRaw); const to = sanitizeManifestEdges(toRaw); return { + framework: diffFrameworkMeta(from.framework, to.framework), controls: diffEntities(from.controls, to.controls, controlEqual), requirements: diffEntities(from.requirements, to.requirements, requirementEqual), policies: diffEntities(from.policies, to.policies, policyEqual), @@ -168,6 +181,23 @@ function edgesFromControls( return controls.flatMap(extract); } +function diffFrameworkMeta( + from: FrameworkManifest['framework'], + to: FrameworkManifest['framework'], +): FrameworkMetaDiff { + const nameChanged = from.name !== to.name; + const fromDescription = from.description ?? null; + const toDescription = to.description ?? null; + const descriptionChanged = fromDescription !== toDescription; + return { + changed: nameChanged || descriptionChanged, + ...(nameChanged ? { name: { from: from.name, to: to.name } } : {}), + ...(descriptionChanged + ? { description: { from: fromDescription, to: toDescription } } + : {}), + }; +} + function controlEqual(a: ManifestControl, b: ManifestControl): boolean { return a.name === b.name && a.description === b.description && (a.controlFamily ?? null) === (b.controlFamily ?? null); } diff --git a/apps/api/src/integration-platform/controllers/internal-checks.controller.spec.ts b/apps/api/src/integration-platform/controllers/internal-checks.controller.spec.ts new file mode 100644 index 0000000000..4a704dabf0 --- /dev/null +++ b/apps/api/src/integration-platform/controllers/internal-checks.controller.spec.ts @@ -0,0 +1,66 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { InternalChecksController } from './internal-checks.controller'; +import { HybridAuthGuard } from '../../auth/hybrid-auth.guard'; +import { PermissionGuard } from '../../auth/permission.guard'; +import { ServiceTokenOnlyGuard } from '../../auth/service-token-only.guard'; +import { ConnectionCheckRunnerService } from '../services/connection-check-runner.service'; + +jest.mock('@db', () => ({ db: {} })); +jest.mock('../../auth/auth.server', () => ({ + auth: { api: { getSession: jest.fn() } }, +})); +jest.mock('@trycompai/auth', () => ({ + statement: { integration: ['create', 'read', 'update', 'delete'] }, + BUILT_IN_ROLE_PERMISSIONS: {}, +})); + +describe('InternalChecksController', () => { + let controller: InternalChecksController; + const mockRunner = { runChecks: jest.fn() }; + const mockGuard = { canActivate: jest.fn().mockReturnValue(true) }; + + beforeEach(async () => { + const module: TestingModule = await Test.createTestingModule({ + controllers: [InternalChecksController], + providers: [ + { provide: ConnectionCheckRunnerService, useValue: mockRunner }, + ], + }) + .overrideGuard(HybridAuthGuard) + .useValue(mockGuard) + .overrideGuard(ServiceTokenOnlyGuard) + .useValue(mockGuard) + .overrideGuard(PermissionGuard) + .useValue(mockGuard) + .compile(); + + controller = module.get(InternalChecksController); + jest.clearAllMocks(); + }); + + it('delegates to the runner with the connection, org and checkId', async () => { + const runResult = { results: [], totalFindings: 0, totalPassing: 0 }; + mockRunner.runChecks.mockResolvedValue(runResult); + + const result = await controller.runConnectionChecks('conn_1', 'org_1', { + checkId: 'aws-s3-public-access', + }); + + expect(mockRunner.runChecks).toHaveBeenCalledWith({ + connectionId: 'conn_1', + organizationId: 'org_1', + checkId: 'aws-s3-public-access', + }); + expect(result).toBe(runResult); + }); + + it('passes checkId undefined when omitted (run all)', async () => { + mockRunner.runChecks.mockResolvedValue({}); + await controller.runConnectionChecks('conn_1', 'org_1', {}); + expect(mockRunner.runChecks).toHaveBeenCalledWith({ + connectionId: 'conn_1', + organizationId: 'org_1', + checkId: undefined, + }); + }); +}); diff --git a/apps/api/src/integration-platform/controllers/internal-checks.controller.ts b/apps/api/src/integration-platform/controllers/internal-checks.controller.ts new file mode 100644 index 0000000000..c2dc399a43 --- /dev/null +++ b/apps/api/src/integration-platform/controllers/internal-checks.controller.ts @@ -0,0 +1,65 @@ +import { Body, Controller, Param, Post, UseGuards } from '@nestjs/common'; +import { + ApiBody, + ApiOperation, + ApiPropertyOptional, + ApiTags, +} from '@nestjs/swagger'; +import { SkipThrottle } from '@nestjs/throttler'; +import { IsOptional, IsString } from 'class-validator'; +import { HybridAuthGuard } from '../../auth/hybrid-auth.guard'; +import { PermissionGuard } from '../../auth/permission.guard'; +import { ServiceTokenOnlyGuard } from '../../auth/service-token-only.guard'; +import { RequirePermission } from '../../auth/require-permission.decorator'; +import { OrganizationId } from '../../auth/auth-context.decorator'; +import { + ConnectionCheckRunnerService, + type RunAllChecksResult, +} from '../services/connection-check-runner.service'; + +// Internal payload. Service-token only — never called by the UI/customers. +class RunConnectionChecksOnServerDto { + @ApiPropertyOptional({ + description: + "Run a single check. Omit to run all of the connection's checks.", + }) + @IsOptional() + @IsString() + checkId?: string; +} + +/** + * Internal, service-token-only endpoint that runs a connection's checks ON OUR + * SERVER and returns the raw result (no persistence). Used exclusively by the + * AWS Trigger tasks so AWS S3 calls egress our VPC instead of Trigger.dev's + * (whose endpoint policy blocks our cross-account reads). All other providers + * keep executing inside Trigger.dev unchanged. + */ +@Controller({ path: 'integrations/internal', version: '1' }) +@ApiTags('Integrations') +export class InternalChecksController { + constructor(private readonly runner: ConnectionCheckRunnerService) {} + + @Post('run-connection-checks/:connectionId') + // Called by the AWS Trigger tasks in bursts (the 6 AM schedule fans out across + // every AWS connection/check). Exempt from the global rate limiter so the burst + // doesn't hit 429s and re-fail the very checks this path exists to fix. + @SkipThrottle() + @UseGuards(HybridAuthGuard, ServiceTokenOnlyGuard, PermissionGuard) + @RequirePermission('integration', 'update') + @ApiOperation({ + summary: "Run a connection's checks on the API server (internal only)", + }) + @ApiBody({ type: RunConnectionChecksOnServerDto }) + async runConnectionChecks( + @Param('connectionId') connectionId: string, + @OrganizationId() organizationId: string, + @Body() body: RunConnectionChecksOnServerDto, + ): Promise { + return this.runner.runChecks({ + connectionId, + organizationId, + checkId: body.checkId, + }); + } +} diff --git a/apps/api/src/integration-platform/integration-platform.module.ts b/apps/api/src/integration-platform/integration-platform.module.ts index 4ed3260faa..87c029e79d 100644 --- a/apps/api/src/integration-platform/integration-platform.module.ts +++ b/apps/api/src/integration-platform/integration-platform.module.ts @@ -7,6 +7,7 @@ import { ConnectionsController } from './controllers/connections.controller'; import { AdminIntegrationsController } from './controllers/admin-integrations.controller'; import { DynamicIntegrationsController } from './controllers/dynamic-integrations.controller'; import { ChecksController } from './controllers/checks.controller'; +import { InternalChecksController } from './controllers/internal-checks.controller'; import { VariablesController } from './controllers/variables.controller'; import { TaskIntegrationsController } from './controllers/task-integrations.controller'; import { WebhookController } from './controllers/webhook.controller'; @@ -20,6 +21,7 @@ import { ConnectionAuthTeardownService } from './services/connection-auth-teardo import { OAuthTokenRevocationService } from './services/oauth-token-revocation.service'; import { DynamicManifestLoaderService } from './services/dynamic-manifest-loader.service'; import { TaskIntegrationChecksService } from './services/task-integration-checks.service'; +import { ConnectionCheckRunnerService } from './services/connection-check-runner.service'; import { ProviderRepository } from './repositories/provider.repository'; import { ConnectionRepository } from './repositories/connection.repository'; import { CredentialRepository } from './repositories/credential.repository'; @@ -42,6 +44,7 @@ import { GenericDeviceSyncService } from './services/generic-device-sync.service AdminIntegrationsController, DynamicIntegrationsController, ChecksController, + InternalChecksController, VariablesController, TaskIntegrationsController, WebhookController, @@ -58,6 +61,7 @@ import { GenericDeviceSyncService } from './services/generic-device-sync.service ConnectionAuthTeardownService, DynamicManifestLoaderService, TaskIntegrationChecksService, + ConnectionCheckRunnerService, IntegrationSyncLoggerService, GenericEmployeeSyncService, GenericDeviceSyncService, diff --git a/apps/api/src/integration-platform/services/connection-check-runner.service.spec.ts b/apps/api/src/integration-platform/services/connection-check-runner.service.spec.ts new file mode 100644 index 0000000000..40749cc5ab --- /dev/null +++ b/apps/api/src/integration-platform/services/connection-check-runner.service.spec.ts @@ -0,0 +1,154 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { BadRequestException, NotFoundException } from '@nestjs/common'; +import { ConnectionCheckRunnerService } from './connection-check-runner.service'; +import { ConnectionRepository } from '../repositories/connection.repository'; +import { ProviderRepository } from '../repositories/provider.repository'; +import { CredentialVaultService } from './credential-vault.service'; +import { OAuthCredentialsService } from './oauth-credentials.service'; + +jest.mock('@db', () => ({ db: {} })); + +jest.mock('@trycompai/integration-platform', () => ({ + getManifest: jest.fn(), + runAllChecks: jest.fn(), +})); + +import { getManifest, runAllChecks } from '@trycompai/integration-platform'; + +const mockedGetManifest = getManifest as jest.Mock; +const mockedRunAllChecks = runAllChecks as jest.Mock; + +const AWS_MANIFEST = { + id: 'aws', + name: 'AWS', + auth: { type: 'custom' }, + checks: [{ id: 'aws-s3-public-access', name: 'S3 public access' }], +}; + +const RUN_RESULT = { + results: [{ checkId: 'aws-s3-public-access', status: 'success', result: {} }], + totalFindings: 0, + totalPassing: 3, +}; + +describe('ConnectionCheckRunnerService', () => { + let service: ConnectionCheckRunnerService; + + const mockConnectionRepository = { findById: jest.fn() }; + const mockProviderRepository = { findById: jest.fn() }; + const mockCredentialVaultService = { + getDecryptedCredentials: jest.fn(), + getValidAccessToken: jest.fn(), + refreshOAuthTokens: jest.fn(), + }; + const mockOAuthCredentialsService = { getCredentials: jest.fn() }; + + beforeEach(async () => { + const module: TestingModule = await Test.createTestingModule({ + providers: [ + ConnectionCheckRunnerService, + { provide: ConnectionRepository, useValue: mockConnectionRepository }, + { provide: ProviderRepository, useValue: mockProviderRepository }, + { + provide: CredentialVaultService, + useValue: mockCredentialVaultService, + }, + { + provide: OAuthCredentialsService, + useValue: mockOAuthCredentialsService, + }, + ], + }).compile(); + + service = module.get(ConnectionCheckRunnerService); + jest.clearAllMocks(); + + mockConnectionRepository.findById.mockResolvedValue({ + id: 'conn_1', + organizationId: 'org_1', + providerId: 'prov_aws', + status: 'active', + variables: {}, + }); + mockProviderRepository.findById.mockResolvedValue({ + id: 'prov_aws', + slug: 'aws', + }); + mockedGetManifest.mockReturnValue(AWS_MANIFEST); + mockCredentialVaultService.getDecryptedCredentials.mockResolvedValue({ + roleArn: 'arn:aws:iam::111111111111:role/x', + externalId: 'ext', + }); + mockedRunAllChecks.mockResolvedValue(RUN_RESULT); + }); + + it('runs the checks on the server and returns the raw result (no persistence)', async () => { + const result = await service.runChecks({ + connectionId: 'conn_1', + organizationId: 'org_1', + checkId: 'aws-s3-public-access', + }); + + expect(mockedRunAllChecks).toHaveBeenCalledWith( + expect.objectContaining({ + connectionId: 'conn_1', + organizationId: 'org_1', + checkId: 'aws-s3-public-access', + }), + ); + expect(result).toBe(RUN_RESULT); + }); + + it('runs ALL checks when no checkId is given (auto-run path)', async () => { + await service.runChecks({ + connectionId: 'conn_1', + organizationId: 'org_1', + }); + expect(mockedRunAllChecks).toHaveBeenCalledWith( + expect.objectContaining({ checkId: undefined }), + ); + }); + + it('throws NotFound for a connection in another org (no cross-tenant run)', async () => { + mockConnectionRepository.findById.mockResolvedValue({ + id: 'conn_1', + organizationId: 'org_OTHER', + providerId: 'prov_aws', + status: 'active', + }); + await expect( + service.runChecks({ connectionId: 'conn_1', organizationId: 'org_1' }), + ).rejects.toBeInstanceOf(NotFoundException); + expect(mockedRunAllChecks).not.toHaveBeenCalled(); + }); + + it('throws BadRequest for an inactive connection', async () => { + mockConnectionRepository.findById.mockResolvedValue({ + id: 'conn_1', + organizationId: 'org_1', + providerId: 'prov_aws', + status: 'paused', + }); + await expect( + service.runChecks({ connectionId: 'conn_1', organizationId: 'org_1' }), + ).rejects.toBeInstanceOf(BadRequestException); + expect(mockedRunAllChecks).not.toHaveBeenCalled(); + }); + + it('throws BadRequest when credentials are missing', async () => { + mockCredentialVaultService.getDecryptedCredentials.mockResolvedValue(null); + await expect( + service.runChecks({ connectionId: 'conn_1', organizationId: 'org_1' }), + ).rejects.toBeInstanceOf(BadRequestException); + expect(mockedRunAllChecks).not.toHaveBeenCalled(); + }); + + it('validates by auth type — rejects empty custom credentials (matches in-app run)', async () => { + // AWS uses custom auth; empty creds must be rejected up front, not executed. + mockCredentialVaultService.getDecryptedCredentials.mockResolvedValue({}); + await expect( + service.runChecks({ connectionId: 'conn_1', organizationId: 'org_1' }), + ).rejects.toBeInstanceOf(BadRequestException); + expect(mockedRunAllChecks).not.toHaveBeenCalled(); + }); +}); diff --git a/apps/api/src/integration-platform/services/connection-check-runner.service.ts b/apps/api/src/integration-platform/services/connection-check-runner.service.ts new file mode 100644 index 0000000000..123cb4f7ab --- /dev/null +++ b/apps/api/src/integration-platform/services/connection-check-runner.service.ts @@ -0,0 +1,179 @@ +import { + BadRequestException, + Injectable, + Logger, + NotFoundException, +} from '@nestjs/common'; +import { getManifest, runAllChecks } from '@trycompai/integration-platform'; +import { ConnectionRepository } from '../repositories/connection.repository'; +import { ProviderRepository } from '../repositories/provider.repository'; +import { CredentialVaultService } from './credential-vault.service'; +import { OAuthCredentialsService } from './oauth-credentials.service'; +import { getStringValue } from '../utils/credential-utils'; + +export type RunAllChecksResult = Awaited>; + +/** + * Runs integration checks for a connection ON OUR SERVER (the API/ECS process) + * and returns the raw result WITHOUT persisting anything. + * + * Why this exists: AWS checks make S3 (and other) API calls that egress the + * runtime's network. In the Trigger.dev runtime those calls exit Trigger.dev's + * VPC, whose S3 endpoint policy blocks our cross-account audit reads + * ("no VPC endpoint policy allows ..."). Running them here egresses OUR VPC, + * whose endpoint allows the read — identical to the in-app manual "Run". + * + * Only the AWS Trigger tasks call this; GCP/Azure/dynamic/legacy integrations + * keep executing in Trigger.dev unchanged. Persistence + task status + emails + * stay in the caller, so AWS results are recorded exactly like every other + * provider's. + */ +@Injectable() +export class ConnectionCheckRunnerService { + private readonly logger = new Logger(ConnectionCheckRunnerService.name); + + constructor( + private readonly connectionRepository: ConnectionRepository, + private readonly providerRepository: ProviderRepository, + private readonly credentialVaultService: CredentialVaultService, + private readonly oauthCredentialsService: OAuthCredentialsService, + ) {} + + /** + * Run a connection's checks and return the raw `runAllChecks` result. + * Pass `checkId` to run a single check; omit it to run all of the + * connection's checks. Does NOT write to the database. + */ + async runChecks(params: { + connectionId: string; + organizationId: string; + checkId?: string; + }): Promise { + const { connectionId, organizationId, checkId } = params; + + const connection = await this.connectionRepository.findById(connectionId); + if (!connection || connection.organizationId !== organizationId) { + throw new NotFoundException('Connection not found'); + } + if (connection.status !== 'active') { + throw new BadRequestException( + `Connection is not active (status: ${connection.status})`, + ); + } + + const provider = await this.providerRepository.findById( + connection.providerId, + ); + if (!provider) { + throw new NotFoundException('Provider not found'); + } + + const manifest = getManifest(provider.slug); + if (!manifest) { + throw new NotFoundException(`Manifest for ${provider.slug} not found`); + } + if (!manifest.checks || manifest.checks.length === 0) { + throw new BadRequestException(`No checks defined for ${provider.slug}`); + } + + const credentials = + await this.credentialVaultService.getDecryptedCredentials(connectionId); + if (!credentials) { + throw new BadRequestException('No credentials found for connection'); + } + + // Validate credentials by auth type, matching the in-app run paths + // (checks.controller / task-integrations.controller) so a server-run rejects + // malformed credentials up front with a clear error instead of executing the + // check on bad input and producing an inconsistent outcome. + if (manifest.auth.type === 'oauth2' && !credentials.access_token) { + throw new BadRequestException( + 'No valid OAuth credentials found. Please reconnect.', + ); + } + if (manifest.auth.type === 'api_key') { + const apiKeyField = manifest.auth.config.name; + if (!credentials[apiKeyField] && !credentials.api_key) { + throw new BadRequestException( + 'API key not found. Please reconnect the integration.', + ); + } + } + if (manifest.auth.type === 'basic') { + const usernameField = manifest.auth.config.usernameField || 'username'; + const passwordField = manifest.auth.config.passwordField || 'password'; + if (!credentials[usernameField] || !credentials[passwordField]) { + throw new BadRequestException( + 'Username and password required. Please reconnect the integration.', + ); + } + } + if ( + manifest.auth.type === 'custom' && + Object.keys(credentials).length === 0 + ) { + throw new BadRequestException( + 'No valid credentials found for custom integration', + ); + } + + const variables = + (connection.variables as Record< + string, + string | number | boolean | string[] | undefined + >) || {}; + + // Build the OAuth refresh callback for providers that support it. AWS is + // not oauth2, so this is a no-op for the AWS path that actually uses this. + let accessToken = getStringValue(credentials.access_token); + let onTokenRefresh: (() => Promise) | undefined; + if (manifest.auth.type === 'oauth2') { + const oauthConfig = manifest.auth.config; + if (oauthConfig.supportsRefreshToken !== false) { + const oauthCredentials = + await this.oauthCredentialsService.getCredentials( + provider.slug, + organizationId, + ); + if (oauthCredentials) { + const refreshConfig = { + tokenUrl: oauthConfig.tokenUrl, + refreshUrl: oauthConfig.refreshUrl, + clientId: oauthCredentials.clientId, + clientSecret: oauthCredentials.clientSecret, + clientAuthMethod: oauthConfig.clientAuthMethod, + scope: oauthCredentials.scopes.join(' '), + tokenParams: oauthConfig.tokenParams, + }; + const validAccessToken = + await this.credentialVaultService.getValidAccessToken( + connectionId, + refreshConfig, + ); + if (validAccessToken) accessToken = validAccessToken; + onTokenRefresh = () => + this.credentialVaultService.refreshOAuthTokens( + connectionId, + refreshConfig, + ); + } + } + } + + return runAllChecks({ + manifest, + accessToken, + credentials, + variables, + connectionId, + organizationId, + checkId, + onTokenRefresh, + logger: { + info: (msg, data) => this.logger.log(msg, data), + warn: (msg, data) => this.logger.warn(msg, data), + error: (msg, data) => this.logger.error(msg, data), + }, + }); + } +} diff --git a/apps/api/src/openapi/public-docs-quality.ts b/apps/api/src/openapi/public-docs-quality.ts index e133d52e3e..04e2d82be8 100644 --- a/apps/api/src/openapi/public-docs-quality.ts +++ b/apps/api/src/openapi/public-docs-quality.ts @@ -19,6 +19,7 @@ export const PUBLIC_DOCS_EXCLUDED_PREFIXES = [ '/v1/finding-template', '/v1/integrations/oauth', '/v1/integrations/oauth-apps', + '/v1/integrations/internal', '/v1/cloud-security/legacy', '/v1/cloud-security/remediation', '/v1/questionnaire/parse/upload/token', diff --git a/apps/api/src/trigger/integration-platform/run-checks-on-server.spec.ts b/apps/api/src/trigger/integration-platform/run-checks-on-server.spec.ts new file mode 100644 index 0000000000..ed59857003 --- /dev/null +++ b/apps/api/src/trigger/integration-platform/run-checks-on-server.spec.ts @@ -0,0 +1,100 @@ +import { runChecksOnServer } from './run-checks-on-server'; + +describe('runChecksOnServer', () => { + const ORIGINAL_TOKEN = process.env.SERVICE_TOKEN_TRIGGER; + const params = { + apiUrl: 'http://api', + connectionId: 'conn_1', + organizationId: 'org_1', + }; + + beforeEach(() => { + process.env.SERVICE_TOKEN_TRIGGER = 'svc-token'; + jest.restoreAllMocks(); + }); + + afterAll(() => { + if (ORIGINAL_TOKEN === undefined) delete process.env.SERVICE_TOKEN_TRIGGER; + else process.env.SERVICE_TOKEN_TRIGGER = ORIGINAL_TOKEN; + }); + + it('POSTs to the internal endpoint with service token + org header and returns the result', async () => { + const runResult = { results: [{}], totalFindings: 1, totalPassing: 2 }; + const fetchMock = jest.spyOn(global, 'fetch').mockResolvedValue({ + ok: true, + json: async () => runResult, + } as unknown as Response); + + const result = await runChecksOnServer({ + ...params, + checkId: 'aws-s3-public-access', + }); + + expect(fetchMock).toHaveBeenCalledWith( + 'http://api/v1/integrations/internal/run-connection-checks/conn_1', + expect.objectContaining({ + method: 'POST', + // An abort signal is wired up so a hung connection times out and the + // task can retry instead of blocking until maxDuration. + signal: expect.any(AbortSignal), + headers: expect.objectContaining({ + 'x-service-token': 'svc-token', + 'x-organization-id': 'org_1', + }), + body: JSON.stringify({ checkId: 'aws-s3-public-access' }), + }), + ); + expect(result).toEqual(runResult); + }); + + it('throws a timeout error when the request is aborted (hung connection)', async () => { + jest.useFakeTimers(); + jest.spyOn(global, 'fetch').mockImplementation( + (_url, opts) => + new Promise((_resolve, reject) => { + (opts as RequestInit).signal?.addEventListener('abort', () => + reject(new Error('aborted')), + ); + }), + ); + + const promise = runChecksOnServer(params); + // Surface the rejection without an unhandled-rejection warning. + const assertion = expect(promise).rejects.toThrow('timed out'); + await jest.advanceTimersByTimeAsync(10 * 60 * 1000); + await assertion; + + jest.useRealTimers(); + }); + + it('sends an empty body when no checkId is given (run all)', async () => { + const fetchMock = jest.spyOn(global, 'fetch').mockResolvedValue({ + ok: true, + json: async () => ({}), + } as unknown as Response); + + await runChecksOnServer(params); + + expect(fetchMock).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ body: JSON.stringify({}) }), + ); + }); + + it('throws with the server message on a non-2xx response', async () => { + jest.spyOn(global, 'fetch').mockResolvedValue({ + ok: false, + status: 500, + json: async () => ({ message: 'boom' }), + } as unknown as Response); + + await expect(runChecksOnServer(params)).rejects.toThrow('boom'); + }); + + it('throws when SERVICE_TOKEN_TRIGGER is not configured', async () => { + delete process.env.SERVICE_TOKEN_TRIGGER; + await expect(runChecksOnServer(params)).rejects.toThrow( + 'SERVICE_TOKEN_TRIGGER is not configured', + ); + }); +}); diff --git a/apps/api/src/trigger/integration-platform/run-checks-on-server.ts b/apps/api/src/trigger/integration-platform/run-checks-on-server.ts new file mode 100644 index 0000000000..e0e136992c --- /dev/null +++ b/apps/api/src/trigger/integration-platform/run-checks-on-server.ts @@ -0,0 +1,82 @@ +import type { runAllChecks } from '@trycompai/integration-platform'; + +export type RunAllChecksResult = Awaited>; + +// Generous backstop for a hung connection (no response). AWS checks legitimately +// take minutes across many buckets/regions, so this is deliberately well below +// the task's 15-minute maxDuration but high enough never to abort a real run — +// it only catches a stalled socket so the error surfaces and the task retries +// instead of blocking the whole 15 minutes. +const REQUEST_TIMEOUT_MS = 10 * 60 * 1000; + +/** + * Run a connection's checks ON OUR SERVER (ECS) and return the raw result. + * + * Used by the AWS Trigger tasks only: AWS S3 calls made from the Trigger.dev + * runtime egress Trigger.dev's VPC, whose endpoint policy blocks our + * cross-account reads. Running them on our server egresses our own VPC (where + * the endpoint allows the read) — matching the in-app manual "Run". The caller + * still persists the returned result, so AWS runs are recorded exactly like + * every other provider's. + * + * Pass `checkId` to run a single check (scheduled path); omit it to run all of + * the connection's checks (auto-run-after-connect path). + * + * Throws on a transport failure (endpoint unreachable / non-2xx) so the caller's + * existing try/catch handles it (the task fails and the orchestrator retries). + * Per-check execution errors come back inside the result as usual. + */ +export async function runChecksOnServer(params: { + apiUrl: string; + connectionId: string; + organizationId: string; + checkId?: string; +}): Promise { + const { apiUrl, connectionId, organizationId, checkId } = params; + + const serviceToken = process.env.SERVICE_TOKEN_TRIGGER; + if (!serviceToken) { + throw new Error('SERVICE_TOKEN_TRIGGER is not configured'); + } + + const abortController = new AbortController(); + const timeoutId = setTimeout( + () => abortController.abort(), + REQUEST_TIMEOUT_MS, + ); + + try { + const response = await fetch( + `${apiUrl}/v1/integrations/internal/run-connection-checks/${connectionId}`, + { + method: 'POST', + signal: abortController.signal, + headers: { + 'Content-Type': 'application/json', + 'x-service-token': serviceToken, + 'x-organization-id': organizationId, + }, + body: JSON.stringify(checkId ? { checkId } : {}), + }, + ); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + const message = + (errorData as { message?: string }).message || + `Server-side check run failed with status ${response.status}`; + throw new Error(message); + } + + return (await response.json()) as RunAllChecksResult; + } catch (error) { + if (abortController.signal.aborted) { + throw new Error( + `Server-side check run timed out after ${REQUEST_TIMEOUT_MS}ms`, + ); + } + throw error; + } finally { + clearTimeout(timeoutId); + } +} diff --git a/apps/api/src/trigger/integration-platform/run-connection-checks.ts b/apps/api/src/trigger/integration-platform/run-connection-checks.ts index e7515540ed..db590e209e 100644 --- a/apps/api/src/trigger/integration-platform/run-connection-checks.ts +++ b/apps/api/src/trigger/integration-platform/run-connection-checks.ts @@ -6,7 +6,7 @@ import { requestValidCredentials, type IntegrationCredentialValues, } from './ensure-valid-credentials'; -import { injectAwsResolvedSession } from './checks-aws-session'; +import { runChecksOnServer } from './run-checks-on-server'; /** * Trigger task that runs all checks for a connection. @@ -93,72 +93,69 @@ export const runConnectionChecks = task({ }; } - // Ensure we have valid credentials const apiUrl = process.env.BASE_URL || 'http://localhost:3333'; - let credentials: IntegrationCredentialValues; - logger.info('Ensuring valid credentials...'); - const credentialsResult = await requestValidCredentials({ - apiUrl, - connectionId, - organizationId, - }); - - if (!credentialsResult.success || !credentialsResult.credentials) { - const errorMessage = - credentialsResult.error || 'Failed to validate credentials'; - logger.error(errorMessage); - return { success: false, error: errorMessage }; - } - credentials = credentialsResult.credentials; - - const handleTokenRefresh = async (): Promise => { - logger.info('Force refreshing OAuth credentials after provider 401...'); - const refreshResult = await requestValidCredentials({ + // AWS checks run ON OUR SERVER (see below), which decrypts the credentials + // and assumes the cross-account role there. Skip the Trigger-side credential + // preflight for AWS — running it would add redundant failure points (a + // transient preflight error would falsely fail an AWS run that + // `runChecksOnServer` could have completed). + let credentials: IntegrationCredentialValues = {}; + let handleTokenRefresh: (() => Promise) | undefined; + + if (providerSlug !== 'aws') { + logger.info('Ensuring valid credentials...'); + const credentialsResult = await requestValidCredentials({ apiUrl, connectionId, organizationId, - forceRefresh: true, }); - if (!refreshResult.success || !refreshResult.credentials) { - logger.error(refreshResult.error || 'Forced token refresh failed'); - return null; + if (!credentialsResult.success || !credentialsResult.credentials) { + const errorMessage = + credentialsResult.error || 'Failed to validate credentials'; + logger.error(errorMessage); + return { success: false, error: errorMessage }; } + credentials = credentialsResult.credentials; + + handleTokenRefresh = async (): Promise => { + logger.info('Force refreshing OAuth credentials after provider 401...'); + const refreshResult = await requestValidCredentials({ + apiUrl, + connectionId, + organizationId, + forceRefresh: true, + }); + + if (!refreshResult.success || !refreshResult.credentials) { + logger.error(refreshResult.error || 'Forced token refresh failed'); + return null; + } - credentials = refreshResult.credentials; - return getAccessToken(credentials) ?? null; - }; + credentials = refreshResult.credentials; + return getAccessToken(credentials) ?? null; + }; - // Validate credentials based on auth type - if (manifest.auth.type === 'oauth2' && !getAccessToken(credentials)) { - logger.error( - `No OAuth access token found for connection: ${connectionId}`, - ); - return { success: false, error: 'No OAuth access token found' }; - } + // Validate credentials based on auth type + if (manifest.auth.type === 'oauth2' && !getAccessToken(credentials)) { + logger.error( + `No OAuth access token found for connection: ${connectionId}`, + ); + return { success: false, error: 'No OAuth access token found' }; + } - if ( - manifest.auth.type === 'custom' && - Object.keys(credentials).length === 0 - ) { - logger.error( - `No credentials found for custom integration: ${connectionId}`, - ); - return { success: false, error: 'No credentials found' }; + if ( + manifest.auth.type === 'custom' && + Object.keys(credentials).length === 0 + ) { + logger.error( + `No credentials found for custom integration: ${connectionId}`, + ); + return { success: false, error: 'No credentials found' }; + } } - // For AWS, resolve the cross-account session in ECS and inject the temp - // creds — the checks run in the Trigger.dev runtime, which cannot assume the - // role itself (no base creds / roleAssumer ARN there). - credentials = await injectAwsResolvedSession({ - credentials, - apiUrl, - connectionId, - organizationId, - providerSlug, - }); - const variables = (connection.variables as Record< string, @@ -180,22 +177,30 @@ export const runConnectionChecks = task({ let totalPassing = 0; try { - // Run all checks - const result = await runAllChecks({ - manifest, - accessToken: getAccessToken(credentials), - credentials, - variables, - connectionId, - organizationId, - onTokenRefresh: - manifest.auth.type === 'oauth2' ? handleTokenRefresh : undefined, - logger: { - info: (msg, data) => logger.info(msg, data), - warn: (msg, data) => logger.warn(msg, data), - error: (msg, data) => logger.error(msg, data), - }, - }); + // AWS checks run ON OUR SERVER so their S3 calls egress our VPC (allowed) + // instead of Trigger.dev's (blocked). Every other provider keeps running + // here in the Trigger.dev runtime, unchanged. Same result shape either + // way, so the persistence below is shared. + const result = + providerSlug === 'aws' + ? await runChecksOnServer({ apiUrl, connectionId, organizationId }) + : await runAllChecks({ + manifest, + accessToken: getAccessToken(credentials), + credentials, + variables, + connectionId, + organizationId, + onTokenRefresh: + manifest.auth.type === 'oauth2' + ? handleTokenRefresh + : undefined, + logger: { + info: (msg, data) => logger.info(msg, data), + warn: (msg, data) => logger.warn(msg, data), + error: (msg, data) => logger.error(msg, data), + }, + }); totalFindings = result.totalFindings; totalPassing = result.totalPassing; diff --git a/apps/api/src/trigger/integration-platform/run-task-integration-checks.ts b/apps/api/src/trigger/integration-platform/run-task-integration-checks.ts index 4d68611326..ad7c5d0722 100644 --- a/apps/api/src/trigger/integration-platform/run-task-integration-checks.ts +++ b/apps/api/src/trigger/integration-platform/run-task-integration-checks.ts @@ -10,7 +10,10 @@ import { requestValidCredentials, type IntegrationCredentialValues, } from './ensure-valid-credentials'; -import { injectAwsResolvedSession } from './checks-aws-session'; +import { + runChecksOnServer, + type RunAllChecksResult, +} from './run-checks-on-server'; /** * Send email notifications for task status change @@ -218,91 +221,88 @@ export const runTaskIntegrationChecks = task({ return { success: false, error: 'Connection not found or inactive' }; } - // Ensure we have valid credentials (refresh OAuth tokens if needed) const apiUrl = process.env.BASE_URL || 'http://localhost:3333'; - let credentials: IntegrationCredentialValues; - - logger.info('Ensuring valid credentials (refreshing if needed)...'); - const credentialsResult = await requestValidCredentials({ - apiUrl, - connectionId, - organizationId, - }); - - if (!credentialsResult.success || !credentialsResult.credentials) { - const errorMessage = - credentialsResult.error || 'Failed to validate credentials'; - logger.error(errorMessage); - - // If unauthorized, mark connection as error - if (credentialsResult.status === 401) { - await db.integrationConnection.update({ - where: { id: connectionId }, - data: { - status: 'error', - errorMessage: - 'OAuth token expired. Please reconnect the integration.', - }, - }); - } - return { success: false, error: errorMessage }; - } - credentials = credentialsResult.credentials; - logger.info('Credentials validated successfully'); - - const handleTokenRefresh = async (): Promise => { - logger.info('Force refreshing OAuth credentials after provider 401...'); - const refreshResult = await requestValidCredentials({ + // AWS checks run ON OUR SERVER (see the loop below), which decrypts the + // credentials and assumes the cross-account role there. So the Trigger-side + // credential/session preflight is skipped for AWS — running it would add + // redundant failure points (a transient preflight error would falsely fail + // an AWS run that `runChecksOnServer` could have completed). + let credentials: IntegrationCredentialValues = {}; + let handleTokenRefresh: (() => Promise) | undefined; + + if (providerSlug !== 'aws') { + logger.info('Ensuring valid credentials (refreshing if needed)...'); + const credentialsResult = await requestValidCredentials({ apiUrl, connectionId, organizationId, - forceRefresh: true, }); - if (!refreshResult.success || !refreshResult.credentials) { - logger.error(refreshResult.error || 'Forced token refresh failed'); - return null; + if (!credentialsResult.success || !credentialsResult.credentials) { + const errorMessage = + credentialsResult.error || 'Failed to validate credentials'; + logger.error(errorMessage); + + // If unauthorized, mark connection as error + if (credentialsResult.status === 401) { + await db.integrationConnection.update({ + where: { id: connectionId }, + data: { + status: 'error', + errorMessage: + 'OAuth token expired. Please reconnect the integration.', + }, + }); + } + + return { success: false, error: errorMessage }; } + credentials = credentialsResult.credentials; + logger.info('Credentials validated successfully'); - credentials = refreshResult.credentials; - return getAccessToken(credentials) ?? null; - }; + handleTokenRefresh = async (): Promise => { + logger.info('Force refreshing OAuth credentials after provider 401...'); + const refreshResult = await requestValidCredentials({ + apiUrl, + connectionId, + organizationId, + forceRefresh: true, + }); - // Validate credentials based on auth type - if (manifest.auth.type === 'oauth2' && !getAccessToken(credentials)) { - logger.error( - `No OAuth access token found for connection: ${connectionId}`, - ); - return { - success: false, - error: 'No OAuth access token found. Please reconnect.', - }; - } + if (!refreshResult.success || !refreshResult.credentials) { + logger.error(refreshResult.error || 'Forced token refresh failed'); + return null; + } - if ( - manifest.auth.type === 'custom' && - Object.keys(credentials).length === 0 - ) { - logger.error( - `No credentials found for custom integration: ${connectionId}`, - ); - return { - success: false, - error: 'No credentials found for custom integration', + credentials = refreshResult.credentials; + return getAccessToken(credentials) ?? null; }; - } - // For AWS, resolve the cross-account session in ECS and inject the temp - // creds — the checks run in the Trigger.dev runtime, which cannot assume the - // role itself (no base creds / roleAssumer ARN there). - credentials = await injectAwsResolvedSession({ - credentials, - apiUrl, - connectionId, - organizationId, - providerSlug, - }); + // Validate credentials based on auth type + if (manifest.auth.type === 'oauth2' && !getAccessToken(credentials)) { + logger.error( + `No OAuth access token found for connection: ${connectionId}`, + ); + return { + success: false, + error: 'No OAuth access token found. Please reconnect.', + }; + } + + if ( + manifest.auth.type === 'custom' && + Object.keys(credentials).length === 0 + ) { + logger.error( + `No credentials found for custom integration: ${connectionId}`, + ); + return { + success: false, + error: 'No credentials found for custom integration', + }; + } + } const variables = (connection.variables as Record< @@ -339,22 +339,81 @@ export const runTaskIntegrationChecks = task({ // Run only the checks that apply to this task try { for (const checkId of effectiveCheckIds) { - const result = await runAllChecks({ - manifest, - accessToken: getAccessToken(credentials), - credentials, - variables, - connectionId, - organizationId, - checkId, // Run specific check - onTokenRefresh: - manifest.auth.type === 'oauth2' ? handleTokenRefresh : undefined, - logger: { - info: (msg, data) => logger.info(msg, data), - warn: (msg, data) => logger.warn(msg, data), - error: (msg, data) => logger.error(msg, data), - }, - }); + // AWS checks run ON OUR SERVER so their S3 calls egress our VPC (whose + // endpoint allows the read) instead of Trigger.dev's (which blocks it). + // Every other provider keeps executing here in the Trigger.dev runtime, + // unchanged. The result shape is identical either way, so all the + // persistence / status / email logic below is shared. + let result: RunAllChecksResult; + try { + result = + providerSlug === 'aws' + ? await runChecksOnServer({ + apiUrl, + connectionId, + organizationId, + checkId, + }) + : await runAllChecks({ + manifest, + accessToken: getAccessToken(credentials), + credentials, + variables, + connectionId, + organizationId, + checkId, // Run specific check + onTokenRefresh: + manifest.auth.type === 'oauth2' + ? handleTokenRefresh + : undefined, + logger: { + info: (msg, data) => logger.info(msg, data), + warn: (msg, data) => logger.warn(msg, data), + error: (msg, data) => logger.error(msg, data), + }, + }); + } catch (error) { + // Only the AWS server-run path is degraded here. Non-AWS providers run + // in-process via runAllChecks, which catches per-check failures and + // returns status:'error' rather than throwing — so a throw on the + // non-AWS branch is unexpected and must NOT be silently downgraded. + // Re-throw it to preserve the pre-change behavior (it propagates to the + // outer catch and fails the task). + if (providerSlug !== 'aws') throw error; + + // AWS server-run threw, and only on a transport blip (network/non-2xx) + // — per-check AWS execution errors come back inside the result, not + // thrown. Record THIS check as errored and keep going so one blip + // doesn't abort its sibling checks (multiple AWS checks share a task) + // or skip the lastSyncAt/status updates, mirroring runAllChecks' + // per-check resilience. hasExecutionErrors keeps integrationLastRunAt + // unwritten, so the next orchestrator tick retries. + const message = + error instanceof Error ? error.message : String(error); + const checkDef = manifest.checks?.find((c) => c.id === checkId); + hasFailedChecks = true; + hasExecutionErrors = true; + await db.integrationCheckRun.create({ + data: { + connectionId, + taskId, + checkId, + checkName: checkDef?.name ?? checkId, + status: 'failed', + startedAt: new Date(), + completedAt: new Date(), + durationMs: 0, + totalChecked: 0, + passedCount: 0, + failedCount: 0, + errorMessage: message, + }, + }); + logger.error( + `Server-run failed for check ${checkId} on task ${taskId}: ${message}`, + ); + continue; + } const checkResult = result.results[0]; if (!checkResult) continue; diff --git a/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/ExpandableDescription.test.tsx b/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/ExpandableDescription.test.tsx new file mode 100644 index 0000000000..f567857b62 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/ExpandableDescription.test.tsx @@ -0,0 +1,41 @@ +import { fireEvent, render, screen, within } from '@testing-library/react'; +import { describe, expect, it, vi } from 'vitest'; +import { ExpandableDescription } from './ExpandableDescription'; + +const LONG = + 'Develop security and privacy plans for the system that are consistent with the enterprise architecture.'; + +describe('ExpandableDescription', () => { + it('renders the description inline with a read-more affordance', () => { + render(); + expect(screen.getByText(LONG)).toBeInTheDocument(); + expect( + screen.getByRole('button', { name: /read full description/i }), + ).toBeInTheDocument(); + }); + + it('opens a dialog with the full description and an identifier · name heading', () => { + render(); + fireEvent.click(screen.getByRole('button', { name: /read full description/i })); + const dialog = screen.getByRole('dialog'); + expect(within(dialog).getByText('PL-2 · System Security')).toBeInTheDocument(); + expect(within(dialog).getByText(LONG)).toBeInTheDocument(); + }); + + it('renders an em dash and no button when there is no description', () => { + render(); + expect(screen.getByText('—')).toBeInTheDocument(); + expect(screen.queryByRole('button', { name: /read full description/i })).toBeNull(); + }); + + it('does not trigger the clickable parent row when expanding', () => { + const onRowClick = vi.fn(); + render( +
+ +
, + ); + fireEvent.click(screen.getByRole('button', { name: /read full description/i })); + expect(onRowClick).not.toHaveBeenCalled(); + }); +}); diff --git a/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/ExpandableDescription.tsx b/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/ExpandableDescription.tsx new file mode 100644 index 0000000000..57f5a67dee --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/ExpandableDescription.tsx @@ -0,0 +1,69 @@ +'use client'; + +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, +} from '@trycompai/design-system'; +import { Maximize } from '@trycompai/design-system/icons'; +import { useState } from 'react'; + +interface ExpandableDescriptionProps { + description: string | null | undefined; + identifier?: string | null; + name?: string | null; +} + +/** + * Read-only requirement description cell. Shows the truncated text inline and, + * on hover, a maximize button that opens a dialog with the full description — + * long framework requirements (e.g. NIST SP800-53 PL-2) are otherwise + * unreadable behind the single-line truncation + native tooltip. + */ +export function ExpandableDescription({ + description, + identifier, + name, +}: ExpandableDescriptionProps) { + const [isOpen, setIsOpen] = useState(false); + + if (!description) { + return ; + } + + const heading = [identifier?.trim(), name].filter(Boolean).join(' · ') || 'Requirement'; + + return ( +
+ + {description} + + + + + + + {heading} + +
+ {description} +
+
+
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/FrameworkRequirements.tsx b/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/FrameworkRequirements.tsx index 04bdbb62ab..1399b4e70b 100644 --- a/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/FrameworkRequirements.tsx +++ b/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/FrameworkRequirements.tsx @@ -26,6 +26,7 @@ import { import { Search } from '@trycompai/design-system/icons'; import { useParams, useRouter } from 'next/navigation'; import { useEffect, useMemo, useState } from 'react'; +import { ExpandableDescription } from './ExpandableDescription'; import { REQUIREMENTS_TABLE_COLUMN_COUNT, REQUIREMENTS_TABLE_STYLE, @@ -215,9 +216,11 @@ export function FrameworkRequirements({ - - {item.description || '—'} - +
diff --git a/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/GroupedRequirementRow.tsx b/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/GroupedRequirementRow.tsx index 0e25a56c64..f6cc83fa7c 100644 --- a/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/GroupedRequirementRow.tsx +++ b/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/GroupedRequirementRow.tsx @@ -4,6 +4,7 @@ import { getRequirementStatus } from '@/lib/control-compliance'; import { Badge, TableCell, TableRow, Text } from '@trycompai/design-system'; import { Launch } from '@trycompai/design-system/icons'; import Link from 'next/link'; +import { ExpandableDescription } from './ExpandableDescription'; import type { RequirementItem } from './framework-controls-shared'; export function GroupedRequirementRow({ @@ -42,9 +43,11 @@ export function GroupedRequirementRow({ - - {item.description || '—'} - +
diff --git a/apps/app/src/app/(app)/[orgId]/layout.tsx b/apps/app/src/app/(app)/[orgId]/layout.tsx index d85c2e2314..b284e75bea 100644 --- a/apps/app/src/app/(app)/[orgId]/layout.tsx +++ b/apps/app/src/app/(app)/[orgId]/layout.tsx @@ -15,7 +15,7 @@ import type { OrganizationFromMe } from '@/types'; import { auth } from '@/utils/auth'; import { GetObjectCommand } from '@aws-sdk/client-s3'; import { getSignedUrl } from '@/lib/s3-presigner'; -import { OrganizationIdentifier } from '@trycompai/analytics'; +import { OrganizationIdentifier, ServerFeatureFlagsProvider } from '@trycompai/analytics'; import { db, Role } from '@db/server'; import dynamic from 'next/dynamic'; import { cookies, headers } from 'next/headers'; @@ -146,21 +146,21 @@ export default async function Layout({ // Check feature flags for menu items. Security (penetration tests) is // always enabled now — the nav rail entry is gated solely by the // `pentest:read` permission downstream, matching `security/layout.tsx`. - let isQuestionnaireEnabled = false; - let isTrustNdaEnabled = false; - let isWebAutomationsEnabled = false; + // The full map is also provided to the client via ServerFeatureFlagsProvider + // so `useFeatureFlag` keeps working when posthog-js is blocked client-side. + const featureFlags = session?.user?.id + ? await getFeatureFlags(session.user.id, { + groups: { organization: organization.id }, + }) + : {}; + const isQuestionnaireEnabled = featureFlags['ai-vendor-questionnaire'] === true; + const isTrustNdaEnabled = + featureFlags['is-trust-nda-enabled'] === true || + featureFlags['is-trust-nda-enabled'] === 'true'; + const isWebAutomationsEnabled = + featureFlags['is-web-automations-enabled'] === true || + featureFlags['is-web-automations-enabled'] === 'true'; const isSecurityEnabled = true; - if (session?.user?.id) { - const flags = await getFeatureFlags(session.user.id, { - groups: { organization: organization.id }, - }); - isQuestionnaireEnabled = flags['ai-vendor-questionnaire'] === true; - isTrustNdaEnabled = - flags['is-trust-nda-enabled'] === true || flags['is-trust-nda-enabled'] === 'true'; - isWebAutomationsEnabled = - flags['is-web-automations-enabled'] === true || - flags['is-web-automations-enabled'] === 'true'; - } // Check auditor role const hasAuditorRole = roles.includes(Role.auditor); @@ -192,25 +192,27 @@ export default async function Layout({ initialToken={publicAccessToken || undefined} > - - {children} - + + + {children} + + ); diff --git a/apps/app/src/app/(app)/[orgId]/overview/components/OverviewTabs.test.tsx b/apps/app/src/app/(app)/[orgId]/overview/components/OverviewTabs.test.tsx new file mode 100644 index 0000000000..0114a7739d --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/overview/components/OverviewTabs.test.tsx @@ -0,0 +1,142 @@ +import { render, renderHook, screen } from '@testing-library/react'; +import { describe, expect, it, vi } from 'vitest'; + +// The live posthog-js value is controlled per-test. `undefined` simulates a +// client whose /ingest/flags request is blocked (ad blocker, privacy browser, +// corporate proxy) — flags never load, so the hook never resolves. +const { useFeatureFlagEnabledMock } = vi.hoisted(() => ({ + useFeatureFlagEnabledMock: vi.fn<(flag: string) => boolean | undefined>(), +})); + +vi.mock('posthog-js/react', () => ({ + useFeatureFlagEnabled: (flag: string) => useFeatureFlagEnabledMock(flag), + usePostHog: () => null, + PostHogProvider: ({ children }: { children: React.ReactNode }) => children, +})); + +vi.mock('next/navigation', () => ({ + useParams: () => ({ orgId: 'org_test123' }), + usePathname: () => '/org_test123/overview', +})); + +vi.mock('@/hooks/use-findings-api', () => ({ + useOrganizationFindings: () => ({ data: undefined }), +})); + +vi.mock('@db', () => ({ + FindingStatus: { open: 'open' }, +})); + +import { ServerFeatureFlagsProvider, useFeatureFlag } from '@trycompai/analytics'; +import { OverviewTabs } from './OverviewTabs'; + +describe('useFeatureFlag server fallback', () => { + it('returns false when flags never load and no server flags are provided', () => { + useFeatureFlagEnabledMock.mockReturnValue(undefined); + + const { result } = renderHook(() => useFeatureFlag('is-timeline-enabled')); + + expect(result.current).toBe(false); + }); + + it('falls back to the server-evaluated value when flags never load', () => { + useFeatureFlagEnabledMock.mockReturnValue(undefined); + + const { result } = renderHook(() => useFeatureFlag('is-timeline-enabled'), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + expect(result.current).toBe(true); + }); + + it('treats multivariate (string) server values as enabled', () => { + useFeatureFlagEnabledMock.mockReturnValue(undefined); + + const { result } = renderHook(() => useFeatureFlag('is-timeline-enabled'), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + expect(result.current).toBe(true); + }); + + it('stays false when both live and server values are disabled', () => { + useFeatureFlagEnabledMock.mockReturnValue(false); + + const { result } = renderHook(() => useFeatureFlag('is-timeline-enabled'), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + expect(result.current).toBe(false); + }); + + it('prefers an enabled server value over a stale persisted live=false', () => { + // posthog-js serves flags persisted from an older session even when the + // network is blocked — those can predate the admin toggle. The fresher + // server-side evaluation must win for enable rollouts. + useFeatureFlagEnabledMock.mockReturnValue(false); + + const { result } = renderHook(() => useFeatureFlag('is-timeline-enabled'), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + expect(result.current).toBe(true); + }); + + it('returns true from the live value alone, without a provider', () => { + useFeatureFlagEnabledMock.mockReturnValue(true); + + const { result } = renderHook(() => useFeatureFlag('is-timeline-enabled')); + + expect(result.current).toBe(true); + }); +}); + +describe('OverviewTabs timeline gating', () => { + it('shows the Timeline tab via server flags when the client cannot load flags', () => { + useFeatureFlagEnabledMock.mockReturnValue(undefined); + + render( + + + , + ); + + expect(screen.getByText('Timeline')).toBeInTheDocument(); + }); + + it('hides the Timeline tab when the flag is off everywhere', () => { + useFeatureFlagEnabledMock.mockReturnValue(undefined); + + render( + + + , + ); + + expect(screen.queryByText('Timeline')).not.toBeInTheDocument(); + }); + + it('shows the Timeline tab from the live client flag without server flags', () => { + useFeatureFlagEnabledMock.mockReturnValue(true); + + render(); + + expect(screen.getByText('Timeline')).toBeInTheDocument(); + }); +}); diff --git a/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/FrameworkRequirementsClientPage.expandable.test.tsx b/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/FrameworkRequirementsClientPage.expandable.test.tsx index 41a0d4182c..ea4910bf86 100644 --- a/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/FrameworkRequirementsClientPage.expandable.test.tsx +++ b/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/FrameworkRequirementsClientPage.expandable.test.tsx @@ -22,6 +22,12 @@ vi.mock('../../../components/table', () => ({ vi.mock('./components/EditFrameworkDialog', () => ({ EditFrameworkDialog: () => null })); vi.mock('./components/DeleteFrameworkDialog', () => ({ DeleteFrameworkDialog: () => null })); +vi.mock('./versions/components/PublishVersionDialog', () => ({ + PublishVersionDialog: () => null, +})); +vi.mock('./versions/hooks/useFrameworkVersions', () => ({ + useFrameworkVersions: () => ({ data: [], refetch: vi.fn() }), +})); vi.mock('@/app/lib/api-client', () => ({ apiClient: vi.fn() })); vi.mock('next/navigation', () => ({ useRouter: () => ({ refresh: vi.fn(), push: vi.fn() }) })); vi.mock('sonner', () => ({ toast: { success: vi.fn(), error: vi.fn() } })); @@ -71,7 +77,11 @@ describe('FrameworkRequirementsClientPage — Description column', () => { const description = editableCellProps.find((p) => p.columnId === 'description'); expect(description?.expandable).toBe(true); - expect(description?.expandTitle).toBe('Edit Requirement Description'); + // Identifier + name are appended so the editor dialog says which requirement + // is being edited (FRAME-7), e.g. "… - AC-2 - Account Management". + expect(description?.expandTitle).toBe( + 'Edit Requirement Description - AC-2 - Account Management', + ); // The short single-line columns stay as plain inline edits. for (const columnId of ['identifier', 'name']) { diff --git a/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/FrameworkRequirementsClientPage.toolbar.test.tsx b/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/FrameworkRequirementsClientPage.toolbar.test.tsx new file mode 100644 index 0000000000..48b8bf20fc --- /dev/null +++ b/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/FrameworkRequirementsClientPage.toolbar.test.tsx @@ -0,0 +1,101 @@ +import { fireEvent, render, screen, waitFor } from '@testing-library/react'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +// Shared, hoisted handles so the mock factory and the assertions see the same +// references. publishProps records each render's `open` value. +const { handleCommit, handleCancel, publishProps } = vi.hoisted(() => ({ + handleCommit: vi.fn(async () => true), + handleCancel: vi.fn(), + publishProps: [] as Array<{ open: boolean }>, +})); + +vi.mock('../../../components/table', () => ({ + ComboboxCell: () => null, + DateCell: () => null, + RelationalCell: () => null, + EditableCell: () => null, +})); +vi.mock('./components/EditFrameworkDialog', () => ({ EditFrameworkDialog: () => null })); +vi.mock('./components/DeleteFrameworkDialog', () => ({ DeleteFrameworkDialog: () => null })); +vi.mock('./versions/components/PublishVersionDialog', () => ({ + PublishVersionDialog: (props: { open: boolean }) => { + publishProps.push({ open: props.open }); + return null; + }, +})); +vi.mock('./versions/hooks/useFrameworkVersions', () => ({ + useFrameworkVersions: () => ({ data: [{ version: '1.0.0' }], refetch: vi.fn() }), +})); +vi.mock('@/app/lib/api-client', () => ({ apiClient: vi.fn() })); +vi.mock('next/navigation', () => ({ useRouter: () => ({ refresh: vi.fn(), push: vi.fn() }) })); +vi.mock('sonner', () => ({ toast: { success: vi.fn(), error: vi.fn() } })); +vi.mock('@trycompai/ui', () => ({ + Button: ({ children, variant: _v, size: _s, ...props }: any) => ( + + ), +})); + +vi.mock('./hooks/useRequirementChangeTracking', () => ({ + simpleUUID: () => 'temp-id', + useRequirementChangeTracking: () => ({ + data: [], + updateCell: vi.fn(), + updateRelational: vi.fn(), + addRow: vi.fn(), + deleteRow: vi.fn(), + getRowClassName: () => '', + handleCommit, + handleCancel, + isDirty: true, + createdIds: new Set(), + changesSummary: '(2 changes)', + }), +})); + +import { FrameworkRequirementsClientPage } from './FrameworkRequirementsClientPage'; + +function renderPage() { + render( + , + ); +} + +describe('FrameworkRequirementsClientPage — Save as Draft / Save and Commit (FRAME-4)', () => { + beforeEach(() => { + vi.clearAllMocks(); + handleCommit.mockImplementation(async () => true); + publishProps.length = 0; + }); + + it('shows all three buttons when there are uncommitted changes', () => { + renderPage(); + expect(screen.getByRole('button', { name: 'Cancel' })).toBeTruthy(); + expect(screen.getByRole('button', { name: 'Save as Draft' })).toBeTruthy(); + expect(screen.getByRole('button', { name: 'Save and Commit' })).toBeTruthy(); + }); + + it('Save as Draft commits without opening the publish dialog', () => { + renderPage(); + fireEvent.click(screen.getByRole('button', { name: 'Save as Draft' })); + expect(handleCommit).toHaveBeenCalledTimes(1); + expect(publishProps.every((p) => p.open === false)).toBe(true); + }); + + it('Save and Commit saves then opens the publish dialog', async () => { + renderPage(); + fireEvent.click(screen.getByRole('button', { name: 'Save and Commit' })); + expect(handleCommit).toHaveBeenCalledTimes(1); + await waitFor(() => expect(publishProps.some((p) => p.open === true)).toBe(true)); + }); + + it('does not open the publish dialog when the save fails', async () => { + handleCommit.mockImplementation(async () => false); + renderPage(); + fireEvent.click(screen.getByRole('button', { name: 'Save and Commit' })); + await waitFor(() => expect(handleCommit).toHaveBeenCalled()); + expect(publishProps.every((p) => p.open === false)).toBe(true); + }); +}); diff --git a/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/FrameworkRequirementsClientPage.tsx b/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/FrameworkRequirementsClientPage.tsx index f6e9570a30..43e752d9b0 100644 --- a/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/FrameworkRequirementsClientPage.tsx +++ b/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/FrameworkRequirementsClientPage.tsx @@ -22,6 +22,8 @@ import { useRequirementChangeTracking, type RequirementGridRow, } from './hooks/useRequirementChangeTracking'; +import { PublishVersionDialog } from './versions/components/PublishVersionDialog'; +import { useFrameworkVersions } from './versions/hooks/useFrameworkVersions'; interface FrameworkDetails { id: string; @@ -73,6 +75,15 @@ export function FrameworkRequirementsClientPage({ const router = useRouter(); const [isEditDialogOpen, setIsEditDialogOpen] = useState(false); const [isDeleteDialogOpen, setIsDeleteDialogOpen] = useState(false); + // Row whose large description editor is currently open — highlighted so the + // edited row is obvious behind the (semi-transparent) editor dialog. + const [expandedRowId, setExpandedRowId] = useState(null); + // "Save and Commit" saves the edits then opens the publish flow (FRAME-4). + const [isPublishOpen, setIsPublishOpen] = useState(false); + const { data: publishedVersions, refetch: refetchVersions } = useFrameworkVersions( + frameworkDetails.id, + ); + const latestPublishedVersion = publishedVersions?.[0]?.version; const initialGridData: RequirementGridRow[] = useMemo( () => @@ -104,6 +115,13 @@ export function FrameworkRequirementsClientPage({ changesSummary, } = useRequirementChangeTracking(initialGridData, frameworkDetails.id); + // Save edits, then (only if they all persisted) open the publish dialog so + // the accumulated changes can be committed as a new version. + const handleSaveAndCommit = useCallback(async () => { + const ok = await handleCommit(); + if (ok) setIsPublishOpen(true); + }, [handleCommit]); + const uniqueFamilies = useMemo(() => { const families = new Set(); for (const row of data) { @@ -155,16 +173,27 @@ export function FrameworkRequirementsClientPage({ header: 'Description', size: 300, maxSize: 300, - cell: ({ row, getValue }) => ( - - ), + cell: ({ row, getValue }) => { + const { identifier, name } = row.original; + const titleSuffix = [identifier, name].filter(Boolean).join(' - '); + return ( + + setExpandedRowId(open ? row.original.id : null) + } + /> + ); + }, }), columnHelper.accessor('controlTemplates', { header: 'Linked Controls', @@ -290,8 +319,11 @@ export function FrameworkRequirementsClientPage({ - + )} @@ -369,7 +401,11 @@ export function FrameworkRequirementsClientPage({ {table.getRowModel().rows.map((row) => ( {row.getVisibleCells().map((cell) => ( )} + setIsPublishOpen(false)} + latestVersion={latestPublishedVersion} + onPublished={() => { + setIsPublishOpen(false); + void refetchVersions(); + router.refresh(); + }} + />
); } diff --git a/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/hooks/useRequirementChangeTracking.ts b/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/hooks/useRequirementChangeTracking.ts index 3d699e390d..ebf5d1213b 100644 --- a/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/hooks/useRequirementChangeTracking.ts +++ b/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/hooks/useRequirementChangeTracking.ts @@ -256,6 +256,10 @@ export function useRequirementChangeTracking( // Re-sync the grid with server truth (ids, timestamps, links). router.refresh(); } + + // Report success so callers (e.g. "Save and Commit") can chain a publish + // only when every edit persisted cleanly. + return results.errors.length === 0; }, [data, createdIds, updatedIds, deletedIds, frameworkId, router]); const handleCancel = useCallback(() => { diff --git a/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/versions/components/VersionDiffView.test.tsx b/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/versions/components/VersionDiffView.test.tsx new file mode 100644 index 0000000000..5d17007c55 --- /dev/null +++ b/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/versions/components/VersionDiffView.test.tsx @@ -0,0 +1,50 @@ +import { describe, expect, it } from 'vitest'; +import type { DraftDiff } from '../hooks/useFrameworkDraftDiff'; +import { hasAnyChanges } from './VersionDiffView'; + +function emptyDiff(): DraftDiff['diff'] { + const entity = { added: [], removed: [], updated: [] }; + const edge = { added: [], removed: [] }; + return { + controls: entity, + requirements: entity, + policies: entity, + tasks: entity, + requirementMapEdges: edge, + controlPolicyEdges: edge, + controlTaskEdges: edge, + controlDocumentTypeEdges: edge, + }; +} + +describe('hasAnyChanges', () => { + it('is false for an empty diff', () => { + expect(hasAnyChanges(emptyDiff())).toBe(false); + }); + + // FRAME-9: a name/description-only edit must count as a change so Publish + // doesn't stay greyed out with "no changes detected". + it('is true when only the framework name changed', () => { + const diff = { ...emptyDiff(), framework: { changed: true, name: { from: 'A', to: 'B' } } }; + expect(hasAnyChanges(diff)).toBe(true); + }); + + it('is true when only the framework description changed', () => { + const diff = { + ...emptyDiff(), + framework: { changed: true, description: { from: 'old', to: 'new' } }, + }; + expect(hasAnyChanges(diff)).toBe(true); + }); + + it('is false when framework.changed is false', () => { + const diff = { ...emptyDiff(), framework: { changed: false } }; + expect(hasAnyChanges(diff)).toBe(false); + }); + + it('still detects entity changes (sanity)', () => { + const diff = emptyDiff(); + diff.controls = { added: [{ id: 'c1', name: 'C1' }], removed: [], updated: [] }; + expect(hasAnyChanges(diff)).toBe(true); + }); +}); diff --git a/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/versions/components/VersionDiffView.tsx b/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/versions/components/VersionDiffView.tsx index c6c1c10d48..500be7aa28 100644 --- a/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/versions/components/VersionDiffView.tsx +++ b/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/versions/components/VersionDiffView.tsx @@ -32,6 +32,7 @@ export function hasAnyChanges(diff: DraftDiff['diff']): boolean { } = diff; const docTypeEdges = controlDocumentTypeEdges ?? { added: [], removed: [] }; return ( + (diff.framework?.changed ?? false) || controls.added.length > 0 || controls.removed.length > 0 || controls.updated.length > 0 || @@ -58,6 +59,7 @@ export function hasAnyChanges(diff: DraftDiff['diff']): boolean { export function VersionDiffView({ diff, linkChanges }: VersionDiffViewProps) { return ( <> + +

+ Framework +

+
+ {framework.name && ( + + + Name:{' '} + {framework.name.from}{' '} + → {framework.name.to} + + + )} + {framework.description && ( + + Description updated + + )} +
+
+ ); +} + interface DiffDetailSectionProps { title: string; added: T[]; diff --git a/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/versions/hooks/useFrameworkDraftDiff.ts b/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/versions/hooks/useFrameworkDraftDiff.ts index 909d9b7994..6e87f74a54 100644 --- a/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/versions/hooks/useFrameworkDraftDiff.ts +++ b/apps/framework-editor/app/(pages)/frameworks/[frameworkId]/versions/hooks/useFrameworkDraftDiff.ts @@ -44,6 +44,13 @@ export interface EdgeDiffCounts { export interface DraftDiff { latestVersion: { id: string; version: string } | null; diff: { + // Optional for older API responses / historical diffs that predate the + // framework-metadata diff (FRAME-9). + framework?: { + changed: boolean; + name?: { from: string; to: string }; + description?: { from: string | null; to: string | null }; + }; controls: EntityDiffCounts; requirements: EntityDiffCounts; policies: EntityDiffCounts; diff --git a/apps/framework-editor/app/components/HeaderFrameworks.tsx b/apps/framework-editor/app/components/HeaderFrameworks.tsx index 716e99dd2b..c65d20ac1a 100644 --- a/apps/framework-editor/app/components/HeaderFrameworks.tsx +++ b/apps/framework-editor/app/components/HeaderFrameworks.tsx @@ -1,6 +1,7 @@ import { Skeleton } from '@trycompai/ui/skeleton'; import Link from 'next/link'; import { Suspense } from 'react'; +import { ThemeToggle } from './theme-toggle'; import { UserMenu } from './user-menu'; export async function Header() { @@ -9,9 +10,12 @@ export async function Header() { Framework Editor - }> - - +
+ + }> + + +
); } diff --git a/apps/framework-editor/app/components/table/EditableCell.test.tsx b/apps/framework-editor/app/components/table/EditableCell.test.tsx index 518d0a1861..e04f032b7e 100644 --- a/apps/framework-editor/app/components/table/EditableCell.test.tsx +++ b/apps/framework-editor/app/components/table/EditableCell.test.tsx @@ -1,6 +1,7 @@ import { fireEvent, render, screen } from '@testing-library/react'; import { beforeEach, describe, expect, it, vi } from 'vitest'; import { EditableCell } from './EditableCell'; +import { clearEditorSize, saveEditorSize } from './editor-size-storage'; // The ui package ships untranspiled JSX in dist; stub the bits the cell uses. vi.mock('@trycompai/ui', () => ({ @@ -55,7 +56,10 @@ describe('EditableCell — non-expandable (default)', () => { }); describe('EditableCell — expandable', () => { - beforeEach(() => vi.clearAllMocks()); + beforeEach(() => { + vi.clearAllMocks(); + clearEditorSize(); + }); it('shows an expand affordance', () => { setup({ expandable: true }); @@ -120,4 +124,32 @@ describe('EditableCell — expandable', () => { setup({ expandable: true, disabled: true }); expect(screen.queryByRole('button', { name: /large editor/i })).toBeNull(); }); + + it('notifies onExpandedChange when the editor opens and on Save', () => { + const onExpandedChange = vi.fn(); + setup({ expandable: true, onExpandedChange }); + fireEvent.click(screen.getByRole('button', { name: /large editor/i })); + expect(onExpandedChange).toHaveBeenLastCalledWith(true); + fireEvent.change(screen.getByRole('textbox'), { target: { value: 'changed' } }); + fireEvent.click(screen.getByRole('button', { name: 'Save' })); + expect(onExpandedChange).toHaveBeenLastCalledWith(false); + }); + + it('notifies onExpandedChange(false) on Cancel', () => { + const onExpandedChange = vi.fn(); + setup({ expandable: true, onExpandedChange }); + fireEvent.contextMenu(screen.getByText(/assign account managers/i)); + expect(onExpandedChange).toHaveBeenLastCalledWith(true); + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })); + expect(onExpandedChange).toHaveBeenLastCalledWith(false); + }); + + it('reopens the editor at the persisted size (FRAME-3)', () => { + saveEditorSize({ width: 900, height: 500 }); + setup({ expandable: true }); + fireEvent.click(screen.getByRole('button', { name: /large editor/i })); + const textarea = screen.getByRole('textbox') as HTMLTextAreaElement; + expect(textarea.style.width).toBe('900px'); + expect(textarea.style.height).toBe('500px'); + }); }); diff --git a/apps/framework-editor/app/components/table/EditableCell.tsx b/apps/framework-editor/app/components/table/EditableCell.tsx index 54e1298732..02204891fe 100644 --- a/apps/framework-editor/app/components/table/EditableCell.tsx +++ b/apps/framework-editor/app/components/table/EditableCell.tsx @@ -10,7 +10,8 @@ import { Textarea, } from '@trycompai/ui'; import { Maximize2 } from 'lucide-react'; -import { useState } from 'react'; +import { useRef, useState } from 'react'; +import { loadEditorSize, saveEditorSize, type EditorSize } from './editor-size-storage'; interface EditableCellProps { value: string | null; @@ -24,6 +25,9 @@ interface EditableCellProps { // values like control descriptions. expandable?: boolean; expandTitle?: string; + // Notified when the large editor opens/closes so the parent can highlight + // the row currently being edited. + onExpandedChange?: (open: boolean) => void; } export function EditableCell({ @@ -35,11 +39,24 @@ export function EditableCell({ placeholder = 'Click to edit', expandable = false, expandTitle = 'Edit', + onExpandedChange, }: EditableCellProps) { const [isEditing, setIsEditing] = useState(false); const [editValue, setEditValue] = useState(value ?? ''); const [isExpanded, setIsExpanded] = useState(false); const [expandValue, setExpandValue] = useState(value ?? ''); + // Remembered editor size (FRAME-3): the large editor is resizable in both + // directions and reopens at the size the user last left it. + const [editorSize, setEditorSize] = useState(null); + const textareaRef = useRef(null); + + // Keep local open state and the parent notification in lockstep so the row + // highlight tracks the dialog exactly (open icon, right-click, save, cancel, + // Esc, and overlay click all route through here). + const setExpanded = (open: boolean) => { + setIsExpanded(open); + onExpandedChange?.(open); + }; const handleBlur = () => { setIsEditing(false); @@ -66,14 +83,28 @@ export function EditableCell({ const handleOpenExpanded = () => { if (disabled) return; setExpandValue(value ?? ''); - setIsExpanded(true); + setEditorSize(loadEditorSize()); + setExpanded(true); }; const handleExpandSave = () => { if (expandValue !== (value ?? '')) { onUpdate(rowId, columnId, expandValue); } - setIsExpanded(false); + setExpanded(false); + }; + + // Persist the editor size after a resize-handle drag (fires on pointer + // release). Skipped when unchanged so plain clicks don't thrash storage. + const handleEditorResizeEnd = () => { + const el = textareaRef.current; + if (!el) return; + const next: EditorSize = { width: el.offsetWidth, height: el.offsetHeight }; + if (editorSize && next.width === editorSize.width && next.height === editorSize.height) { + return; + } + setEditorSize(next); + saveEditorSize(next); }; if (disabled) { @@ -136,19 +167,26 @@ export function EditableCell({ - - + + {expandTitle}