diff --git a/test/org/apache/tomcat/util/net/TestSSLHostConfigProtocol.java b/test/org/apache/tomcat/util/net/TestSSLHostConfigProtocol.java index 2db6cde512b9..b90d5a59d0e9 100644 --- a/test/org/apache/tomcat/util/net/TestSSLHostConfigProtocol.java +++ b/test/org/apache/tomcat/util/net/TestSSLHostConfigProtocol.java @@ -20,12 +20,18 @@ import java.util.Collection; import java.util.List; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.TrustManager; + import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameter; +import org.apache.catalina.Context; import org.apache.catalina.connector.Connector; import org.apache.catalina.startup.Tomcat; import org.apache.catalina.startup.TomcatBaseTest; @@ -95,6 +101,44 @@ private void doTestIgnoreProtocol(String protocol) throws Exception { Assert.assertEquals("TLSv1.2", enabledProtocols[0]); } + @Test(expected = SSLHandshakeException.class) + public void testTlsVersionMismatchServerTls13ClientTls12() throws Exception { + SSLHostConfig sslHostConfig = getSSLHostConfig(); + sslHostConfig.setProtocols(Constants.SSL_PROTO_TLSv1_3); + + Context ctx = getProgrammaticRootContext(); + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); + + Tomcat tomcat = getTomcatInstance(); + tomcat.start(); + + TesterSupport.configureClientSsl(true); + + getUrl("https://localhost:" + getPort() + "/"); + } + + @Test(expected = SSLHandshakeException.class) + public void testTlsVersionMismatchServerTls12ClientTls13() throws Exception { + SSLHostConfig sslHostConfig = getSSLHostConfig(); + sslHostConfig.setProtocols(Constants.SSL_PROTO_TLSv1_2); + + Context ctx = getProgrammaticRootContext(); + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); + + Tomcat tomcat = getTomcatInstance(); + tomcat.start(); + + SSLContext sc = SSLContext.getInstance(Constants.SSL_PROTO_TLSv1_3); + sc.init(null, new TrustManager[] { new TesterSupport.TrustAllCerts() }, null); + TesterSupport.ClientSSLSocketFactory clientSSLSocketFactory = new TesterSupport.ClientSSLSocketFactory(sc.getSocketFactory()); + clientSSLSocketFactory.setProtocols(new String[] { Constants.SSL_PROTO_TLSv1_3 }); + HttpsURLConnection.setDefaultSSLSocketFactory(clientSSLSocketFactory); + + getUrl("https://localhost:" + getPort() + "/"); + } + private SSLHostConfig getSSLHostConfig() { Tomcat tomcat = getTomcatInstance(); diff --git a/test/org/apache/tomcat/util/net/TesterSupport.java b/test/org/apache/tomcat/util/net/TesterSupport.java index 9c3e0d209525..a3dd1abb05f9 100644 --- a/test/org/apache/tomcat/util/net/TesterSupport.java +++ b/test/org/apache/tomcat/util/net/TesterSupport.java @@ -658,6 +658,7 @@ public static class ClientSSLSocketFactory extends SSLSocketFactory { private final SSLSocketFactory delegate; private String[] ciphers = null; + private String[] protocols = null; public ClientSSLSocketFactory(SSLSocketFactory delegate) { @@ -673,6 +674,15 @@ public void setCipher(String[] ciphers) { this.ciphers = ciphers; } + /** + * Forces the use of the specified protocols. + * + * @param protocols Array of standard protocols to use + */ + public void setProtocols(String[] protocols) { + this.protocols = protocols; + } + @Override public Socket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException { Socket result = delegate.createSocket(s, host, port, autoClose); @@ -724,6 +734,9 @@ private Socket reconfigureSocket(Socket socket) { if (ciphers != null) { ((SSLSocket) socket).setEnabledCipherSuites(ciphers); } + if (protocols != null) { + ((SSLSocket) socket).setEnabledProtocols(protocols); + } return socket; } }