diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/DefaultSrtSigner.java b/sdk/src/main/java/io/opentdf/platform/sdk/DefaultSrtSigner.java new file mode 100644 index 00000000..102f9986 --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/DefaultSrtSigner.java @@ -0,0 +1,34 @@ +package io.opentdf.platform.sdk; + +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.JWSHeader; +import com.nimbusds.jose.crypto.RSASSASigner; +import com.nimbusds.jose.jwk.RSAKey; + +final class DefaultSrtSigner implements SrtSigner { + private static final JWSHeader HEADER = new JWSHeader.Builder(JWSAlgorithm.RS256).build(); + private final RSASSASigner signer; + + DefaultSrtSigner(RSAKey rsaKey) { + try { + this.signer = new RSASSASigner(rsaKey); + } catch (JOSEException e) { + throw new SDKException("error creating SRT signer", e); + } + } + + @Override + public byte[] sign(byte[] input) throws java.security.GeneralSecurityException { + try { + return signer.sign(HEADER, input).decode(); + } catch (JOSEException e) { + throw new java.security.GeneralSecurityException("error signing SRT payload", e); + } + } + + @Override + public String alg() { + return JWSAlgorithm.RS256.getName(); + } +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java b/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java index b1c00085..91ada8b0 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java @@ -7,8 +7,9 @@ import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSHeader; -import com.nimbusds.jose.crypto.RSASSASigner; -import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.JWSSigner; +import com.nimbusds.jose.util.Base64URL; +import com.nimbusds.jose.jca.JCAContext; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; import io.opentdf.platform.kas.AccessServiceClient; @@ -28,6 +29,7 @@ import java.util.Collections; import java.util.Date; import java.util.HashMap; +import java.util.Set; import java.util.function.BiFunction; import static io.opentdf.platform.sdk.TDF.GLOBAL_KEY_SALT; @@ -42,7 +44,7 @@ class KASClient implements SDK.KAS { private final OkHttpClient httpClient; private final BiFunction protocolClientFactory; private final boolean usePlaintext; - private final RSASSASigner signer; + private final JWSSigner signer; private AsymDecryption decryptor; private String clientPublicKey; private KASKeyCache kasKeyCache; @@ -53,17 +55,16 @@ class KASClient implements SDK.KAS { * A client that communicates with KAS * * communicate - * @param dpopKey + * @param srtSigner */ - KASClient(OkHttpClient httpClient, BiFunction protocolClientFactory, RSAKey dpopKey, boolean usePlaintext) { + KASClient(OkHttpClient httpClient, BiFunction protocolClientFactory, SrtSigner srtSigner, boolean usePlaintext) { this.httpClient = httpClient; this.protocolClientFactory = protocolClientFactory; this.usePlaintext = usePlaintext; - try { - this.signer = new RSASSASigner(dpopKey); - } catch (JOSEException e) { - throw new SDKException("error creating dpop signer", e); + if (srtSigner == null) { + throw new SDKException("srtSigner must be provided"); } + this.signer = new SrtJwsSigner(srtSigner); this.kasKeyCache = new KASKeyCache(); } @@ -197,4 +198,40 @@ synchronized AccessServiceClient getStub(String url) { return new AccessServiceClient(client); }); } + + private static final class SrtJwsSigner implements JWSSigner { + private static final JWSAlgorithm EXPECTED_ALG = JWSAlgorithm.RS256; + private final SrtSigner srtSigner; + private final JCAContext jcaContext = new JCAContext(); + + private SrtJwsSigner(SrtSigner srtSigner) { + this.srtSigner = srtSigner; + if (!EXPECTED_ALG.getName().equals(srtSigner.alg())) { + throw new SDKException("unsupported SRT signing algorithm: " + srtSigner.alg()); + } + } + + @Override + public Base64URL sign(JWSHeader header, byte[] signingInput) throws JOSEException { + if (!EXPECTED_ALG.equals(header.getAlgorithm())) { + throw new JOSEException("SRT signer algorithm mismatch: " + header.getAlgorithm()); + } + + try { + return Base64URL.encode(srtSigner.sign(signingInput)); + } catch (java.security.GeneralSecurityException e) { + throw new JOSEException("error signing SRT payload", e); + } + } + + @Override + public Set supportedJWSAlgorithms() { + return Collections.singleton(EXPECTED_ALG); + } + + @Override + public JCAContext getJCAContext() { + return jcaContext; + } + } } diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java b/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java index 1194bdd7..9e80b0ab 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java @@ -31,6 +31,7 @@ public class SDK implements AutoCloseable { private final Interceptor authInterceptor; private final String platformUrl; private final ProtocolClient platformServicesClient; + private final SrtSigner srtSigner; /** * Closes the SDK, including its associated services. @@ -87,12 +88,13 @@ public Optional getAuthInterceptor() { return Optional.ofNullable(authInterceptor); } - SDK(Services services, TrustManager trustManager, Interceptor authInterceptor, ProtocolClient platformServicesClient, String platformUrl) { + SDK(Services services, TrustManager trustManager, Interceptor authInterceptor, ProtocolClient platformServicesClient, String platformUrl, SrtSigner srtSigner) { this.platformUrl = platformUrl; this.services = services; this.trustManager = trustManager; this.authInterceptor = authInterceptor; this.platformServicesClient = platformServicesClient; + this.srtSigner = srtSigner; } public Services getServices() { @@ -122,6 +124,10 @@ public ProtocolClient getPlatformServicesClient() { return this.platformServicesClient; } + public Optional getSrtSigner() { + return Optional.ofNullable(srtSigner); + } + /** * Checks to see if this has the structure of a Z-TDF in that it is a zip file * containing diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java b/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java index 698953f0..bb350b66 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java @@ -66,6 +66,7 @@ public class SDKBuilder { private SSLFactory sslFactory; private AuthorizationGrant authzGrant; private ProtocolType protocolType = ProtocolType.CONNECT; + private SrtSigner srtSigner; private static final Logger logger = LoggerFactory.getLogger(SDKBuilder.class); @@ -177,6 +178,11 @@ public SDKBuilder protocol(ProtocolType protocolType) { return this; } + public SDKBuilder srtSigner(SrtSigner signer) { + this.srtSigner = signer; + return this; + } + private Interceptor getAuthInterceptor(RSAKey rsaKey) { if (platformEndpoint == null) { throw new SDKException("cannot build an SDK without specifying the platform endpoint"); @@ -236,14 +242,16 @@ static class ServicesAndInternals { final Interceptor interceptor; final TrustManager trustManager; final ProtocolClient protocolClient; + final SrtSigner srtSigner; final SDK.Services services; - ServicesAndInternals(Interceptor interceptor, TrustManager trustManager, SDK.Services services, ProtocolClient protocolClient) { + ServicesAndInternals(Interceptor interceptor, TrustManager trustManager, SDK.Services services, ProtocolClient protocolClient, SrtSigner srtSigner) { this.interceptor = interceptor; this.trustManager = trustManager; this.services = services; this.protocolClient = protocolClient; + this.srtSigner = srtSigner; } } @@ -267,7 +275,8 @@ ServicesAndInternals buildServices() { this.platformEndpoint = AddressNormalizer.normalizeAddress(this.platformEndpoint, this.usePlainText); var authInterceptor = getAuthInterceptor(dpopKey); - var kasClient = getKASClient(dpopKey, authInterceptor); + var srtSignerToUse = this.srtSigner == null ? new DefaultSrtSigner(dpopKey) : this.srtSigner; + var kasClient = getKASClient(srtSignerToUse, authInterceptor); var httpClient = getHttpClient(); var client = getProtocolClient(platformEndpoint, httpClient, authInterceptor); var attributeService = new AttributesServiceClient(client); @@ -337,18 +346,19 @@ public SDK.KAS kas() { authInterceptor, sslFactory == null ? null : sslFactory.getTrustManager().orElse(null), services, - client); + client, + srtSignerToUse); } @Nonnull - private KASClient getKASClient(RSAKey dpopKey, Interceptor interceptor) { + private KASClient getKASClient(SrtSigner srtSigner, Interceptor interceptor) { BiFunction protocolClientFactory = (OkHttpClient client, String address) -> getProtocolClient(address, client, interceptor); - return new KASClient(getHttpClient(), protocolClientFactory, dpopKey, usePlainText); + return new KASClient(getHttpClient(), protocolClientFactory, srtSigner, usePlainText); } public SDK build() { var services = buildServices(); - return new SDK(services.services, services.trustManager, services.interceptor, services.protocolClient, platformEndpoint); + return new SDK(services.services, services.trustManager, services.interceptor, services.protocolClient, platformEndpoint, services.srtSigner); } private ProtocolClient getUnauthenticatedProtocolClient(String endpoint, OkHttpClient httpClient) { diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SrtSigner.java b/sdk/src/main/java/io/opentdf/platform/sdk/SrtSigner.java new file mode 100644 index 00000000..b9a04092 --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SrtSigner.java @@ -0,0 +1,7 @@ +package io.opentdf.platform.sdk; + +public interface SrtSigner { + byte[] sign(byte[] input) throws java.security.GeneralSecurityException; + + String alg(); +} diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/KASClientTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/KASClientTest.java index a4f85eb0..013f3bc4 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/KASClientTest.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/KASClientTest.java @@ -9,7 +9,10 @@ import com.google.gson.Gson; import com.google.protobuf.ByteString; import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.JWSHeader; import com.nimbusds.jose.JWSVerifier; +import com.nimbusds.jose.crypto.RSASSASigner; import com.nimbusds.jose.crypto.RSASSAVerifier; import com.nimbusds.jose.jwk.RSAKey; import com.nimbusds.jwt.SignedJWT; @@ -33,9 +36,11 @@ import java.util.Random; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; +import java.nio.charset.StandardCharsets; import static io.opentdf.platform.sdk.SDKBuilderTest.getRandomPort; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; public class KASClientTest { OkHttpClient httpClient = new OkHttpClient.Builder() @@ -68,7 +73,7 @@ public void publicKey(PublicKeyRequest request, StreamObserver respons rewrapServer = startServer(accessService); byte[] plaintextKey; byte[] rewrapResponse; - try (var kas = new KASClient(httpClient, aclientFactory, dpopKey, true)) { + try (var kas = new KASClient(httpClient, aclientFactory, new DefaultSrtSigner(dpopKey), true)) { Manifest.KeyAccess keyAccess = new Manifest.KeyAccess(); keyAccess.url = "http://localhost:" + rewrapServer.getPort(); @@ -182,6 +187,131 @@ public void rewrap(RewrapRequest request, StreamObserver respons } } + @Test + void testCustomSrtSignerIsUsed() throws IOException { + var serverKeypair = CryptoUtils.generateRSAKeypair(); + var signingInput = new AtomicReference(); + var signedToken = new AtomicReference(); + var signingKeypair = CryptoUtils.generateRSAKeypair(); + var signingKey = new RSAKey.Builder((RSAPublicKey) signingKeypair.getPublic()) + .privateKey(signingKeypair.getPrivate()) + .build(); + SrtSigner srtSigner = new SrtSigner() { + @Override + public byte[] sign(byte[] input) { + signingInput.set(input); + try { + return new RSASSASigner(signingKey) + .sign(new JWSHeader.Builder(JWSAlgorithm.RS256).build(), input) + .decode(); + } catch (JOSEException e) { + throw new AssertionError("Signing failed unexpectedly in test", e); + } + } + + @Override + public String alg() { + return "RS256"; + } + }; + + AccessServiceGrpc.AccessServiceImplBase accessService = new AccessServiceGrpc.AccessServiceImplBase() { + @Override + public void rewrap(RewrapRequest request, StreamObserver responseObserver) { + signedToken.set(request.getSignedRequestToken()); + SignedJWT signedJWT; + try { + signedJWT = SignedJWT.parse(request.getSignedRequestToken()); + JWSVerifier verifier = new RSASSAVerifier(new RSAKey.Builder((RSAPublicKey) signingKeypair.getPublic()).build()); + if (!signedJWT.verify(verifier)) { + responseObserver.onError(new JOSEException("Unable to verify signature")); + responseObserver.onCompleted(); + return; + } + } catch (ParseException e) { + responseObserver.onError(e); + responseObserver.onCompleted(); + return; + } catch (JOSEException e) { + responseObserver.onError(e); + responseObserver.onCompleted(); + return; + } + + String requestBodyJson; + try { + requestBodyJson = signedJWT.getJWTClaimsSet().getStringClaim("requestBody"); + } catch (ParseException e) { + responseObserver.onError(e); + responseObserver.onCompleted(); + return; + } + + var gson = new Gson(); + var req = gson.fromJson(requestBodyJson, KASClient.RewrapRequestBody.class); + + byte[] decryptedKey; + try { + decryptedKey = new AsymDecryption(serverKeypair.getPrivate()) + .decrypt(Base64.getDecoder().decode(req.keyAccess.wrappedKey)); + } catch (Exception e) { + responseObserver.onError(e); + responseObserver.onCompleted(); + return; + } + var encryptedKey = new AsymEncryption(req.clientPublicKey).encrypt(decryptedKey); + + responseObserver.onNext( + RewrapResponse.newBuilder().setEntityWrappedKey(ByteString.copyFrom(encryptedKey)).build()); + responseObserver.onCompleted(); + } + }; + + Server rewrapServer = null; + try { + rewrapServer = startServer(accessService); + byte[] plaintextKey; + byte[] rewrapResponse; + try (var kas = new KASClient(httpClient, aclientFactory, srtSigner, true)) { + Manifest.KeyAccess keyAccess = new Manifest.KeyAccess(); + keyAccess.url = "http://localhost:" + rewrapServer.getPort(); + plaintextKey = new byte[32]; + new Random().nextBytes(plaintextKey); + var serverWrappedKey = new AsymEncryption(serverKeypair.getPublic()).encrypt(plaintextKey); + keyAccess.wrappedKey = Base64.getEncoder().encodeToString(serverWrappedKey); + + rewrapResponse = kas.unwrap(keyAccess, "the policy", KeyType.RSA2048Key); + } + assertThat(rewrapResponse).containsExactly(plaintextKey); + assertThat(signingInput.get()).isNotNull(); + var tokenParts = signedToken.get().split("\\.", 3); + assertThat(tokenParts.length).isEqualTo(3); + var expectedSigningInput = (tokenParts[0] + "." + tokenParts[1]).getBytes(StandardCharsets.US_ASCII); + assertThat(signingInput.get()).containsExactly(expectedSigningInput); + } finally { + if (rewrapServer != null) { + rewrapServer.shutdownNow(); + } + } + } + + @Test + void testSrtSignerAlgMismatchRejected() { + SrtSigner srtSigner = new SrtSigner() { + @Override + public byte[] sign(byte[] input) { + return new byte[0]; + } + + @Override + public String alg() { + return "none"; + } + }; + + assertThrows(SDKException.class, () -> new KASClient(httpClient, aclientFactory, srtSigner, true)); + } + @Test void testAddressNormalizationWithHTTPSClient() { var lastAddress = new AtomicReference(); @@ -191,7 +321,7 @@ void testAddressNormalizationWithHTTPSClient() { var httpsKASClient = new KASClient(httpClient, (client, addr) -> { lastAddress.set(addr); return aclientFactory.apply(client, addr); - }, dpopKey, false); + }, new DefaultSrtSigner(dpopKey), false); var stub = httpsKASClient.getStub("http://localhost:8080"); assertThat(lastAddress.get()).isEqualTo("https://localhost:8080"); @@ -209,7 +339,7 @@ void testAddressNormalizationWithInsecureHTTPClient() { var httpsKASClient = new KASClient(httpClient, (client, addr) -> { lastAddress.set(addr); return aclientFactory.apply(client, addr); - }, dpopKey, true); + }, new DefaultSrtSigner(dpopKey), true); var c1 = httpsKASClient.getStub("http://example.org"); assertThat(lastAddress.get()).isEqualTo("http://example.org:80"); diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java index 3a07c17b..cebc9928 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java @@ -108,6 +108,55 @@ void testCreatingSDKServicesPlainText() throws Exception { sdkServicesSetup(false, false); } + @Test + void testInjectedSrtSignerIsExposedOnSdk() throws Exception { + WellKnownServiceGrpc.WellKnownServiceImplBase wellKnownService = new WellKnownServiceGrpc.WellKnownServiceImplBase() { + @Override + public void getWellKnownConfiguration(GetWellKnownConfigurationRequest request, + StreamObserver responseObserver) { + responseObserver.onNext(GetWellKnownConfigurationResponse.getDefaultInstance()); + responseObserver.onCompleted(); + } + }; + + Server platformServices = null; + SrtSigner signer = new SrtSigner() { + @Override + public byte[] sign(byte[] input) { + return new byte[] { 7 }; + } + + @Override + public String alg() { + return "RS256"; + } + }; + + try { + platformServices = ServerBuilder + .forPort(getRandomPort()) + .directExecutor() + .addService(wellKnownService) + .build() + .start(); + + var sdk = SDKBuilder.newBuilder() + .clientSecret("user", "password") + .platformEndpoint("http://localhost:" + platformServices.getPort()) + .useInsecurePlaintextConnection(true) + .protocol(ProtocolType.GRPC) + .srtSigner(signer) + .build(); + + assertThat(sdk.getSrtSigner()).isPresent(); + assertThat(sdk.getSrtSigner().get()).isSameAs(signer); + } finally { + if (platformServices != null) { + platformServices.shutdownNow(); + } + } + } + void sdkServicesSetup(boolean useSSLPlatform, boolean useSSLIDP) throws Exception { HeldCertificate rootCertificate = new HeldCertificate.Builder() diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/SDKTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/SDKTest.java index 44e30dce..289d2f42 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/SDKTest.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/SDKTest.java @@ -32,7 +32,7 @@ void testExaminingValidZTDF() throws IOException { @Test void testReadingProtocolClient() { var platformServicesClient = mock(ProtocolClient.class); - var sdk = new SDK(new FakeServicesBuilder().build(), null, null, platformServicesClient, null); + var sdk = new SDK(new FakeServicesBuilder().build(), null, null, platformServicesClient, null, null); assertThat(sdk.getPlatformServicesClient()).isSameAs(platformServicesClient); } @@ -52,7 +52,7 @@ void testGettingBaseKey() { .thenReturn(TestUtil.successfulUnaryCall(response)); var services = new FakeServicesBuilder().setWellknownService(wellknownService).build(); - var sdk = new SDK(services, null, null, platformServicesClient, null); + var sdk = new SDK(services, null, null, platformServicesClient, null, null); var baseKey = sdk.getBaseKey(); assertThat(baseKey).isPresent(); @@ -68,7 +68,7 @@ void testAuthorizationServiceClientV2() { var platformServicesClient = mock(ProtocolClient.class); io.opentdf.platform.authorization.v2.AuthorizationServiceClientInterface authSvcV2 = mock(io.opentdf.platform.authorization.v2.AuthorizationServiceClientInterface.class); var fakeServiceBuilder = new FakeServicesBuilder().setAuthorizationServiceV2(authSvcV2).build(); - var sdk = new SDK(fakeServiceBuilder, null, null, platformServicesClient, null); + var sdk = new SDK(fakeServiceBuilder, null, null, platformServicesClient, null, null); assertThat(sdk.getServices().authorizationV2()).isSameAs(fakeServiceBuilder.authorizationV2()); }