diff --git a/gql/transport/websockets_protocol.py b/gql/transport/websockets_protocol.py index 3b66a0cb..1ccf744e 100644 --- a/gql/transport/websockets_protocol.py +++ b/gql/transport/websockets_protocol.py @@ -479,12 +479,18 @@ async def _after_connect(self): # Find the backend subprotocol returned in the response headers try: self.subprotocol = self.response_headers["Sec-WebSocket-Protocol"] + log.debug(f"backend subprotocol returned: {self.subprotocol!r}") except KeyError: - # If the server does not send the subprotocol header, using - # the apollo subprotocol by default - self.subprotocol = self.APOLLO_SUBPROTOCOL - - log.debug(f"backend subprotocol returned: {self.subprotocol!r}") + # If the server does not send the subprotocol header, use + # the apollo subprotocol by default unless we didn't ask for it + if ( + self.adapter.subprotocols is None + or self.APOLLO_SUBPROTOCOL in self.adapter.subprotocols + ): + self.subprotocol = self.APOLLO_SUBPROTOCOL + else: + self.subprotocol = self.GRAPHQLWS_SUBPROTOCOL + log.debug(f"backend returned no subprotocol, using: {self.subprotocol!r}") async def _after_initialize(self): diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 416726aa..a65e4895 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -899,3 +899,41 @@ async def test_graphqlws_subscription_reconnecting_session( break assert transport._connected is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_no_server_protocol(server, subscription_str): + """The goal of this test is to verify that if the client requests only the + graphqlws subprotocol AND the server is not returning its subprotocol + in its header, then the client will assume that the protocol used is + the graphqlws subprotocol (See PR #586). + """ + + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = WebsocketsTransport( + url=url, + subprotocols=[WebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], + keep_alive_timeout=3, + ) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1