Skip to content

Commit 3e29572

Browse files
committed
Introduce integration test for TLS session tickets
Adds isolated SessionTicketsIT which checks if session tickets mechanism behaves as expected. It does that by watching Java's SSL debug logs and ensuring that specific substrings appear in them. Since the server supports tickets with TLSv1.3, only that version is tested. In that version the session resumption without server-side state is done through pre-shared keys. The details are described in rfc8446 (see section 2.2).
1 parent c800792 commit 3e29572

1 file changed

Lines changed: 248 additions & 0 deletions

File tree

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
package com.datastax.oss.driver.core.ssl;
2+
3+
import static java.nio.charset.StandardCharsets.UTF_8;
4+
5+
import com.datastax.oss.driver.api.core.CqlSession;
6+
import com.datastax.oss.driver.api.testinfra.ccm.CcmBridge;
7+
import com.datastax.oss.driver.api.testinfra.ccm.CustomCcmRule;
8+
import com.datastax.oss.driver.api.testinfra.requirement.BackendRequirement;
9+
import com.datastax.oss.driver.api.testinfra.requirement.BackendType;
10+
import com.datastax.oss.driver.api.testinfra.session.SessionUtils;
11+
import com.datastax.oss.driver.categories.IsolatedTests;
12+
import com.google.common.collect.ImmutableList;
13+
import java.io.ByteArrayOutputStream;
14+
import java.io.IOException;
15+
import java.io.InputStream;
16+
import java.nio.file.Files;
17+
import java.nio.file.Paths;
18+
import java.security.KeyStore;
19+
import java.security.SecureRandom;
20+
import java.util.ArrayList;
21+
import java.util.List;
22+
import java.util.concurrent.atomic.AtomicInteger;
23+
import java.util.function.Function;
24+
import java.util.logging.Handler;
25+
import java.util.logging.Level;
26+
import java.util.logging.LogRecord;
27+
import java.util.logging.Logger;
28+
import javax.net.ssl.KeyManagerFactory;
29+
import javax.net.ssl.SSLContext;
30+
import javax.net.ssl.TrustManagerFactory;
31+
import org.junit.ClassRule;
32+
import org.junit.Test;
33+
import org.junit.experimental.categories.Category;
34+
35+
@Category(IsolatedTests.class)
36+
@BackendRequirement(
37+
type = BackendType.SCYLLA,
38+
minInclusive = "2025.2.0",
39+
description = "Requires server side support for session tickets")
40+
public class SessionTicketsIT {
41+
@ClassRule
42+
public static final CustomCcmRule CCM_RULE =
43+
CustomCcmRule.builder()
44+
.withSsl()
45+
.withCassandraConfiguration("client_encryption_options.enable_session_tickets", "true")
46+
.build();
47+
48+
@Test
49+
public void should_use_session_tickets_when_TLSv13() throws Exception {
50+
// This test relies on parsing SSL debug logs which may not be identical across different
51+
// Java versions. Additionally, the expected number of specific messages depends on default
52+
// driver behavior.
53+
// Currently the driver negotiates the protocol version starting from DSE_V2 down to V4, and
54+
// does not
55+
// try to use V6. If that ever changes or Scylla starts supporting V5 or starts to send
56+
// different
57+
// number of tickets after hello, the expected counts will need to be updated.
58+
59+
// Set up SSL debug logging
60+
System.setProperty("javax.net.debug", "");
61+
Logger sslLogger = Logger.getLogger("javax.net.ssl");
62+
Level originalLevel = sslLogger.getLevel();
63+
sslLogger.setLevel(Level.ALL);
64+
65+
final int expectedNumOfClientHellos =
66+
6; // 4 come from ControlConnection (DSE_V2, DSE_V1, V5, V4 protocols) + 2 from dedicated
67+
// Channels to 2 shards
68+
69+
List<OccurrenceCounter> counters =
70+
ImmutableList.of(
71+
new OccurrenceCounter(
72+
"Consuming CertificateVerify handshake message", count -> count == 1),
73+
// It looks like server currently sends two tickets after each handshake
74+
new OccurrenceCounter(
75+
"Consuming NewSessionTicket", count -> count == expectedNumOfClientHellos * 2),
76+
new OccurrenceCounter(
77+
"Produced ClientHello handshake message",
78+
count -> count == expectedNumOfClientHellos),
79+
new OccurrenceCounter(
80+
"Consuming ServerHello handshake message",
81+
count -> count == expectedNumOfClientHellos),
82+
new OccurrenceCounter(
83+
"Try resuming session", count -> count == expectedNumOfClientHellos - 1),
84+
new OccurrenceCounter(
85+
"Consumed extension: pre_shared_key",
86+
count -> count == expectedNumOfClientHellos - 1),
87+
new OccurrenceCounter(
88+
"Resuming session:", count -> count == expectedNumOfClientHellos - 1),
89+
new OccurrenceCounter(
90+
"Using PSK to derive early secret",
91+
count -> count == expectedNumOfClientHellos - 1),
92+
new OccurrenceCounter(
93+
"Negotiated protocol version: TLSv1.3",
94+
count -> count == expectedNumOfClientHellos));
95+
96+
// Custom handler to capture log messages
97+
ByteArrayOutputStream logCapture = new ByteArrayOutputStream();
98+
TlsDebugLogHandler handler = new TlsDebugLogHandler(logCapture, counters);
99+
sslLogger.setUseParentHandlers(false);
100+
sslLogger.addHandler(handler);
101+
102+
try {
103+
SSLContext context = createSslContext("TLSv1.3");
104+
try (CqlSession session =
105+
(CqlSession)
106+
SessionUtils.baseBuilder()
107+
.addContactEndPoints(CCM_RULE.getContactPoints())
108+
.withSslContext(context)
109+
.build()) {
110+
session.execute("select * from system.local where key='local'");
111+
}
112+
List<String> errors = new ArrayList<>();
113+
for (OccurrenceCounter counter : counters) {
114+
if (!counter.isWithinNorm()) {
115+
errors.add(
116+
"Counter not within norm: "
117+
+ counter
118+
+ ", inspect the function passed to constructor at "
119+
+ counter.getCreationLocation());
120+
}
121+
}
122+
if (!errors.isEmpty()) {
123+
StringBuilder sb = new StringBuilder();
124+
sb.append("One or more assertions failed:\n");
125+
for (String error : errors) {
126+
sb.append(error).append("\n");
127+
}
128+
throw new AssertionError(sb.toString());
129+
}
130+
} finally {
131+
sslLogger.removeHandler(handler);
132+
sslLogger.setLevel(originalLevel);
133+
}
134+
}
135+
136+
private SSLContext createSslContext(String protocol) throws Exception {
137+
SSLContext context = SSLContext.getInstance(protocol);
138+
139+
TrustManagerFactory tmf =
140+
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
141+
try (InputStream tsf =
142+
Files.newInputStream(
143+
Paths.get(CcmBridge.DEFAULT_CLIENT_TRUSTSTORE_FILE.getAbsolutePath()))) {
144+
KeyStore ts = KeyStore.getInstance("JKS");
145+
char[] password = CcmBridge.DEFAULT_CLIENT_TRUSTSTORE_PASSWORD.toCharArray();
146+
ts.load(tsf, password);
147+
tmf.init(ts);
148+
}
149+
150+
KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
151+
try (InputStream ksf =
152+
Files.newInputStream(Paths.get(CcmBridge.DEFAULT_CLIENT_KEYSTORE_FILE.getAbsolutePath()))) {
153+
KeyStore ks = KeyStore.getInstance("JKS");
154+
char[] password = CcmBridge.DEFAULT_CLIENT_KEYSTORE_PASSWORD.toCharArray();
155+
ks.load(ksf, password);
156+
kmf.init(ks, password);
157+
}
158+
159+
context.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
160+
return context;
161+
}
162+
163+
static class TlsDebugLogHandler extends Handler {
164+
private final ByteArrayOutputStream outputStream;
165+
private final List<OccurrenceCounter> counters;
166+
167+
TlsDebugLogHandler(ByteArrayOutputStream outputStream, List<OccurrenceCounter> counters) {
168+
this.outputStream = outputStream;
169+
this.counters = counters;
170+
}
171+
172+
@Override
173+
public void publish(LogRecord record) {
174+
try {
175+
for (OccurrenceCounter counter : counters) {
176+
counter.incrementIfFound(record.getMessage());
177+
}
178+
outputStream.write((record.getMessage() + "\n").getBytes(UTF_8));
179+
} catch (IOException e) {
180+
throw new RuntimeException(e);
181+
}
182+
}
183+
184+
@Override
185+
public void flush() {
186+
try {
187+
outputStream.flush();
188+
} catch (IOException e) {
189+
throw new RuntimeException(e);
190+
}
191+
}
192+
193+
@Override
194+
public void close() throws SecurityException {
195+
try {
196+
outputStream.close();
197+
} catch (IOException e) {
198+
throw new RuntimeException(e);
199+
}
200+
}
201+
}
202+
203+
static class OccurrenceCounter {
204+
private final AtomicInteger count = new AtomicInteger(0);
205+
private final String substring; // Exact substring to look for
206+
private final Function<Integer, Boolean> isWithinNorm;
207+
private final StackTraceElement creationLocation;
208+
209+
public OccurrenceCounter(String substring, Function<Integer, Boolean> isWithinNorm) {
210+
this.substring = substring;
211+
this.isWithinNorm = isWithinNorm;
212+
this.creationLocation = new Throwable().getStackTrace()[1]; // [1] gets the caller's location
213+
}
214+
215+
/**
216+
* Increment the counter if the substring is found in the log line. Multiple occurrences count
217+
* as one.
218+
*
219+
* @param logLine log line to check
220+
*/
221+
public void incrementIfFound(String logLine) {
222+
if (logLine.contains(substring)) {
223+
count.incrementAndGet();
224+
}
225+
}
226+
227+
public int get() {
228+
return count.get();
229+
}
230+
231+
public boolean isWithinNorm() {
232+
return isWithinNorm.apply(count.get());
233+
}
234+
235+
public String getSubstring() {
236+
return substring;
237+
}
238+
239+
public String getCreationLocation() {
240+
return creationLocation.toString();
241+
}
242+
243+
@Override
244+
public String toString() {
245+
return "OccurrenceCounter{substring='" + substring + "', count=" + count.get() + "}";
246+
}
247+
}
248+
}

0 commit comments

Comments
 (0)