Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "bugfix",
"description": "Fixed `AWSCRTHTTPClient` to send HTTP/1.1 request bodies via `body_stream` instead of the HTTP/2-only `request_body_generator`."
}
30 changes: 23 additions & 7 deletions packages/smithy-http/src/smithy_http/aio/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from copy import deepcopy
from dataclasses import dataclass
from inspect import iscoroutinefunction
from io import BytesIO
from typing import TYPE_CHECKING, Any

from awscrt.exceptions import AwsCrtError
Expand Down Expand Up @@ -36,6 +37,7 @@
from smithy_core import interfaces as core_interfaces
from smithy_core.aio import interfaces as core_aio_interfaces
from smithy_core.aio.types import AsyncBytesReader
from smithy_core.aio.utils import read_streaming_blob_async
from smithy_core.exceptions import MissingDependencyError

from .. import Field, Fields
Expand Down Expand Up @@ -175,13 +177,18 @@ async def send(
crt_request = self._marshal_request(request)
connection = await self._get_connection(request.destination)

# Convert body to async iterator for request_body_generator
body_generator = self._create_body_generator(request.body)

crt_stream = connection.request(
crt_request,
request_body_generator=body_generator,
)
# request_body_generator is HTTP/2-only in CRT; HTTP/1.1 must use body_stream
if connection.version == crt_http.HttpVersion.Http2:
crt_stream = connection.request(
crt_request,
request_body_generator=self._create_body_generator(request.body),
)
else:
if (
body_stream := await self._create_body_stream(request.body)
) is not None:
crt_request.body_stream = body_stream
crt_stream = connection.request(crt_request)

return await self._await_response(crt_stream)
except AwsCrtError as e:
Expand Down Expand Up @@ -308,6 +315,15 @@ def _marshal_request(
)
return crt_request

async def _create_body_stream(
self, body: core_aio_interfaces.StreamingBlob
) -> core_interfaces.BytesReader | None:
"""Convert various body types to a bytes reader for CRT HTTP/1.1."""
if core_interfaces.is_bytes_reader(body):
return body
buffered = await read_streaming_blob_async(body)
return BytesIO(buffered) if buffered else None

async def _create_body_generator(
self, body: core_aio_interfaces.StreamingBlob
) -> AsyncGenerator[bytes, None]:
Expand Down
110 changes: 110 additions & 0 deletions packages/smithy-http/tests/unit/aio/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,69 @@ async def test_port_included_in_host_header(host: str, expected: str) -> None:
assert crt_request.headers.get("host") == expected # type: ignore


async def test_send_http1_uses_body_stream() -> None:
"""Test HTTP/1.1 requests use the CRT body stream."""
client = AWSCRTHTTPClient()
request = HTTPRequest(
method="POST",
destination=URI(scheme="https", host="example.com", path="/"),
body=AsyncBytesReader(b"Action=Test&Version=2020-01-08"),
fields=Fields(),
)
mock_stream = Mock()
mock_response = Mock()
mock_connection = Mock()
mock_connection.version = crt_http.HttpVersion.Http1_1
mock_connection.request = Mock(return_value=mock_stream)

with (
patch.object(
client, "_get_connection", AsyncMock(return_value=mock_connection)
),
patch.object(client, "_await_response", AsyncMock(return_value=mock_response)),
):
actual = await client.send(request)

assert actual is mock_response
mock_connection.request.assert_called_once()
assert "request_body_generator" not in mock_connection.request.call_args.kwargs
crt_request = mock_connection.request.call_args.args[0]
assert crt_request.body_stream is not None


async def test_send_http2_uses_body_generator() -> None:
"""Test HTTP/2 requests use the CRT body generator."""
client = AWSCRTHTTPClient()
request = HTTPRequest(
method="POST",
destination=URI(scheme="https", host="example.com", path="/"),
body=AsyncBytesReader(b"Action=Test&Version=2020-01-08"),
fields=Fields(),
)
mock_stream = Mock()
mock_response = Mock()
mock_connection = Mock()
mock_connection.version = crt_http.HttpVersion.Http2
mock_connection.request = Mock(return_value=mock_stream)

with (
patch.object(
client, "_get_connection", AsyncMock(return_value=mock_connection)
),
patch.object(client, "_await_response", AsyncMock(return_value=mock_response)),
):
actual = await client.send(request)

assert actual is mock_response
mock_connection.request.assert_called_once()
crt_request = mock_connection.request.call_args.args[0]
body_generator = mock_connection.request.call_args.kwargs["request_body_generator"]
assert crt_request.body_stream is None
assert [chunk async for chunk in body_generator] == [
b"Action=Test&Version=2020-01-08"
]


async def test_body_generator_bytes() -> None:
"""Test body generator with bytes input."""
client = AWSCRTHTTPClient()
Expand Down Expand Up @@ -191,6 +254,53 @@ async def test_body_generator_empty_bytes() -> None:
assert chunks == [b""]


async def test_body_stream_bytes() -> None:
"""Test body stream with bytes input."""
client = AWSCRTHTTPClient()

body_stream = await client._create_body_stream(b"Hello, World!")

assert body_stream is not None
assert body_stream.read() == b"Hello, World!"


async def test_body_stream_bytesio() -> None:
"""Test body stream with BytesIO."""
client = AWSCRTHTTPClient()
body = BytesIO(b"data from BytesIO")

body_stream = await client._create_body_stream(body)

assert body_stream is not None
assert body_stream is body
assert body_stream.read() == b"data from BytesIO"


async def test_body_stream_async_iterable() -> None:
"""Test body stream with custom AsyncIterable."""

async def custom_generator() -> AsyncIterator[bytes]:
yield b"chunk1"
yield b"chunk2"
yield b"chunk3"

client = AWSCRTHTTPClient()

body_stream = await client._create_body_stream(custom_generator())

assert body_stream is not None
assert body_stream.read() == b"chunk1chunk2chunk3"


async def test_body_stream_empty_bytes() -> None:
"""Test body stream with empty bytes."""
client = AWSCRTHTTPClient()

body_stream = await client._create_body_stream(b"")

assert body_stream is None


async def test_build_connection_http() -> None:
"""Test building HTTP connection."""
client = AWSCRTHTTPClient()
Expand Down
Loading