diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/JettyWebSocketClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/JettyWebSocketClient.java index e3ef05f2064e..7efa0446f051 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/JettyWebSocketClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/JettyWebSocketClient.java @@ -25,6 +25,7 @@ import org.eclipse.jetty.client.Response; import org.eclipse.jetty.http.HttpHeader; import org.eclipse.jetty.util.component.LifeCycle; +import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.JettyUpgradeListener; import org.jspecify.annotations.Nullable; @@ -107,8 +108,10 @@ public void onHandshakeResponse(Request request, Response response) { }; Sinks.Empty completion = Sinks.empty(); - JettyWebSocketHandlerAdapter handlerAdapter = new JettyWebSocketHandlerAdapter(handler, session -> - new JettyWebSocketSession(session, Objects.requireNonNull(handshakeInfo.get()), DefaultDataBufferFactory.sharedInstance, completion)); + JettyWebSocketHandlerAdapter handlerAdapter = new JettyWebSocketHandlerAdapter(handler, session -> { + configureSession(session); + return new JettyWebSocketSession(session, Objects.requireNonNull(handshakeInfo.get()), DefaultDataBufferFactory.sharedInstance, completion); + }); try { this.client.connect(handlerAdapter, upgradeRequest, jettyUpgradeListener) .exceptionally(throwable -> { @@ -123,4 +126,15 @@ public void onHandshakeResponse(Request request, Response response) { return Mono.error(ex); } } + + private void configureSession(Session session) { + session.setMaxFrameSize(this.client.getMaxFrameSize()); + session.setMaxBinaryMessageSize(this.client.getMaxBinaryMessageSize()); + session.setMaxTextMessageSize(this.client.getMaxTextMessageSize()); + session.setMaxOutgoingFrames(this.client.getMaxOutgoingFrames()); + session.setIdleTimeout(this.client.getIdleTimeout()); + session.setAutoFragment(this.client.isAutoFragment()); + session.setInputBufferSize(this.client.getInputBufferSize()); + session.setOutputBufferSize(this.client.getOutputBufferSize()); + } } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/ReactorNettyWebSocketClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/ReactorNettyWebSocketClient.java index 416cc2d9fe96..ea91a82191be 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/ReactorNettyWebSocketClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/ReactorNettyWebSocketClient.java @@ -127,16 +127,18 @@ public Mono execute(URI url, WebSocketHandler handler) { @Override public Mono execute(URI url, HttpHeaders requestHeaders, WebSocketHandler handler) { String protocols = StringUtils.collectionToCommaDelimitedString(handler.getSubProtocols()); + WebsocketClientSpec wsClientSpec = buildSpec(protocols); return getHttpClient() .headers(nettyHeaders -> setNettyHeaders(requestHeaders, nettyHeaders)) - .websocket(buildSpec(protocols)) + .websocket(wsClientSpec) .uri(url.toString()) .handle((inbound, outbound) -> { HttpHeaders responseHeaders = toHttpHeaders(inbound); String protocol = responseHeaders.getFirst("Sec-WebSocket-Protocol"); HandshakeInfo info = new HandshakeInfo(url, responseHeaders, Mono.empty(), protocol); NettyDataBufferFactory factory = new NettyDataBufferFactory(outbound.alloc()); - WebSocketSession session = new ReactorNettyWebSocketSession(inbound, outbound, info, factory); + WebSocketSession session = new ReactorNettyWebSocketSession(inbound, outbound, info, factory, + wsClientSpec.maxFramePayloadLength()); if (logger.isDebugEnabled()) { logger.debug("Started session '" + session.getId() + "' for " + url); } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/socket/AbstractReactiveWebSocketIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/socket/AbstractReactiveWebSocketIntegrationTests.java index 30c384d5c6d4..44f8c8403efa 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/socket/AbstractReactiveWebSocketIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/socket/AbstractReactiveWebSocketIntegrationTests.java @@ -148,7 +148,9 @@ void stopServer() { if (this.client instanceof Lifecycle lifecycle) { lifecycle.stop(); } - this.server.stop(); + if (this.server != null) { + this.server.stop(); + } } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java index 9f0820bf009b..6754116f0d6f 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java @@ -16,7 +16,9 @@ package org.springframework.web.reactive.socket; +import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -27,21 +29,28 @@ import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.netty.http.client.WebsocketClientSpec; import reactor.util.retry.Retry; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.core.io.buffer.DataBuffer; import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseCookie; import org.springframework.web.filter.reactive.ServerWebExchangeContextFilter; import org.springframework.web.reactive.HandlerMapping; import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping; +import org.springframework.web.reactive.socket.adapter.NettyWebSocketSessionSupport; +import org.springframework.web.reactive.socket.client.JettyWebSocketClient; +import org.springframework.web.reactive.socket.client.ReactorNettyWebSocketClient; +import org.springframework.web.reactive.socket.client.TomcatWebSocketClient; import org.springframework.web.reactive.socket.client.WebSocketClient; import org.springframework.web.server.WebFilter; import org.springframework.web.testfixture.http.server.reactive.bootstrap.HttpServer; import org.springframework.web.testfixture.http.server.reactive.bootstrap.TomcatHttpServer; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; /** * Integration tests with server-side {@link WebSocketHandler}s. @@ -186,6 +195,51 @@ void cookie(WebSocketClient client, HttpServer server, Class serverConfigClas assertThat(cookie.get()).isEqualTo("project=spring"); } + @ParameterizedWebSocketTest + void largePayload(WebSocketClient client, HttpServer server, Class serverConfigClass) throws Exception { + + int defaultFrameMaxSize = NettyWebSocketSessionSupport.DEFAULT_FRAME_MAX_SIZE; + int extendedLimit = 2 * defaultFrameMaxSize; + + WebSocketClient extendedClient = extendLimits(client, extendedLimit); + + startServer(extendedClient, server, serverConfigClass); + + AtomicReference payloadSizeRef = new AtomicReference<>(); + assertThatCode(() -> extendedClient.execute(getUrl("/large-payload"), + session -> session.receive() + .map(WebSocketMessage::getPayload) + .map(DataBuffer::readableByteCount) + .reduce(Integer::sum) + .doOnNext(payloadSizeRef::set) + .then()) + .block(TIMEOUT)) + .doesNotThrowAnyException(); + + assertThat(payloadSizeRef.get()).isGreaterThan(defaultFrameMaxSize); + assertThat(payloadSizeRef.get()).isEqualTo(extendedLimit); + } + + private WebSocketClient extendLimits(WebSocketClient client, int limit) { + if (client instanceof ReactorNettyWebSocketClient netty) { + client = new ReactorNettyWebSocketClient( + netty.getHttpClient(), + () -> WebsocketClientSpec.builder().maxFramePayloadLength(limit)); + } + + if (client instanceof TomcatWebSocketClient tomcat) { + tomcat.getWebSocketContainer().setDefaultMaxTextMessageBufferSize(limit); + } + + if (client instanceof JettyWebSocketClient) { + org.eclipse.jetty.websocket.client.WebSocketClient jetty = + new org.eclipse.jetty.websocket.client.WebSocketClient(); + jetty.setMaxTextMessageSize(limit); + client = new JettyWebSocketClient(jetty); + } + + return client; + } @Configuration static class WebConfig { @@ -198,6 +252,7 @@ public HandlerMapping handlerMapping() { map.put("/custom-header", new CustomHeaderHandler()); map.put("/close", new SessionClosingHandler()); map.put("/cookie", new CookieHandler()); + map.put("/large-payload", new LargePayloadHandler()); return new SimpleUrlHandlerMapping(map); } @@ -274,4 +329,16 @@ public Mono handle(WebSocketSession session) { } } + private static class LargePayloadHandler implements WebSocketHandler { + + @Override + public Mono handle(WebSocketSession session) { + int doubledFrameSize = 2 * NettyWebSocketSessionSupport.DEFAULT_FRAME_MAX_SIZE; + byte[] payload = new byte[doubledFrameSize]; + Arrays.fill(payload, (byte) 'x'); + String text = new String(payload, StandardCharsets.UTF_8); + WebSocketMessage message = session.textMessage(text); + return session.send(Mono.just(message)); + } + } }