Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions gql/transport/websockets_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,12 +479,18 @@ async def _after_connect(self):
# Find the backend subprotocol returned in the response headers
try:
self.subprotocol = self.response_headers["Sec-WebSocket-Protocol"]
log.debug(f"backend subprotocol returned: {self.subprotocol!r}")
except KeyError:
# If the server does not send the subprotocol header, using
# the apollo subprotocol by default
self.subprotocol = self.APOLLO_SUBPROTOCOL

log.debug(f"backend subprotocol returned: {self.subprotocol!r}")
# If the server does not send the subprotocol header, use
# the apollo subprotocol by default unless we didn't ask for it
if (
self.adapter.subprotocols is None
or self.APOLLO_SUBPROTOCOL in self.adapter.subprotocols
):
self.subprotocol = self.APOLLO_SUBPROTOCOL
else:
self.subprotocol = self.GRAPHQLWS_SUBPROTOCOL
log.debug(f"backend returned no subprotocol, using: {self.subprotocol!r}")

async def _after_initialize(self):

Expand Down
38 changes: 38 additions & 0 deletions tests/test_graphqlws_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,3 +899,41 @@ async def test_graphqlws_subscription_reconnecting_session(
break

assert transport._connected is False


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
async def test_graphqlws_subscription_no_server_protocol(server, subscription_str):
"""The goal of this test is to verify that if the client requests only the
graphqlws subprotocol AND the server is not returning its subprotocol
in its header, then the client will assume that the protocol used is
the graphqlws subprotocol (See PR #586).
"""

from gql.transport.websockets import WebsocketsTransport

url = f"ws://{server.hostname}:{server.port}/graphql"
print(f"url = {url}")

transport = WebsocketsTransport(
url=url,
subprotocols=[WebsocketsTransport.GRAPHQLWS_SUBPROTOCOL],
keep_alive_timeout=3,
)

client = Client(transport=transport)

count = 10
subscription = gql(subscription_str.format(count=count))

async with client as session:
async for result in session.subscribe(subscription):

number = result["number"]
print(f"Number received: {number}")

assert number == count
count -= 1

assert count == -1
Loading