diff --git a/packages/openops/src/lib/aws/azure-aws-federation.ts b/packages/openops/src/lib/aws/azure-aws-federation.ts new file mode 100644 index 0000000000..abec45ad71 --- /dev/null +++ b/packages/openops/src/lib/aws/azure-aws-federation.ts @@ -0,0 +1,117 @@ +import { + AssumeRoleCommand, + AssumeRoleWithWebIdentityCommand, + Credentials, + STSClient, +} from '@aws-sdk/client-sts'; +import { logger, SharedSystemProp, system } from '@openops/server-shared'; +import { v4 as uuidv4 } from 'uuid'; +import { getAwsClient } from './get-client'; + +let cachedCredentials: { + credentials: Credentials; + expiresAt: number; +} | null = null; + +export function clearAzureFederationCache() { + cachedCredentials = null; +} + +export async function assumeTargetRoleViaAzureFederation( + defaultRegion: string, + roleArn: string, + externalId?: string, + endpoint?: string | undefined | null, +): Promise { + const sourceCredentials = await getAwsCredentialsFromAzureIdentity( + defaultRegion, + ); + + if (!sourceCredentials?.AccessKeyId || !sourceCredentials.SecretAccessKey) { + throw new Error('Failed to get AWS credentials from Azure identity'); + } + + const client = getAwsClient( + STSClient, + { + accessKeyId: sourceCredentials.AccessKeyId, + secretAccessKey: sourceCredentials.SecretAccessKey, + sessionToken: sourceCredentials.SessionToken, + endpoint, + }, + defaultRegion, + ); + + const command = new AssumeRoleCommand({ + RoleArn: roleArn, + ExternalId: externalId || undefined, + RoleSessionName: 'openops-' + uuidv4(), + }); + + const response = await client.send(command); + + return response.Credentials; +} + +export async function getAwsCredentialsFromAzureIdentity( + defaultRegion: string, +): Promise { + const now = Date.now(); + const buffer = 5 * 60 * 1000; + + if (cachedCredentials && cachedCredentials.expiresAt > now + buffer) { + return cachedCredentials.credentials; + } + + const webIdentityToken = await getAzureOidcTokenForAws(); + const client = new STSClient({ + region: defaultRegion, + }); + + const federationRoleArn = system.getOrThrow( + SharedSystemProp.AWS_AZURE_FEDERATION_ROLE_ARN, + ); + + const command = new AssumeRoleWithWebIdentityCommand({ + RoleArn: federationRoleArn, + RoleSessionName: 'openops-' + uuidv4(), + WebIdentityToken: webIdentityToken, + }); + + const response = await client.send(command); + + if (response.Credentials) { + cachedCredentials = { + credentials: response.Credentials, + expiresAt: response.Credentials.Expiration + ? new Date(response.Credentials.Expiration).getTime() + : now + 3600 * 1000, + }; + } + + return response.Credentials; +} + +async function getAzureOidcTokenForAws(): Promise { + const resource = 'api://AzureADTokenExchange'; + + const url = + `http://169.254.169.254/metadata/identity/oauth2/token` + + `?api-version=2018-02-01` + + `&resource=${encodeURIComponent(resource)}`; + + const response = await fetch(url, { + headers: { + Metadata: 'true', + }, + }); + + if (!response.ok) { + logger.info('Failed to get Azure managed identity token.', response); + throw new Error('Failed to get Azure managed identity token.'); + } + + const data = (await response.json()) as { access_token: string }; + + return data.access_token; +} diff --git a/packages/openops/src/lib/aws/get-client.ts b/packages/openops/src/lib/aws/get-client.ts index 1b973c6a7b..ab917def26 100644 --- a/packages/openops/src/lib/aws/get-client.ts +++ b/packages/openops/src/lib/aws/get-client.ts @@ -1,26 +1,44 @@ import { SharedSystemProp, system } from '@openops/server-shared'; import { AwsCredentials } from './auth'; +import { getAwsCredentialsFromAzureIdentity } from './azure-aws-federation'; + +type AwsClientConfig = { + region: string; + credentials?: AwsCredentials | (() => Promise); + endpoint?: string; +}; + +type CachedAwsCredentials = AwsCredentials & { + expiration?: Date; +}; + +const azureCredentialCache = new Map< + string, + { + credentials: CachedAwsCredentials | null; + promise: Promise | null; + } +>(); export function getAwsClient( - ClientConstructor: new (config: { - region: string; - credentials: AwsCredentials | undefined; - endpoint?: string; - }) => T, + ClientConstructor: new (config: AwsClientConfig) => T, credentials: AwsCredentials, region: string, ): T { - const config: any = { region }; + const config: AwsClientConfig = { + region, + }; + if (credentials.accessKeyId) { - config.credentials = { - accessKeyId: credentials.accessKeyId, - secretAccessKey: credentials.secretAccessKey, - sessionToken: credentials.sessionToken, - }; + config.credentials = createStaticCredentials(credentials); } else if (!system.getBoolean(SharedSystemProp.AWS_ENABLE_IMPLICIT_ROLE)) { throw new Error( 'AWS credentials are required, please provide accessKeyId and secretAccessKey', ); + } else if ( + system.getBoolean(SharedSystemProp.AWS_USE_AZURE_MANAGED_IDENTITY) + ) { + config.credentials = createAzureManagedIdentityCredentialsProvider(region); } if (credentials.endpoint) { @@ -29,3 +47,85 @@ export function getAwsClient( return new ClientConstructor(config); } + +function createStaticCredentials(credentials: AwsCredentials): AwsCredentials { + return { + accessKeyId: credentials.accessKeyId, + secretAccessKey: credentials.secretAccessKey, + sessionToken: credentials.sessionToken, + }; +} + +function createAzureManagedIdentityCredentialsProvider( + region: string, +): () => Promise { + const cache = getOrCreateAzureCredentialCache(region); + + return async () => { + if (hasValidCredentials(cache.credentials)) { + return cache.credentials; + } + + if (cache.promise) { + return cache.promise; + } + + cache.promise = fetchAzureManagedIdentityCredentials(region); + + try { + cache.credentials = await cache.promise; + return cache.credentials; + } finally { + cache.promise = null; + } + }; +} + +function getOrCreateAzureCredentialCache(region: string) { + let cache = azureCredentialCache.get(region); + + if (!cache) { + cache = { + credentials: null, + promise: null, + }; + + azureCredentialCache.set(region, cache); + } + + return cache; +} + +function hasValidCredentials( + credentials: CachedAwsCredentials | null, +): credentials is CachedAwsCredentials { + if (!credentials) { + return false; + } + + if (!credentials.expiration) { + return true; + } + + // Refresh 1 minute before expiration + return credentials.expiration.getTime() > Date.now() + 60_000; +} + +async function fetchAzureManagedIdentityCredentials( + region: string, +): Promise { + const stsCredentials = await getAwsCredentialsFromAzureIdentity(region); + + if (!stsCredentials?.AccessKeyId || !stsCredentials?.SecretAccessKey) { + throw new Error( + 'Failed to obtain AWS credentials from Azure managed identity', + ); + } + + return { + accessKeyId: stsCredentials.AccessKeyId, + secretAccessKey: stsCredentials.SecretAccessKey, + sessionToken: stsCredentials.SessionToken, + expiration: stsCredentials.Expiration, + }; +} diff --git a/packages/openops/src/lib/aws/sts-common.ts b/packages/openops/src/lib/aws/sts-common.ts index 2ec646b327..e3012706ca 100644 --- a/packages/openops/src/lib/aws/sts-common.ts +++ b/packages/openops/src/lib/aws/sts-common.ts @@ -4,7 +4,9 @@ import { GetCallerIdentityCommand, STSClient, } from '@aws-sdk/client-sts'; +import { SharedSystemProp, system } from '@openops/server-shared'; import { v4 as uuidv4 } from 'uuid'; +import { assumeTargetRoleViaAzureFederation } from './azure-aws-federation'; import { getAwsClient } from './get-client'; export async function getAccountId( @@ -26,16 +28,31 @@ export async function assumeRole( externalId?: string, endpoint?: string | undefined | null, ): Promise { + if ( + !accessKeyId && + system.getBoolean(SharedSystemProp.AWS_ENABLE_IMPLICIT_ROLE) && + system.getBoolean(SharedSystemProp.AWS_USE_AZURE_MANAGED_IDENTITY) + ) { + return assumeTargetRoleViaAzureFederation( + defaultRegion, + roleArn, + externalId, + endpoint, + ); + } + const client = getAwsClient( STSClient, { accessKeyId, secretAccessKey, endpoint }, defaultRegion, ); + const command = new AssumeRoleCommand({ RoleArn: roleArn, ExternalId: externalId || undefined, RoleSessionName: 'openops-' + uuidv4(), }); + const response = await client.send(command); return response.Credentials; diff --git a/packages/openops/test/aws/azure-aws-federation.test.ts b/packages/openops/test/aws/azure-aws-federation.test.ts new file mode 100644 index 0000000000..2e07081d31 --- /dev/null +++ b/packages/openops/test/aws/azure-aws-federation.test.ts @@ -0,0 +1,169 @@ +import { + AssumeRoleCommand, + AssumeRoleWithWebIdentityCommand, + STSClient, +} from '@aws-sdk/client-sts'; +import { logger, system } from '@openops/server-shared'; +import { v4 as uuidv4 } from 'uuid'; +import { + assumeTargetRoleViaAzureFederation, + clearAzureFederationCache, + getAwsCredentialsFromAzureIdentity, +} from '../../src/lib/aws/azure-aws-federation'; +import { getAwsClient } from '../../src/lib/aws/get-client'; + +jest.mock('@aws-sdk/client-sts', () => { + return { + STSClient: jest.fn().mockImplementation(() => ({ + send: jest.fn(), + })), + AssumeRoleCommand: jest.fn(), + AssumeRoleWithWebIdentityCommand: jest.fn(), + }; +}); +jest.mock('@openops/server-shared'); +jest.mock('uuid'); +jest.mock('../../src/lib/aws/get-client'); + +describe('azure-aws-federation', () => { + const mockRegion = 'us-east-1'; + const mockRoleArn = 'arn:aws:iam::123456789012:role/target-role'; + const mockFederationRoleArn = + 'arn:aws:iam::123456789012:role/federation-role'; + const mockExternalId = 'external-id'; + const mockAccessToken = 'azure-access-token'; + const mockUuid = 'mock-uuid'; + + beforeEach(() => { + jest.clearAllMocks(); + clearAzureFederationCache(); + (uuidv4 as jest.Mock).mockReturnValue(mockUuid); + globalThis.fetch = jest.fn(); + }); + + describe('getAwsCredentialsFromAzureIdentity', () => { + it('should return credentials when successful', async () => { + (globalThis.fetch as jest.Mock).mockResolvedValue({ + ok: true, + json: async () => ({ access_token: mockAccessToken }), + }); + + (system.getOrThrow as jest.Mock).mockReturnValue(mockFederationRoleArn); + + const mockCredentials = { + AccessKeyId: 'AKIA', + SecretAccessKey: 'SECRET', + SessionToken: 'TOKEN', + }; + + const mockSend = jest + .fn() + .mockResolvedValue({ Credentials: mockCredentials }); + (STSClient as jest.Mock).mockImplementation(() => ({ + send: mockSend, + })); + + const result = await getAwsCredentialsFromAzureIdentity(mockRegion); + + expect(result).toEqual(mockCredentials); + expect(globalThis.fetch).toHaveBeenCalledWith( + expect.stringContaining('resource=api%3A%2F%2FAzureADTokenExchange'), + expect.objectContaining({ + headers: { Metadata: 'true' }, + }), + ); + expect(STSClient).toHaveBeenCalledWith({ region: mockRegion }); + expect(AssumeRoleWithWebIdentityCommand).toHaveBeenCalledWith({ + RoleArn: mockFederationRoleArn, + RoleSessionName: `openops-${mockUuid}`, + WebIdentityToken: mockAccessToken, + }); + }); + + it('should throw error when fetch fails', async () => { + (globalThis.fetch as jest.Mock).mockResolvedValue({ + ok: false, + status: 500, + }); + + await expect( + getAwsCredentialsFromAzureIdentity(mockRegion), + ).rejects.toThrow('Failed to get Azure managed identity token.'); + expect(logger.info).toHaveBeenCalled(); + }); + }); + + describe('assumeTargetRoleViaAzureFederation', () => { + it('should assume role and return credentials', async () => { + const mockSourceCredentials = { + AccessKeyId: 'AKIA-SOURCE', + SecretAccessKey: 'SECRET-SOURCE', + SessionToken: 'TOKEN-SOURCE', + }; + + (globalThis.fetch as jest.Mock).mockResolvedValue({ + ok: true, + json: async () => ({ access_token: mockAccessToken }), + }); + (system.getOrThrow as jest.Mock).mockReturnValue(mockFederationRoleArn); + + const mockStsClientForFederation = { + send: jest + .fn() + .mockResolvedValue({ Credentials: mockSourceCredentials }), + }; + const mockStsClientForTarget = { + send: jest + .fn() + .mockResolvedValue({ Credentials: { AccessKeyId: 'AKIA-TARGET' } }), + }; + + (STSClient as jest.Mock).mockImplementationOnce( + () => mockStsClientForFederation, + ); + (getAwsClient as jest.Mock).mockReturnValue(mockStsClientForTarget); + + const result = await assumeTargetRoleViaAzureFederation( + mockRegion, + mockRoleArn, + mockExternalId, + ); + + expect(result).toEqual({ AccessKeyId: 'AKIA-TARGET' }); + expect(getAwsClient).toHaveBeenCalledWith( + STSClient, + { + accessKeyId: mockSourceCredentials.AccessKeyId, + secretAccessKey: mockSourceCredentials.SecretAccessKey, + sessionToken: mockSourceCredentials.SessionToken, + endpoint: undefined, + }, + mockRegion, + ); + expect(AssumeRoleCommand).toHaveBeenCalledWith({ + RoleArn: mockRoleArn, + ExternalId: mockExternalId, + RoleSessionName: `openops-${mockUuid}`, + }); + }); + + it('should throw error if source credentials are missing required fields', async () => { + (globalThis.fetch as jest.Mock).mockResolvedValue({ + ok: true, + json: async () => ({ access_token: mockAccessToken }), + }); + (system.getOrThrow as jest.Mock).mockReturnValue(mockFederationRoleArn); + + const mockSend = jest + .fn() + .mockResolvedValue({ Credentials: { AccessKeyId: 'AKIA' } }); + (STSClient as jest.Mock).mockImplementation(() => ({ + send: mockSend, + })); + + await expect( + assumeTargetRoleViaAzureFederation(mockRegion, mockRoleArn), + ).rejects.toThrow('Failed to get AWS credentials from Azure identity'); + }); + }); +}); diff --git a/packages/openops/test/aws/get-client.test.ts b/packages/openops/test/aws/get-client.test.ts index c9986396b0..cf59d0e10c 100644 --- a/packages/openops/test/aws/get-client.test.ts +++ b/packages/openops/test/aws/get-client.test.ts @@ -3,9 +3,15 @@ jest.mock('@openops/server-shared', () => ({ system: mockSystem, SharedSystemProp: { AWS_ENABLE_IMPLICIT_ROLE: 'AWS_ENABLE_IMPLICIT_ROLE', + AWS_USE_AZURE_MANAGED_IDENTITY: 'AWS_USE_AZURE_MANAGED_IDENTITY', }, })); +jest.mock('../../src/lib/aws/azure-aws-federation', () => ({ + getAwsCredentialsFromAzureIdentity: jest.fn(), +})); + +import { getAwsCredentialsFromAzureIdentity } from '../../src/lib/aws/azure-aws-federation'; import { getAwsClient } from '../../src/lib/aws/get-client'; class MockServiceClient { @@ -75,7 +81,9 @@ describe('getClient', () => { }); test('should not throw an error if credentials are not required', () => { - mockSystem.getBoolean.mockReturnValue(true); + mockSystem.getBoolean.mockReturnValueOnce(true); + mockSystem.getBoolean.mockReturnValueOnce(false); + const credentials = { accessKeyId: '', secretAccessKey: '', @@ -91,4 +99,44 @@ describe('getClient', () => { mockSystem.getBoolean.mockReturnValue(false); } }); + + test('should use Azure managed identity when configured', async () => { + mockSystem.getBoolean.mockImplementation((prop) => { + if (prop === 'AWS_ENABLE_IMPLICIT_ROLE') { + return true; + } + if (prop === 'AWS_USE_AZURE_MANAGED_IDENTITY') { + return true; + } + return false; + }); + + const mockCreds = { + AccessKeyId: 'azure-key', + SecretAccessKey: 'azure-secret', + }; + (getAwsCredentialsFromAzureIdentity as jest.Mock).mockResolvedValue( + mockCreds, + ); + + const credentials = { + accessKeyId: '', + secretAccessKey: '', + }; + + try { + const client = getAwsClient(MockServiceClient, credentials, region); + expect(client).toBeInstanceOf(MockServiceClient); + expect(typeof client.config.credentials).toBe('function'); + + const result = await client.config.credentials(); + expect(result).toEqual({ + accessKeyId: 'azure-key', + secretAccessKey: 'azure-secret', + }); + expect(getAwsCredentialsFromAzureIdentity).toHaveBeenCalledWith(region); + } finally { + mockSystem.getBoolean.mockReturnValue(false); + } + }); }); diff --git a/packages/openops/test/sts-common.test.ts b/packages/openops/test/sts-common.test.ts index 364b122be4..f823c5b2a4 100644 --- a/packages/openops/test/sts-common.test.ts +++ b/packages/openops/test/sts-common.test.ts @@ -35,6 +35,24 @@ const ACCESS_KEY_ID = 'random accessKeyId'; const SECRET_ACCESS_KEY = 'random secretAccessKey'; const DEFAULT_REGION = 'random defaultRegion'; +const mockAssumeTargetRoleViaAzureFederation = jest.fn(); + +jest.mock('../src/lib/aws/azure-aws-federation', () => ({ + assumeTargetRoleViaAzureFederation: (...args: any[]) => + mockAssumeTargetRoleViaAzureFederation(...args), +})); + +const mockSystemGetBoolean = jest.fn(); +jest.mock('@openops/server-shared', () => ({ + SharedSystemProp: { + AWS_ENABLE_IMPLICIT_ROLE: 'AWS_ENABLE_IMPLICIT_ROLE', + AWS_USE_AZURE_MANAGED_IDENTITY: 'AWS_USE_AZURE_MANAGED_IDENTITY', + }, + system: { + getBoolean: (...args: any[]) => mockSystemGetBoolean(...args), + }, +})); + import { assumeRole, getAccountId } from '../src/lib/aws/sts-common'; describe('assumeRole tests', () => { @@ -123,4 +141,93 @@ describe('getAccountId tests', () => { }, }); }); + + test('should return empty string if account is missing', async () => { + mockSend.mockResolvedValueOnce({}); + const result = await getAccountId( + { + accessKeyId: ACCESS_KEY_ID, + secretAccessKey: SECRET_ACCESS_KEY, + }, + DEFAULT_REGION, + ); + + expect(result).toBe(''); + }); +}); + +describe('assumeRole with Azure Federation', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + test('should use Azure Federation when credentials are missing and enabled', async () => { + mockSystemGetBoolean.mockImplementation((prop) => { + if (prop === 'AWS_ENABLE_IMPLICIT_ROLE') { + return true; + } + if (prop === 'AWS_USE_AZURE_MANAGED_IDENTITY') { + return true; + } + return false; + }); + mockAssumeTargetRoleViaAzureFederation.mockResolvedValue( + 'azure credentials', + ); + + const result = await assumeRole( + '', + '', + DEFAULT_REGION, + 'some role arn', + 'external id', + 'some endpoint', + ); + + expect(result).toBe('azure credentials'); + expect(mockAssumeTargetRoleViaAzureFederation).toHaveBeenCalledWith( + DEFAULT_REGION, + 'some role arn', + 'external id', + 'some endpoint', + ); + expect(mockCreateStsClient).not.toHaveBeenCalled(); + }); + + test('should NOT use Azure Federation when AWS_ENABLE_IMPLICIT_ROLE is false', async () => { + mockSystemGetBoolean.mockImplementation((prop) => { + if (prop === 'AWS_ENABLE_IMPLICIT_ROLE') { + return false; + } + if (prop === 'AWS_USE_AZURE_MANAGED_IDENTITY') { + return true; + } + return false; + }); + + await expect( + assumeRole('', '', DEFAULT_REGION, 'some role arn', 'external id'), + ).rejects.toThrow( + 'AWS credentials are required, please provide accessKeyId and secretAccessKey', + ); + + expect(mockAssumeTargetRoleViaAzureFederation).not.toHaveBeenCalled(); + }); + + test('should NOT use Azure Federation when AWS_USE_AZURE_MANAGED_IDENTITY is false', async () => { + mockSystemGetBoolean.mockImplementation((prop) => { + if (prop === 'AWS_ENABLE_IMPLICIT_ROLE') { + return true; + } + if (prop === 'AWS_USE_AZURE_MANAGED_IDENTITY') { + return false; + } + return false; + }); + + await assumeRole('', '', DEFAULT_REGION, 'some role arn', 'external id'); + + expect(mockAssumeTargetRoleViaAzureFederation).not.toHaveBeenCalled(); + expect(mockCreateStsClient).toHaveBeenCalled(); + }); }); diff --git a/packages/server/shared/src/lib/system/system-prop.ts b/packages/server/shared/src/lib/system/system-prop.ts index eb03efee51..37dd3ec847 100644 --- a/packages/server/shared/src/lib/system/system-prop.ts +++ b/packages/server/shared/src/lib/system/system-prop.ts @@ -152,6 +152,8 @@ export enum SharedSystemProp { SLACK_ENABLE_INTERACTIONS = 'SLACK_ENABLE_INTERACTIONS', AWS_ENABLE_IMPLICIT_ROLE = 'AWS_ENABLE_IMPLICIT_ROLE', + AWS_USE_AZURE_MANAGED_IDENTITY = 'AWS_USE_AZURE_MANAGED_IDENTITY', + AWS_AZURE_FEDERATION_ROLE_ARN = 'AWS_AZURE_FEDERATION_ROLE_ARN', LANGFUSE_SECRET_KEY = 'LANGFUSE_SECRET_KEY', LANGFUSE_PUBLIC_KEY = 'LANGFUSE_PUBLIC_KEY',