diff --git a/rlib-network/src/main/java/javasabr/rlib/network/NetworkFactory.java b/rlib-network/src/main/java/javasabr/rlib/network/NetworkFactory.java index 0c09b9a5..e0b0b00b 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/NetworkFactory.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/NetworkFactory.java @@ -7,6 +7,7 @@ import javasabr.rlib.network.impl.DefaultBufferAllocator; import javasabr.rlib.network.impl.DefaultConnection; import javasabr.rlib.network.impl.StringDataConnection; +import javasabr.rlib.network.impl.StringDataMtlsServerConnection; import javasabr.rlib.network.impl.StringDataSslConnection; import javasabr.rlib.network.packet.impl.DefaultReadableNetworkPacket; import javasabr.rlib.network.packet.registry.ReadableNetworkPacketRegistry; @@ -140,7 +141,11 @@ public static ClientNetwork stringDataSslClientNetwork( SSLContext sslContext) { return clientNetwork( networkConfig, - (network, channel) -> new StringDataSslConnection(network, channel, bufferAllocator, sslContext, true)); + (network, channel) -> { + StringDataSslConnection connection = new StringDataSslConnection(network, channel, bufferAllocator, sslContext, true); + connection.beginHandshake(); + return connection; + }); } /** @@ -196,7 +201,11 @@ public static ServerNetwork stringDataSslServerNetwork( SSLContext sslContext) { return serverNetwork( networkConfig, - (network, channel) -> new StringDataSslConnection(network, channel, bufferAllocator, sslContext, false)); + (network, channel) -> { + StringDataSslConnection connection = new StringDataSslConnection(network, channel, bufferAllocator, sslContext, false); + connection.beginHandshake(); + return connection; + }); } /** @@ -231,4 +240,26 @@ public static ServerNetwork defaultServerNetwork( networkConfig, (network, channel) -> new DefaultConnection(network, channel, bufferAllocator, packetRegistry)); } + + /** + * Create string packet based asynchronous Mutual TLS server network. + * + * @param networkConfig the server network configuration + * @param bufferAllocator the buffer allocator + * @param sslContext SSL context + * @return a new mTLS server network + * @since 10.0.0 + */ + public static ServerNetwork stringDataMtlsServerNetwork( + ServerNetworkConfig networkConfig, + BufferAllocator bufferAllocator, + SSLContext sslContext) { + return serverNetwork( + networkConfig, + (network, channel) -> { + StringDataMtlsServerConnection connection = new StringDataMtlsServerConnection(network, channel, bufferAllocator, sslContext); + connection.beginHandshake(); + return connection; + }); + } } diff --git a/rlib-network/src/main/java/javasabr/rlib/network/exception/ConnectionClosedException.java b/rlib-network/src/main/java/javasabr/rlib/network/exception/ConnectionClosedException.java new file mode 100644 index 00000000..c5e90114 --- /dev/null +++ b/rlib-network/src/main/java/javasabr/rlib/network/exception/ConnectionClosedException.java @@ -0,0 +1,12 @@ +package javasabr.rlib.network.exception; + +public class ConnectionClosedException extends NetworkException { + + public ConnectionClosedException(String remoteAddress) { + super("Connection closed: %s".formatted(remoteAddress)); + } + + public ConnectionClosedException(String remoteAddress, Throwable cause) { + super("Connection closed: %s".formatted(remoteAddress), cause); + } +} diff --git a/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java b/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java index 7e8e04c6..3d4c2ee9 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java @@ -16,6 +16,7 @@ import javasabr.rlib.network.Connection; import javasabr.rlib.network.Network; import javasabr.rlib.network.UnsafeConnection; +import javasabr.rlib.network.exception.ConnectionClosedException; import javasabr.rlib.network.packet.NetworkPacketReader; import javasabr.rlib.network.packet.NetworkPacketWriter; import javasabr.rlib.network.packet.ReadableNetworkPacket; @@ -64,6 +65,7 @@ public WritablePacketWithFeedback(CompletableFuture attachment, Writabl final MutableArray>> validPacketSubscribers; final MutableArray>> invalidPacketSubscribers; + final MutableArray> activeSinks; final int maxPacketsByRead; @@ -84,6 +86,7 @@ public AbstractConnection( this.closed = new AtomicBoolean(false); this.validPacketSubscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class); this.invalidPacketSubscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class); + this.activeSinks = ArrayFactory.stampedLockBasedArray(FluxSink.class); this.remoteAddress = String.valueOf(NetworkUtils.getRemoteAddress(channel)); } @@ -134,10 +137,12 @@ protected void registerFluxOnReceivedEvents( validPacketSubscribers.add(validListener); invalidPacketSubscribers.add(invalidListener); + activeSinks.add(sink); sink.onDispose(() -> { validPacketSubscribers.remove(validListener); validPacketSubscribers.remove(invalidListener); + activeSinks.remove(sink); }); network.inNetworkThread(() -> packetReader().startRead()); @@ -146,14 +151,22 @@ protected void registerFluxOnReceivedEvents( protected void registerFluxOnReceivedValidPackets(FluxSink> sink) { BiConsumer> listener = (connection, packet) -> sink.next(packet); validPacketSubscribers.add(listener); - sink.onDispose(() -> validPacketSubscribers.remove(listener)); + activeSinks.add(sink); + sink.onDispose(() -> { + validPacketSubscribers.remove(listener); + activeSinks.remove(sink); + }); network.inNetworkThread(() -> packetReader().startRead()); } protected void registerFluxOnReceivedInvalidPackets(FluxSink> sink) { BiConsumer> listener = (connection, packet) -> sink.next(packet); invalidPacketSubscribers.add(listener); - sink.onDispose(() -> invalidPacketSubscribers.remove(listener)); + activeSinks.add(sink); + sink.onDispose(() -> { + invalidPacketSubscribers.remove(listener); + activeSinks.remove(sink); + }); network.inNetworkThread(() -> packetReader().startRead()); } @@ -184,6 +197,24 @@ protected void doClose() { clearWaitPackets(); packetReader().close(); packetWriter().close(); + notifySinksOnError(); + } + + protected void notifySinksOnError() { + if (activeSinks.isEmpty()) { + return; + } + ConnectionClosedException error = new ConnectionClosedException(remoteAddress); + activeSinks + .iterations() + .forEach(error, (sink, exc) -> { + try { + sink.error(exc); + } catch (RuntimeException e) { + log.error(e.getMessage(), "Failed to notify sink of connection closure: "::formatted); + } + }); + activeSinks.clear(); } /** diff --git a/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractSslConnection.java b/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractSslConnection.java index 9ae853d7..f015070b 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractSslConnection.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractSslConnection.java @@ -26,6 +26,9 @@ public AbstractSslConnection( super(network, channel, bufferAllocator, maxPacketsByRead); this.sslEngine = sslContext.createSSLEngine(); this.sslEngine.setUseClientMode(clientMode); + } + + public void beginHandshake() { try { this.sslEngine.beginHandshake(); } catch (SSLException e) { diff --git a/rlib-network/src/main/java/javasabr/rlib/network/impl/StringDataMtlsServerConnection.java b/rlib-network/src/main/java/javasabr/rlib/network/impl/StringDataMtlsServerConnection.java new file mode 100644 index 00000000..9bd21203 --- /dev/null +++ b/rlib-network/src/main/java/javasabr/rlib/network/impl/StringDataMtlsServerConnection.java @@ -0,0 +1,28 @@ +package javasabr.rlib.network.impl; + +import javasabr.rlib.network.BufferAllocator; +import javasabr.rlib.network.Network; +import javasabr.rlib.network.packet.impl.StringReadableNetworkPacket; + +import javax.net.ssl.SSLContext; +import java.nio.channels.AsynchronousSocketChannel; + +/** + * @author crazyrokr + */ +public class StringDataMtlsServerConnection extends DefaultDataSslConnection { + + public StringDataMtlsServerConnection( + Network network, + AsynchronousSocketChannel channel, + BufferAllocator bufferAllocator, + SSLContext sslContext) { + super(network, channel, bufferAllocator, sslContext, 100, 2, false); + sslEngine.setNeedClientAuth(true); + } + + @Override + protected StringReadableNetworkPacket createReadablePacket() { + return new StringReadableNetworkPacket<>(); + } +} diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java index de7e5b8a..ee683870 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java @@ -461,10 +461,14 @@ protected void handleFailedReceiving(Throwable exception, ByteBuffer readingBuff retryReadLater(); } } - case AsynchronousCloseException ex -> - log.info(remoteAddress(), "[%s] Connection was closed"::formatted); - case ClosedChannelException ex -> - log.info(remoteAddress(), "[%s] Connection was closed"::formatted); + case AsynchronousCloseException ex -> { + log.info(remoteAddress(), "[%s] Connection was closed"::formatted); + connection.close(); + } + case ClosedChannelException ex -> { + log.info(remoteAddress(), "[%s] Connection was closed"::formatted); + connection.close(); + } default -> { log.error(exception); connection.close(); diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java index 6ab75309..a69eb4e2 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java @@ -76,6 +76,7 @@ protected AbstractSslNetworkPacketReader( protected void handleReceivedData(int receivedBytes, ByteBuffer readingBuffer) { if (receivedBytes == -1) { doHandshake(sslNetworkBuffer(), -1); + connection.close(); return; } super.handleReceivedData(receivedBytes, readingBuffer); @@ -159,6 +160,9 @@ protected int doHandshake(ByteBuffer networkBuffer, int receivedBytes) { case NEED_WRAP: { log.debug(remoteAddress, "[%s] Send command to wrap data"::formatted); packetWriter.accept(SslWrapRequestNetworkPacket.getInstance()); + if (networkBuffer.hasRemaining()) { + return decryptAndRead(networkBuffer); + } NetworkUtils.cleanNetworkBuffer(networkBuffer); return SKIP_READ_PACKETS; } @@ -203,6 +207,10 @@ protected int decryptAndRead(ByteBuffer receivedBuffer) { } switch (result.getStatus()) { case OK: { + if (result.bytesConsumed() == 0 && result.bytesProduced() == 0) { + log.debug(remoteAddress, "[%s] No progress during decryption, stop processing"::formatted); + return SKIP_READ_PACKETS; + } sslDataBuffer.flip(); logDataAfterDecrypt(remoteAddress, sslDataBuffer); total += readPackets(sslDataBuffer, sslDataPendingBuffer); diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketWriter.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketWriter.java index f15f9aee..b2b0fb57 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketWriter.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketWriter.java @@ -197,7 +197,7 @@ protected ByteBuffer doHandshake(HandshakeStatus handshakeStatus) { break; } case NEED_UNWRAP: { - break; + return EMPTY_BUFFER; } default: { throw new IllegalStateException("Invalid SSL status:" + handshakeStatus); diff --git a/rlib-network/src/test/java/javasabr/rlib/network/ConnectionCloseTest.java b/rlib-network/src/test/java/javasabr/rlib/network/ConnectionCloseTest.java new file mode 100644 index 00000000..42a70f76 --- /dev/null +++ b/rlib-network/src/test/java/javasabr/rlib/network/ConnectionCloseTest.java @@ -0,0 +1,49 @@ +package javasabr.rlib.network; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.net.InetSocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javasabr.rlib.network.exception.ConnectionClosedException; +import javasabr.rlib.network.impl.AbstractConnection; +import javasabr.rlib.network.impl.DefaultConnection; +import javasabr.rlib.network.packet.impl.DefaultReadableNetworkPacket; +import javasabr.rlib.network.packet.registry.ReadableNetworkPacketRegistry; +import org.junit.jupiter.api.Test; + +public class ConnectionCloseTest extends BaseNetworkTest { + + @Test + void shouldPropagateConnectionCloseToClient() throws InterruptedException { + // given + var packetRegistry = ReadableNetworkPacketRegistry.of( + DefaultReadableNetworkPacket.class, + DefaultConnection.class, + DefaultNetworkTest.ServerPackets.RequestEchoMessage.class, + DefaultNetworkTest.ServerPackets.RequestServerTime.class); + var serverNetwork = NetworkFactory.defaultServerNetwork(packetRegistry); + InetSocketAddress serverAddress = serverNetwork.start(); + serverNetwork.onAccept(AbstractConnection::close); + var clientNetwork = NetworkFactory.defaultClientNetwork(packetRegistry); + CountDownLatch closeLatch = new CountDownLatch(1); + + // when + clientNetwork + .connectReactive(serverAddress) + .flatMapMany(AbstractConnection::receivedEvents) + .doOnError(e -> { + if (e instanceof ConnectionClosedException) { + closeLatch.countDown(); + } + }) + .subscribe(); + + // then + assertThat(closeLatch.await(5000, TimeUnit.MILLISECONDS)) + .as("Client should be notified that connection is closed") + .isTrue(); + clientNetwork.shutdown(); + serverNetwork.shutdown(); + } +} diff --git a/rlib-network/src/test/java/javasabr/rlib/network/StringSslNetworkTest.java b/rlib-network/src/test/java/javasabr/rlib/network/StringSslNetworkTest.java index a91a4419..43d7bfe3 100644 --- a/rlib-network/src/test/java/javasabr/rlib/network/StringSslNetworkTest.java +++ b/rlib-network/src/test/java/javasabr/rlib/network/StringSslNetworkTest.java @@ -17,12 +17,15 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.IntStream; import javasabr.rlib.common.util.ObjectUtils; import javasabr.rlib.common.util.StringUtils; import javasabr.rlib.common.util.Utils; import javasabr.rlib.network.client.ClientNetwork; +import javasabr.rlib.network.exception.ConnectionClosedException; import javasabr.rlib.network.impl.DefaultBufferAllocator; +import javasabr.rlib.network.impl.StringDataMtlsServerConnection; import javasabr.rlib.network.impl.StringDataSslConnection; import javasabr.rlib.network.packet.ReadableNetworkPacket; import javasabr.rlib.network.packet.impl.StringReadableNetworkPacket; @@ -328,6 +331,63 @@ void shouldReceiveManyPacketsFromSmallToBigSize() { } } + @Test + @SneakyThrows + void shouldRejectClientWithoutCertificateWithinMutualTls() { + InputStream serverKeystoreFile = StringSslNetworkTest.class.getResourceAsStream("/ssl/rlib_test_cert.p12"); + SSLContext serverSslContext = NetworkUtils.createSslContext(serverKeystoreFile, "test"); + ServerNetworkConfig serverConfig = ServerNetworkConfig.SimpleServerNetworkConfig.builder().build(); + BufferAllocator bufferAllocator = new DefaultBufferAllocator(serverConfig); + + ServerNetwork serverNetwork = + NetworkFactory.stringDataMtlsServerNetwork(serverConfig, bufferAllocator, serverSslContext); + + InetSocketAddress serverAddress = serverNetwork.start(); + CountDownLatch dataReceivedByServer = new CountDownLatch(1); + + serverNetwork + .accepted() + .flatMap(Connection::receivedEvents) + .subscribe(event -> dataReceivedByServer.countDown()); + + SSLContext clientWithoutCertContext = NetworkUtils.createAllTrustedClientSslContext(); + ClientNetwork clientNetwork = NetworkFactory.stringDataSslClientNetwork( + NetworkConfig.DEFAULT_CLIENT, + new DefaultBufferAllocator(NetworkConfig.DEFAULT_CLIENT), + clientWithoutCertContext); + + AtomicReference connectionError = new AtomicReference<>(); + CountDownLatch errorReceived = new CountDownLatch(1); + + try { + clientNetwork + .connectReactive(serverAddress) + .doOnNext(connection -> connection.sendInBackground(new StringWritableNetworkPacket<>("no cert"))) + .flatMapMany(Connection::receivedEvents) + .subscribe( + event -> {}, + ex -> { + connectionError.set(ex); + errorReceived.countDown(); + }); + + assertThat(errorReceived.await(5, TimeUnit.SECONDS)) + .as("Client subscriber must receive an error when the server closes the mTLS connection.") + .isTrue(); + + assertThat(connectionError.get()) + .as("Client must receive ConnectionClosedException, not a timeout.") + .isInstanceOf(ConnectionClosedException.class); + + assertThat(dataReceivedByServer.getCount()) + .as("Server must not receive data from an unauthenticated client.") + .isEqualTo(1); + } finally { + serverNetwork.shutdown(); + clientNetwork.shutdown(); + } + } + private static StringWritableNetworkPacket newMessage(int minMessageLength, int maxMessageLength) { return new StringWritableNetworkPacket<>(StringUtils.generate(minMessageLength, maxMessageLength)); } diff --git a/rlib-network/src/test/java/javasabr/rlib/network/packet/impl/SslPacketReaderTest.java b/rlib-network/src/test/java/javasabr/rlib/network/packet/impl/SslPacketReaderTest.java new file mode 100644 index 00000000..d37d2af5 --- /dev/null +++ b/rlib-network/src/test/java/javasabr/rlib/network/packet/impl/SslPacketReaderTest.java @@ -0,0 +1,187 @@ +package javasabr.rlib.network.packet.impl; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import java.time.Duration; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; +import javasabr.rlib.network.BufferAllocator; +import javasabr.rlib.network.Network; +import javasabr.rlib.network.NetworkConfig; +import javasabr.rlib.network.UnsafeConnection; +import javasabr.rlib.network.impl.DefaultBufferAllocator; +import javasabr.rlib.network.packet.ReadableNetworkPacket; +import javasabr.rlib.network.packet.WritableNetworkPacket; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLEngineResult.Status; +import javax.net.ssl.SSLSession; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; + +/** + * The tests of SSL packet reader + * + * @author crazyrokr + */ +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +public class SslPacketReaderTest { + + private interface TestConnection extends UnsafeConnection {} + + @Mock + private TestConnection connection; + + @Mock + private Network network; + + @Mock + private SSLEngine sslEngine; + + @Mock + private SSLSession sslSession; + + @Mock + private Consumer> packetHandler; + + @Mock + private Consumer> packetWriter; + + private BufferAllocator bufferAllocator; + + @BeforeEach + void setUp() { + bufferAllocator = new DefaultBufferAllocator(NetworkConfig.DEFAULT_CLIENT); + when(connection.bufferAllocator()).thenReturn(bufferAllocator); + when(connection.network()).thenReturn((Network) network); + when(connection.remoteAddress()).thenReturn("test-address"); + when(network.config()).thenReturn(NetworkConfig.DEFAULT_CLIENT); + when(sslEngine.getSession()).thenReturn(sslSession); + when(sslSession.getApplicationBufferSize()).thenReturn(1024); + when(sslSession.getPacketBufferSize()).thenReturn(1024); + } + + private static class TestSslPacketReader extends + AbstractSslNetworkPacketReader, TestConnection> { + + private final AtomicInteger readPacketsCount = new AtomicInteger(); + + protected TestSslPacketReader( + TestConnection connection, + Consumer> packetHandler, + SSLEngine sslEngine, + Consumer> packetWriter) { + super(connection, () -> {}, packetHandler, packetHandler, sslEngine, packetWriter, 100); + } + + @Override + protected boolean canStartReadPacket(ByteBuffer buffer) { + return buffer.remaining() >= 1; + } + + @Override + protected int readFullPacketLength(ByteBuffer buffer) { + return 1; + } + + @Override + protected ReadableNetworkPacket createPacketFor( + ByteBuffer buffer, + int startPacketPosition, + int packetFullLength, + int packetDataLength) { + buffer.get(); // consume 1 byte + readPacketsCount.incrementAndGet(); + return mock(ReadableNetworkPacket.class); + } + } + + @Test + void testShouldNotLoseDataOnNeedWrapDuringHandshake() throws Exception { + // given + var reader = new TestSslPacketReader(connection, packetHandler, sslEngine, packetWriter); + + // Initial state: NEED_UNWRAP + when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_UNWRAP); + + // First unwrap will result in NEED_WRAP and status OK, consuming some data. + // Simulate a single network buffer containing 5 bytes of handshake data followed by + // 5 bytes of application data, so the remaining bytes can still be processed afterward. + ByteBuffer networkData = ByteBuffer.allocate(10); + networkData.put(new byte[10]); + networkData.flip(); + + // doHandshake calls unwrap in NEED_UNWRAP, consumes first 5 bytes, then returns OK + when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer[].class))).thenAnswer(invocation -> { + ByteBuffer in = invocation.getArgument(0); + in.position(in.position() + 5); // consume 5 bytes of handshake + // Change status to NEED_WRAP for next getHandshakeStatus() call + when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_WRAP); + return new SSLEngineResult(Status.OK, HandshakeStatus.NEED_WRAP, 5, 0); + }); + + // decryptAndRead calls unwrap, consumes the remaining 5 bytes, then return FINISHED or NOT_HANDSHAKING + when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer.class))).thenAnswer(invocation -> { + ByteBuffer in = invocation.getArgument(0); + ByteBuffer out = invocation.getArgument(1); + int remaining = in.remaining(); + in.position(in.limit()); // consume all + out.put(new byte[remaining]); // put decrypted data (mocked) + when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NOT_HANDSHAKING); + return new SSLEngineResult(Status.OK, HandshakeStatus.NOT_HANDSHAKING, remaining, remaining); + }); + + // when + reader.readPackets(networkData); + + // then + // readPackets should have been called for the remaining 5 bytes, + // since each packet is 1 byte, it should have read 5 packets + assertThat(reader.readPacketsCount.get()).isEqualTo(5); + verify(packetWriter).accept(any(SslWrapRequestNetworkPacket.class)); + } + + @Test + void testShouldNotDeadLoopWhenNeedWrapAndNoProgress() throws Exception { + // given + var reader = new TestSslPacketReader(connection, packetHandler, sslEngine, packetWriter); + + // Initial state: NEED_WRAP + when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_WRAP); + + // Network buffer has data + ByteBuffer networkData = ByteBuffer.allocate(10); + networkData.put(new byte[10]); + networkData.flip(); + + // Mock unwrap in decryptAndRead to return OK with 0 progress + // This happens if engine is in NEED_WRAP and can't decrypt application data + when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer.class))).thenReturn( + new SSLEngineResult(Status.OK, HandshakeStatus.NEED_WRAP, 0, 0) + ); + + // when + // We expect this NOT to hang indefinitely. + // If it dead-loops, the test will fail by timeout. + assertTimeoutPreemptively(Duration.ofSeconds(5), () -> + reader.readPackets(networkData) + ); + + // then + // Should have requested wrap + verify(packetWriter).accept(any(SslWrapRequestNetworkPacket.class)); + } +}