Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions src/aws-cpp-sdk-core/source/client/AWSClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,17 @@ HttpResponseOutcome AWSClient::AttemptExhaustively(const Aws::Http::URI& uri,
AWS_LOGSTREAM_WARN(AWS_CLIENT_LOG_TAG, "Request failed, now waiting " << sleepMillis << " ms before attempting again.");
if(request.GetBody())
{
if (request.GetBody()->tellg() == EOF) {
// Save checksum information from the original request if we haven't already and stream is finalized
RetryContext context = request.GetRetryContext();
if (context.m_requestHash == nullptr) {
auto originalRequestHash = httpRequest->GetRequestHash();
if (originalRequestHash.second != nullptr) {
context.m_requestHash = Aws::MakeShared<std::pair<Aws::String, std::shared_ptr<Aws::Utils::Crypto::Hash>>>(AWS_CLIENT_LOG_TAG, originalRequestHash);
request.SetRetryContext(context);
}
}
}
request.GetBody()->clear();
request.GetBody()->seekg(0);
}
Expand All @@ -397,16 +408,6 @@ HttpResponseOutcome AWSClient::AttemptExhaustively(const Aws::Http::URI& uri,
newUri.SetAuthority(newEndpoint);
}

// Save checksum information from the original request if we haven't already - safe to assume that the checksum has been finalized, since we have sent and received a response
RetryContext context = request.GetRetryContext();
if (context.m_requestHash == nullptr) {
auto originalRequestHash = httpRequest->GetRequestHash();
if (originalRequestHash.second != nullptr) {
context.m_requestHash = Aws::MakeShared<std::pair<Aws::String, std::shared_ptr<Aws::Utils::Crypto::Hash>>>(AWS_CLIENT_LOG_TAG, originalRequestHash);
request.SetRetryContext(context);
}
}

httpRequest = CreateHttpRequest(newUri, method, request.GetResponseStreamFactory());

httpRequest->SetHeaderValue(Http::SDK_INVOCATION_ID_HEADER, invocationId);
Expand Down
55 changes: 55 additions & 0 deletions tests/aws-cpp-sdk-s3-unit-tests/S3UnitTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <aws/core/auth/AWSCredentials.h>
#include <aws/core/client/RetryStrategy.h>
#include <aws/core/utils/HashingUtils.h>
#include <aws/core/utils/base64/Base64.h>
#include <aws/core/utils/crypto/CRC64.h>
#include <aws/s3/S3Client.h>
#include <aws/s3/model/DeleteObjectsRequest.h>
#include <aws/s3/model/PutObjectRequest.h>
Expand Down Expand Up @@ -609,4 +611,57 @@ TEST_F(S3UnitTest, testLegacyApi)
"SignatureV4");

EXPECT_TRUE(outcome2.IsSuccess());
}

TEST_F(S3UnitTest, PartiallyConsumedStreamChecksumReuse) {
auto request = PutObjectRequest().WithBucket("(iamthou").WithKey("thouarti");
// the body has to be over 8K as the checksum is read as we read in chunks, in this case
// we set the chunk size to 8K and we need the body to be larger than that.
const Aws::String bodyString(9216, 'a');
request.SetBody(Aws::MakeShared<StringStream>(ALLOCATION_TAG, bodyString));

const auto errorResponseStream = Aws::MakeShared<Standard::StandardHttpRequest>(ALLOCATION_TAG, "mockuri", HttpMethod::HTTP_POST);
errorResponseStream->SetResponseStreamFactory(Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
auto errorResponse = Aws::MakeShared<Standard::StandardHttpResponse>(ALLOCATION_TAG, errorResponseStream);
errorResponse->SetResponseCode(HttpResponseCode::REQUEST_TIMEOUT);
_mockHttpClient->AddResponseToReturn(
errorResponse, [](IOStream&) -> void {},
[](const std::shared_ptr<Aws::Http::HttpRequest>& request) -> void {
// Partially read the buffer, such that the request checksum ends up in a bad state.
Aws::Array<char, 12> tempBuffer;
request->GetContentBody()->read(tempBuffer.data(), 12);
});

const auto successResponseStream = Aws::MakeShared<Standard::StandardHttpRequest>(ALLOCATION_TAG, "mockuri", HttpMethod::HTTP_POST);
successResponseStream->SetResponseStreamFactory(Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
auto successResponse = Aws::MakeShared<Standard::StandardHttpResponse>(ALLOCATION_TAG, errorResponseStream);
successResponse->SetResponseCode(HttpResponseCode::OK);
_mockHttpClient->AddResponseToReturn(
successResponse, [](IOStream& stream) -> void {EXPECT_EQ(stream.tellg(), 0);}, [](const std::shared_ptr<Aws::Http::HttpRequest>& request) -> void {
Aws::Array<char, 9216> tempBuffer;
request->GetContentBody()->read(tempBuffer.data(), 9216);
});

// The top level test has a no retry policy so we have to create one that retries
const AWSCredentials credentials{"mock", "credentials"};
S3ClientConfiguration configuration;
configuration.httpClientChunkedMode = HttpClientChunkedMode::DEFAULT;
// Smallest chunk size allowed
configuration.awsChunkedBufferSize = 8192UL;
const S3Client clientWithRetries{credentials, nullptr, configuration};

const auto response = clientWithRetries.PutObject(request);
AWS_EXPECT_SUCCESS(response);

Aws::Utils::Crypto::CRC64 crc64Hash{};
const auto expectedChecksum = crc64Hash.Calculate(bodyString);
EXPECT_TRUE(expectedChecksum.IsSuccess());
const Aws::Utils::Base64::Base64 base64{};
const auto expectedChecksumBase64 = base64.Encode(expectedChecksum.GetResult());

const auto retriedRequest = _mockHttpClient->GetMostRecentHttpRequest();
const auto seenChecksum = retriedRequest.GetRequestHash().second->GetHash();
EXPECT_TRUE(seenChecksum.IsSuccess());
const auto seenChecksumBase64 = base64.Encode(seenChecksum.GetResult());
EXPECT_EQ(seenChecksumBase64, expectedChecksumBase64);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class MockHttpClient : public Aws::Http::HttpClient
{
public:
using ResponseCallbackTuple = std::pair<std::shared_ptr<Aws::Http::HttpResponse>, std::function<void (Aws::IOStream&)>>;
using ResponseAndRequestCallbackTuple = std::tuple<std::shared_ptr<Aws::Http::HttpResponse>,
std::function<void (Aws::IOStream&)>,
std::function<void (const std::shared_ptr<Aws::Http::HttpRequest>&)>>;

std::shared_ptr<Aws::Http::HttpResponse> MakeRequest(const std::shared_ptr<Aws::Http::HttpRequest>& request,
Aws::Utils::RateLimits::RateLimiterInterface* readLimiter = nullptr,
Expand All @@ -46,6 +49,16 @@ class MockHttpClient : public Aws::Http::HttpClient
}
return responseToUse.first;
}
if (!m_responseAndRequestsCallback.empty()) {
auto responseToUse = m_responseAndRequestsCallback.front();
m_responseAndRequestsCallback.pop();
if (std::get<0>(responseToUse)) {
std::get<0>(responseToUse)->SetOriginatingRequest(request);
std::get<1>(responseToUse)(std::get<0>(responseToUse)->GetResponseBody());
std::get<2>(responseToUse)(request);
}
return std::get<0>(responseToUse);
}
return Aws::MakeShared<Aws::Http::Standard::StandardHttpResponse>(MockHttpAllocationTag, request);
}

Expand All @@ -60,6 +73,11 @@ class MockHttpClient : public Aws::Http::HttpClient
//when you are finished.
void AddResponseToReturn(const std::shared_ptr<Aws::Http::HttpResponse>& response) { m_responsesToUse.emplace(response, [](Aws::IOStream&) -> void {}); }
void AddResponseToReturn(const std::shared_ptr<Aws::Http::HttpResponse>& response, const std::function<void (Aws::IOStream&)>& callbackFucntion) { m_responsesToUse.emplace(response, callbackFucntion); }
void AddResponseToReturn(const std::shared_ptr<Aws::Http::HttpResponse>& response,
const std::function<void(Aws::IOStream&)>& callbackFucntion,
const std::function<void(const std::shared_ptr<Aws::Http::HttpRequest>&)>& requestCallback) {
m_responseAndRequestsCallback.emplace(response, callbackFucntion, requestCallback);
}

void Reset()
{
Expand All @@ -68,9 +86,12 @@ class MockHttpClient : public Aws::Http::HttpClient
std::swap(m_responsesToUse, empty);
}



private:
mutable Aws::Vector<Aws::Http::Standard::StandardHttpRequest> m_requestsMade;
mutable Aws::Queue<ResponseCallbackTuple> m_responsesToUse;
mutable Aws::Queue<ResponseAndRequestCallbackTuple> m_responseAndRequestsCallback;
};

class MockHttpClientFactory : public Aws::Http::HttpClientFactory
Expand Down
Loading