diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java index 9cb9a867118..d26990b3b89 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java @@ -73,16 +73,27 @@ final class FileWatcherCertificateProvider extends CertificateProvider implement this.scheduledExecutorService = checkNotNull(scheduledExecutorService, "scheduledExecutorService"); this.timeProvider = checkNotNull(timeProvider, "timeProvider"); - this.certFile = Paths.get(checkNotNull(certFile, "certFile")); - this.keyFile = Paths.get(checkNotNull(keyFile, "keyFile")); - checkArgument((trustFile != null || spiffeTrustMapFile != null), - "either trustFile or spiffeTrustMapFile must be present"); + checkArgument(certFile == null || keyFile != null, + "keyFile must be set when certFile is set"); + checkArgument(keyFile == null || certFile != null, + "certFile must be set when keyFile is set"); + checkArgument(certFile != null || trustFile != null || spiffeTrustMapFile != null, + "at least one of identity (certFile/keyFile), trustFile, or spiffeTrustMapFile must be" + + " present"); + if (notifyCertUpdates && certFile == null) { + // UnsupportedOperationException so CertificateProviderStore.createOrGetProvider's catch + // block falls back to notifyCertUpdates=false for roots-only configs. + throw new UnsupportedOperationException( + "certFile/keyFile must be set when notifyCertUpdates is true"); + } + this.certFile = certFile == null ? null : Paths.get(certFile); + this.keyFile = keyFile == null ? null : Paths.get(keyFile); if (spiffeTrustMapFile != null) { this.spiffeTrustMapFile = Paths.get(spiffeTrustMapFile); this.trustFile = null; } else { this.spiffeTrustMapFile = null; - this.trustFile = Paths.get(trustFile); + this.trustFile = trustFile == null ? null : Paths.get(trustFile); } this.refreshIntervalInSeconds = refreshIntervalInSeconds; } @@ -112,28 +123,31 @@ private synchronized void scheduleNextRefreshCertificate(long delayInSeconds) { @VisibleForTesting void checkAndReloadCertificates() { try { - try { - FileTime currentCertTime = Files.getLastModifiedTime(certFile); - FileTime currentKeyTime = Files.getLastModifiedTime(keyFile); - if (!currentCertTime.equals(lastModifiedTimeCert) - || !currentKeyTime.equals(lastModifiedTimeKey)) { - byte[] certFileContents = Files.readAllBytes(certFile); - byte[] keyFileContents = Files.readAllBytes(keyFile); - FileTime currentCertTime2 = Files.getLastModifiedTime(certFile); - FileTime currentKeyTime2 = Files.getLastModifiedTime(keyFile); - if (currentCertTime2.equals(currentCertTime) && currentKeyTime2.equals(currentKeyTime)) { - try (ByteArrayInputStream certStream = new ByteArrayInputStream(certFileContents); - ByteArrayInputStream keyStream = new ByteArrayInputStream(keyFileContents)) { - PrivateKey privateKey = CertificateUtils.getPrivateKey(keyStream); - X509Certificate[] certs = CertificateUtils.toX509Certificates(certStream); - getWatcher().updateCertificate(privateKey, Arrays.asList(certs)); + if (certFile != null) { + try { + FileTime currentCertTime = Files.getLastModifiedTime(certFile); + FileTime currentKeyTime = Files.getLastModifiedTime(keyFile); + if (!currentCertTime.equals(lastModifiedTimeCert) + || !currentKeyTime.equals(lastModifiedTimeKey)) { + byte[] certFileContents = Files.readAllBytes(certFile); + byte[] keyFileContents = Files.readAllBytes(keyFile); + FileTime currentCertTime2 = Files.getLastModifiedTime(certFile); + FileTime currentKeyTime2 = Files.getLastModifiedTime(keyFile); + if (currentCertTime2.equals(currentCertTime) + && currentKeyTime2.equals(currentKeyTime)) { + try (ByteArrayInputStream certStream = new ByteArrayInputStream(certFileContents); + ByteArrayInputStream keyStream = new ByteArrayInputStream(keyFileContents)) { + PrivateKey privateKey = CertificateUtils.getPrivateKey(keyStream); + X509Certificate[] certs = CertificateUtils.toX509Certificates(certStream); + getWatcher().updateCertificate(privateKey, Arrays.asList(certs)); + } + lastModifiedTimeCert = currentCertTime; + lastModifiedTimeKey = currentKeyTime; } - lastModifiedTimeCert = currentCertTime; - lastModifiedTimeKey = currentKeyTime; } + } catch (Throwable t) { + generateErrorIfCurrentCertExpired(t); } - } catch (Throwable t) { - generateErrorIfCurrentCertExpired(t); } try { if (spiffeTrustMapFile != null) { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java index e4871dc4c84..add90e9d70b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java @@ -17,7 +17,6 @@ package io.grpc.xds.internal.security.certprovider; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -82,6 +81,13 @@ public CertificateProvider createCertificateProvider( Object config, CertificateProvider.DistributorWatcher watcher, boolean notifyCertUpdates) { Config configObj = validateAndTranslateConfig(config); + if (notifyCertUpdates && configObj.certFile == null) { + // Throw UnsupportedOperationException so CertificateProviderStore.createOrGetProvider's + // catch block falls back to notifyCertUpdates=false for roots-only configs. + throw new UnsupportedOperationException( + "'" + CERT_FILE_KEY + "' and '" + KEY_FILE_KEY + + "' are required when notifyCertUpdates is true"); + } return fileWatcherCertificateProviderFactory.create( watcher, notifyCertUpdates, @@ -94,30 +100,43 @@ public CertificateProvider createCertificateProvider( timeProvider); } - private static String checkForNullAndGet(Map map, String key) { - return checkNotNull(JsonUtil.getString(map, key), "'" + key + "' is required in the config"); - } - private static Config validateAndTranslateConfig(Object config) { checkArgument(config instanceof Map, "Only Map supported for config"); @SuppressWarnings("unchecked") Map map = (Map)config; Config configObj = new Config(); - configObj.certFile = checkForNullAndGet(map, CERT_FILE_KEY); - configObj.keyFile = checkForNullAndGet(map, KEY_FILE_KEY); + configObj.certFile = JsonUtil.getString(map, CERT_FILE_KEY); + configObj.keyFile = JsonUtil.getString(map, KEY_FILE_KEY); + if (configObj.certFile != null && configObj.keyFile == null) { + throw new NullPointerException( + "'" + KEY_FILE_KEY + "' is required when '" + CERT_FILE_KEY + "' is set"); + } + if (configObj.keyFile != null && configObj.certFile == null) { + throw new NullPointerException( + "'" + CERT_FILE_KEY + "' is required when '" + KEY_FILE_KEY + "' is set"); + } if (enableSpiffe) { - if (!map.containsKey(ROOT_FILE_KEY) && !map.containsKey(SPIFFE_TRUST_MAP_FILE_KEY)) { - throw new NullPointerException( - String.format("either '%s' or '%s' is required in the config", - ROOT_FILE_KEY, SPIFFE_TRUST_MAP_FILE_KEY)); - } if (map.containsKey(SPIFFE_TRUST_MAP_FILE_KEY)) { configObj.spiffeTrustMapFile = JsonUtil.getString(map, SPIFFE_TRUST_MAP_FILE_KEY); - } else { + } else if (map.containsKey(ROOT_FILE_KEY)) { configObj.rootFile = JsonUtil.getString(map, ROOT_FILE_KEY); } + if (configObj.certFile == null + && configObj.rootFile == null + && configObj.spiffeTrustMapFile == null) { + throw new NullPointerException( + String.format( + "config must specify ('%s' and '%s'), '%s', or '%s'", + CERT_FILE_KEY, KEY_FILE_KEY, ROOT_FILE_KEY, SPIFFE_TRUST_MAP_FILE_KEY)); + } } else { - configObj.rootFile = checkForNullAndGet(map, ROOT_FILE_KEY); + configObj.rootFile = JsonUtil.getString(map, ROOT_FILE_KEY); + if (configObj.certFile == null && configObj.rootFile == null) { + throw new NullPointerException( + String.format( + "config must specify ('%s' and '%s') or '%s'", + CERT_FILE_KEY, KEY_FILE_KEY, ROOT_FILE_KEY)); + } } String refreshIntervalString = JsonUtil.getString(map, REFRESH_INTERVAL_KEY); if (refreshIntervalString != null) { diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java index 304a2dd5441..c76dc070fa7 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java @@ -197,7 +197,7 @@ public void createProvider_zeroRefreshInterval() throws IOException { } @Test - public void createProvider_missingCert_expectException() throws IOException { + public void createProvider_keyWithoutCert_expectException() throws IOException { CertificateProvider.DistributorWatcher distWatcher = new CertificateProvider.DistributorWatcher(); @SuppressWarnings("unchecked") @@ -206,12 +206,13 @@ public void createProvider_missingCert_expectException() throws IOException { provider.createCertificateProvider(map, distWatcher, true); fail("exception expected"); } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'certificate_file' is required in the config"); + assertThat(npe).hasMessageThat() + .isEqualTo("'certificate_file' is required when 'private_key_file' is set"); } } @Test - public void createProvider_missingKey_expectException() throws IOException { + public void createProvider_certWithoutKey_expectException() throws IOException { CertificateProvider.DistributorWatcher distWatcher = new CertificateProvider.DistributorWatcher(); @SuppressWarnings("unchecked") @@ -220,19 +221,137 @@ public void createProvider_missingKey_expectException() throws IOException { provider.createCertificateProvider(map, distWatcher, true); fail("exception expected"); } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'private_key_file' is required in the config"); + assertThat(npe).hasMessageThat() + .isEqualTo("'private_key_file' is required when 'certificate_file' is set"); } } @Test - public void createProvider_missingRoot_expectException() throws IOException { - String expectedMessage = enableSpiffe ? "either 'ca_certificate_file' or " - + "'spiffe_trust_bundle_map_file' is required in the config" - : "'ca_certificate_file' is required in the config"; + public void createProvider_identityOnly_succeeds() throws IOException { CertificateProvider.DistributorWatcher distWatcher = new CertificateProvider.DistributorWatcher(); @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_ROOT_AND_SPIFFE_CONFIG); + Map map = (Map) JsonParser.parse(IDENTITY_ONLY_CONFIG); + ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); + when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); + provider.createCertificateProvider(map, distWatcher, true); + verify(fileWatcherCertificateProviderFactory, times(1)) + .create( + eq(distWatcher), + eq(true), + eq("/var/run/gke-spiffe/certs/certificates.pem"), + eq("/var/run/gke-spiffe/certs/private_key.pem"), + eq(null), + eq(null), + eq(600L), + eq(mockService), + eq(timeProvider)); + } + + @Test + public void createProvider_caRootsOnly_succeeds() throws IOException { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(CA_ROOTS_ONLY_CONFIG); + ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); + when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); + provider.createCertificateProvider(map, distWatcher, false); + verify(fileWatcherCertificateProviderFactory, times(1)) + .create( + eq(distWatcher), + eq(false), + eq(null), + eq(null), + eq("/var/run/gke-spiffe/certs/ca_certificates.pem"), + eq(null), + eq(600L), + eq(mockService), + eq(timeProvider)); + } + + @Test + public void createProvider_spiffeRootsOnly_succeeds() throws IOException { + Assume.assumeTrue(enableSpiffe); + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(SPIFFE_ROOTS_ONLY_CONFIG); + ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); + when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); + provider.createCertificateProvider(map, distWatcher, false); + verify(fileWatcherCertificateProviderFactory, times(1)) + .create( + eq(distWatcher), + eq(false), + eq(null), + eq(null), + eq(null), + eq("/var/run/gke-spiffe/certs/spiffe_bundle.json"), + eq(600L), + eq(mockService), + eq(timeProvider)); + } + + @Test + public void createProvider_rootsOnlyWithNotifyCertUpdates_throwsUnsupportedOperation() + throws IOException { + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(CA_ROOTS_ONLY_CONFIG); + try { + provider.createCertificateProvider(map, distWatcher, true); + fail("exception expected"); + } catch (UnsupportedOperationException uoe) { + assertThat(uoe).hasMessageThat().isEqualTo( + "'certificate_file' and 'private_key_file' are required when notifyCertUpdates is true"); + } + } + + @Test + public void rootsOnlyConfig_storeFallbackProbesTrueThenFalse() throws IOException { + // Regression: CertificateProviderStore.createOrGetProvider always probes notifyCertUpdates=true + // first and only falls back to the caller's value (false) when UnsupportedOperationException is + // thrown. Any other exception type would escape the try/catch and break legitimate roots-only + // configs. + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(CA_ROOTS_ONLY_CONFIG); + try { + provider.createCertificateProvider(map, distWatcher, true); + fail("first probe must throw UnsupportedOperationException so the store falls back"); + } catch (UnsupportedOperationException expected) { + // expected — this is what the store's catch block in createOrGetProvider relies on + } + ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); + when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); + provider.createCertificateProvider(map, distWatcher, false); + verify(fileWatcherCertificateProviderFactory, times(1)) + .create( + eq(distWatcher), + eq(false), + eq(null), + eq(null), + eq("/var/run/gke-spiffe/certs/ca_certificates.pem"), + eq(null), + eq(600L), + eq(mockService), + eq(timeProvider)); + } + + @Test + public void createProvider_emptyConfig_expectException() throws IOException { + String expectedMessage = enableSpiffe + ? "config must specify ('certificate_file' and 'private_key_file'), 'ca_certificate_file'," + + " or 'spiffe_trust_bundle_map_file'" + : "config must specify ('certificate_file' and 'private_key_file') or" + + " 'ca_certificate_file'"; + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(EMPTY_CONFIG); try { provider.createCertificateProvider(map, distWatcher, true); fail("exception expected"); @@ -286,12 +405,25 @@ public void createProvider_missingRoot_expectException() throws IOException { + " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates.pem\"" + " }"; - private static final String MISSING_ROOT_AND_SPIFFE_CONFIG = + private static final String IDENTITY_ONLY_CONFIG = "{\n" + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates.pem\"," + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\"" + " }"; + private static final String CA_ROOTS_ONLY_CONFIG = + "{\n" + + " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates.pem\"" + + " }"; + + private static final String SPIFFE_ROOTS_ONLY_CONFIG = + "{\n" + + " \"spiffe_trust_bundle_map_file\":" + + " \"/var/run/gke-spiffe/certs/spiffe_bundle.json\"" + + " }"; + + private static final String EMPTY_CONFIG = "{}"; + private static final String ZERO_REFRESH_INTERVAL = "{\n" + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates2.pem\"," diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java index f6fdc51dece..7e610292ec3 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java @@ -199,6 +199,94 @@ public void allUpdateSecondTime() throws IOException, CertificateException, Inte verifyTimeServiceAndScheduledFuture(); } + @Test + public void identityOnlyProvider_reloadsOnlyCert() + throws IOException, CertificateException, InterruptedException { + provider = new FileWatcherCertificateProvider(watcher, true, certFile, keyFile, null, null, + 600L, timeService, timeProvider); + TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, null, false, false, false, false); + provider.checkAndReloadCertificates(); + verifyWatcherUpdates(CLIENT_PEM_FILE, null, null); + verifyTimeServiceAndScheduledFuture(); + } + + @Test + public void caRootsOnlyProvider_reloadsOnlyRoots() + throws IOException, CertificateException, InterruptedException { + provider = new FileWatcherCertificateProvider(watcher, false, null, null, rootFile, null, + 600L, timeService, timeProvider); + TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + populateTarget(null, null, CA_PEM_FILE, null, false, false, false, false); + provider.checkAndReloadCertificates(); + verifyWatcherUpdates(null, CA_PEM_FILE, null); + verifyTimeServiceAndScheduledFuture(); + } + + @Test + public void spiffeRootsOnlyProvider_reloadsOnlySpiffeMap() throws Exception { + provider = new FileWatcherCertificateProvider(watcher, false, null, null, null, + spiffeTrustMapFile, 600L, timeService, timeProvider); + TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + populateTarget(null, null, null, SPIFFE_TRUST_MAP_FILE, false, false, false, false); + provider.checkAndReloadCertificates(); + verifyWatcherUpdates(null, null, SPIFFE_TRUST_MAP_FILE); + verifyTimeServiceAndScheduledFuture(); + } + + @Test + public void provider_constructor_rejectsRootsOnlyWhenNotifyCertUpdatesTrue() { + try { + new FileWatcherCertificateProvider(watcher, true, null, null, rootFile, null, 600L, + timeService, timeProvider); + org.junit.Assert.fail("exception expected"); + } catch (UnsupportedOperationException expected) { + assertThat(expected).hasMessageThat().contains("notifyCertUpdates"); + } + } + + @Test + public void provider_constructor_rejectsBothMissing() { + try { + new FileWatcherCertificateProvider(watcher, true, null, null, null, null, 600L, timeService, + timeProvider); + org.junit.Assert.fail("exception expected"); + } catch (IllegalArgumentException expected) { + assertThat(expected).hasMessageThat().contains("identity"); + } + } + + @Test + public void provider_constructor_rejectsKeyWithoutCert() { + try { + new FileWatcherCertificateProvider(watcher, true, null, keyFile, rootFile, null, 600L, + timeService, timeProvider); + org.junit.Assert.fail("exception expected"); + } catch (IllegalArgumentException expected) { + assertThat(expected).hasMessageThat().contains("certFile"); + } + } + + @Test + public void provider_constructor_rejectsCertWithoutKey() { + try { + new FileWatcherCertificateProvider(watcher, true, certFile, null, rootFile, null, 600L, + timeService, timeProvider); + org.junit.Assert.fail("exception expected"); + } catch (IllegalArgumentException expected) { + assertThat(expected).hasMessageThat().contains("keyFile"); + } + } + @Test public void closeDoesNotScheduleNext() throws IOException, CertificateException { TestScheduledFuture scheduledFuture =