Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ getClient(const S3::URI & url, const S3Settings & settings, ContextPtr context,
LOG_DEBUG(getLogger("getClient"), "Got new access tokens {} {} {}", access_key_id, secret_access_key, session_token);
}
}

auto shared_cache = S3::ClientCacheRegistry::instance().getOrCreateCacheForKey(url.endpoint, url.bucket);

return S3::ClientFactory::instance().create(
client_configuration,
client_settings,
Expand All @@ -233,7 +236,8 @@ getClient(const S3::URI & url, const S3Settings & settings, ContextPtr context,
auth_settings.server_side_encryption_kms_config,
auth_settings.getHeaders(),
credentials_configuration,
session_token);
session_token,
shared_cache);
}

}
Expand Down
94 changes: 71 additions & 23 deletions src/IO/S3/Client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <Poco/MD5Engine.h>
#include <Common/CurrentThread.h>
#include <Common/Exception.h>
#include <Common/SipHash.h>

#include <aws/core/Aws.h>
#include <aws/core/client/CoreErrors.h>
Expand Down Expand Up @@ -219,11 +220,12 @@ std::unique_ptr<Client> Client::create(
const std::shared_ptr<Aws::Auth::AWSCredentialsProvider> & credentials_provider,
const PocoHTTPClientConfiguration & client_configuration,
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads,
const ClientSettings & client_settings)
const ClientSettings & client_settings,
const std::shared_ptr<ClientCache> & shared_cache)
{
verifyClientConfiguration(client_configuration);
return std::unique_ptr<Client>(
new Client(max_redirects_, std::move(sse_kms_config_), credentials_provider, client_configuration, sign_payloads, client_settings));
new Client(max_redirects_, std::move(sse_kms_config_), credentials_provider, client_configuration, sign_payloads, client_settings, shared_cache));
}

std::unique_ptr<Client> Client::clone() const
Expand Down Expand Up @@ -258,7 +260,8 @@ Client::Client(
const std::shared_ptr<Aws::Auth::AWSCredentialsProvider> & credentials_provider_,
const PocoHTTPClientConfiguration & client_configuration_,
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads_,
const ClientSettings & client_settings_)
const ClientSettings & client_settings_,
const std::shared_ptr<ClientCache> & shared_cache)
: Aws::S3::S3Client(credentials_provider_, client_configuration_, sign_payloads_, client_settings_.use_virtual_addressing)
, credentials_provider(credentials_provider_)
, client_configuration(client_configuration_)
Expand Down Expand Up @@ -298,7 +301,10 @@ Client::Client(

detect_region = provider_type == ProviderType::AWS && explicit_region == Aws::Region::AWS_GLOBAL;

cache = std::make_shared<ClientCache>();
if (shared_cache)
cache = shared_cache;
else
cache = std::make_shared<ClientCache>();
ClientCacheRegistry::instance().registerClient(cache);

ProfileEvents::increment(ProfileEvents::S3Clients);
Expand All @@ -321,7 +327,7 @@ Client::Client(
, sse_kms_config(other.sse_kms_config)
, log(getLogger("S3Client"))
{
cache = std::make_shared<ClientCache>(*other.cache);
cache = other.cache;
ClientCacheRegistry::instance().registerClient(cache);

logConfiguration();
Expand Down Expand Up @@ -1108,37 +1114,77 @@ void ClientCache::clearCache()
void ClientCacheRegistry::registerClient(const std::shared_ptr<ClientCache> & client_cache)
{
std::lock_guard lock(clients_mutex);
auto [it, inserted] = client_caches.emplace(client_cache.get(), client_cache);
if (!inserted)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Same S3 client registered twice");
auto it = client_caches.find(client_cache.get());
if (it != client_caches.end())
{
++it->second.second;
return;
}
client_caches.emplace(client_cache.get(), std::pair{std::weak_ptr<ClientCache>(client_cache), size_t(1)});
}

void ClientCacheRegistry::unregisterClient(ClientCache * client)
{
std::lock_guard lock(clients_mutex);
auto erased = client_caches.erase(client);
if (erased == 0)
auto it = client_caches.find(client);
if (it == client_caches.end())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Can't unregister S3 client, either it was already unregistered or not registered at all");
if (--it->second.second == 0)
client_caches.erase(it);
}

void ClientCacheRegistry::clearCacheForAll()
void ClientCacheRegistry::pruneExpiredCachesLocked()
{
std::lock_guard lock(clients_mutex);
std::erase_if(cache_by_endpoint_bucket, [](const auto & pair) { return pair.second.expired(); });
}

for (auto it = client_caches.begin(); it != client_caches.end();)
std::shared_ptr<ClientCache> ClientCacheRegistry::getOrCreateCacheForKey(const std::string & endpoint, const std::string & bucket)
{
SipHash hash;
hash.update(endpoint.size());
hash.update(endpoint);
hash.update(bucket);
UInt128 key = hash.get128();

std::lock_guard lock(cache_by_key_mutex);
if (auto it = cache_by_endpoint_bucket.find(key); it != cache_by_endpoint_bucket.end())
{
if (auto locked_client = it->second.lock(); locked_client)
{
locked_client->clearCache();
++it;
}
else
if (auto cached = it->second.lock(); cached)
return cached;
cache_by_endpoint_bucket.erase(it);
}
auto cache = std::make_shared<ClientCache>();
cache_by_endpoint_bucket[key] = cache;

pruneExpiredCachesLocked();

return cache;
}

void ClientCacheRegistry::clearCacheForAll()
{
{
std::lock_guard lock(clients_mutex);

for (auto it = client_caches.begin(); it != client_caches.end();)
{
LOG_INFO(getLogger("ClientCacheRegistry"), "Deleting leftover S3 client cache");
it = client_caches.erase(it);
if (auto locked_client = it->second.first.lock(); locked_client)
{
locked_client->clearCache();
++it;
}
else
{
LOG_INFO(getLogger("ClientCacheRegistry"), "Deleting leftover S3 client cache");
it = client_caches.erase(it);
}
}
}

{
std::lock_guard lock(cache_by_key_mutex);
pruneExpiredCachesLocked();
}
}

ClientFactory::ClientFactory()
Expand Down Expand Up @@ -1183,7 +1229,8 @@ std::unique_ptr<S3::Client> ClientFactory::create( // NOLINT
ServerSideEncryptionKMSConfig sse_kms_config,
HTTPHeaderEntries headers,
CredentialsConfiguration credentials_configuration,
const String & session_token)
const String & session_token,
const std::shared_ptr<ClientCache> & shared_cache)
{
PocoHTTPClientConfiguration client_configuration = cfg_;
client_configuration.updateSchemeAndRegion();
Expand Down Expand Up @@ -1237,7 +1284,8 @@ std::unique_ptr<S3::Client> ClientFactory::create( // NOLINT
client_configuration, // Client configuration.
client_settings.is_s3express_bucket ? Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::RequestDependent
: Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
client_settings);
client_settings,
shared_cache);
}

PocoHTTPClientConfiguration ClientFactory::createClientConfiguration( // NOLINT
Expand Down
21 changes: 17 additions & 4 deletions src/IO/S3/Client.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct ServerSideEncryptionKMSConfig
#include <IO/S3/PocoHTTPClient.h>
#include <IO/S3/Credentials.h>
#include <IO/S3/ProviderType.h>
#include <Common/HashTable/Hash.h>

#include <aws/core/Aws.h>
#include <aws/core/client/DefaultRetryStrategy.h>
Expand Down Expand Up @@ -77,11 +78,17 @@ class ClientCacheRegistry
void registerClient(const std::shared_ptr<ClientCache> & client_cache);
void unregisterClient(ClientCache * client);
void clearCacheForAll();
std::shared_ptr<ClientCache> getOrCreateCacheForKey(const std::string & endpoint, const std::string & bucket);

private:
ClientCacheRegistry() = default;

void pruneExpiredCachesLocked() TSA_REQUIRES(cache_by_key_mutex);

std::mutex clients_mutex;
std::unordered_map<ClientCache *, std::weak_ptr<ClientCache>> client_caches TSA_GUARDED_BY(clients_mutex);
std::unordered_map<ClientCache *, std::pair<std::weak_ptr<ClientCache>, size_t>> client_caches TSA_GUARDED_BY(clients_mutex);
std::mutex cache_by_key_mutex;
std::unordered_map<UInt128, std::weak_ptr<ClientCache>, UInt128Hash> cache_by_endpoint_bucket TSA_GUARDED_BY(cache_by_key_mutex);
};

bool isS3ExpressEndpoint(const std::string & endpoint);
Expand Down Expand Up @@ -128,7 +135,8 @@ class Client : private Aws::S3::S3Client
const std::shared_ptr<Aws::Auth::AWSCredentialsProvider> & credentials_provider,
const PocoHTTPClientConfiguration & client_configuration,
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads,
const ClientSettings & client_settings);
const ClientSettings & client_settings,
const std::shared_ptr<ClientCache> & shared_cache = nullptr);

std::unique_ptr<Client> clone() const;

Expand Down Expand Up @@ -240,14 +248,18 @@ class Client : private Aws::S3::S3Client

const PocoHTTPClientConfiguration & getClientConfiguration() const { return client_configuration; }

/// For testing purposes only
ClientCache * getRawCache() const { return cache.get(); }

protected:
// visible for testing
Client(size_t max_redirects_,
ServerSideEncryptionKMSConfig sse_kms_config_,
const std::shared_ptr<Aws::Auth::AWSCredentialsProvider> & credentials_provider_,
const PocoHTTPClientConfiguration & client_configuration,
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads,
const ClientSettings & client_settings_);
const ClientSettings & client_settings_,
const std::shared_ptr<ClientCache> & shared_cache = nullptr);

private:
Client(
Expand Down Expand Up @@ -346,7 +358,8 @@ class ClientFactory
ServerSideEncryptionKMSConfig sse_kms_config,
HTTPHeaderEntries headers,
CredentialsConfiguration credentials_configuration,
const String & session_token = "");
const String & session_token = "",
const std::shared_ptr<ClientCache> & shared_cache = nullptr);

PocoHTTPClientConfiguration createClientConfiguration(
const String & force_region,
Expand Down
124 changes: 124 additions & 0 deletions src/IO/S3/tests/gtest_aws_s3_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,4 +539,128 @@ TEST(IOTestAwsS3Client, AssumeRole)
}
}

TEST(IOTestAwsS3Client, ClientCacheRegistryGetOrCreateCacheForKey)
{
auto & registry = DB::S3::ClientCacheRegistry::instance();

std::shared_ptr<DB::S3::ClientCache> cache_ab1 = registry.getOrCreateCacheForKey("endpoint1", "bucket1");
std::shared_ptr<DB::S3::ClientCache> cache_ab2 = registry.getOrCreateCacheForKey("endpoint1", "bucket1");
EXPECT_EQ(cache_ab1.get(), cache_ab2.get()) << "Same (endpoint, bucket) should return the same cache";

std::shared_ptr<DB::S3::ClientCache> cache_b1 = registry.getOrCreateCacheForKey("endpoint1", "bucket2");
EXPECT_NE(cache_ab1.get(), cache_b1.get()) << "Different bucket should return different cache";

std::shared_ptr<DB::S3::ClientCache> cache_e2 = registry.getOrCreateCacheForKey("endpoint2", "bucket1");
EXPECT_NE(cache_ab1.get(), cache_e2.get()) << "Different endpoint should return different cache";

auto cache_concat1 = registry.getOrCreateCacheForKey("ab", "c");
auto cache_concat2 = registry.getOrCreateCacheForKey("a", "bc");
EXPECT_NE(cache_concat1.get(), cache_concat2.get())
<< "Pairs with identical concatenation but different boundary must not share a cache";
}

TEST(IOTestAwsS3Client, ClientSharesCacheWithClone)
{
DB::RemoteHostFilter remote_host_filter;
DB::S3::URI uri("https://s3.eu-central-1.amazonaws.com/my-bucket/key");
DB::S3::PocoHTTPClientConfiguration client_configuration = DB::S3::ClientFactory::instance().createClientConfiguration(
"eu-central-1",
remote_host_filter,
10,
DB::S3::PocoHTTPClientConfiguration::RetryStrategy{.max_retries = 0},
true,
true,
false,
false,
{},
{},
"https");
client_configuration.endpointOverride = uri.endpoint;

DB::S3::ClientSettings client_settings{
.use_virtual_addressing = uri.is_virtual_hosted_style,
.disable_checksum = false,
.gcs_issue_compose_request = false,
.is_s3express_bucket = false,
};

auto shared_cache = DB::S3::ClientCacheRegistry::instance().getOrCreateCacheForKey(uri.endpoint, uri.bucket);
std::unique_ptr<DB::S3::Client> client = DB::S3::ClientFactory::instance().create(
client_configuration,
client_settings,
"access",
"secret",
"",
{},
{},
DB::S3::CredentialsConfiguration{.use_environment_credentials = false, .use_insecure_imds_request = false},
"",
shared_cache);

ASSERT_TRUE(client);
std::unique_ptr<DB::S3::Client> clone = client->clone();
ASSERT_TRUE(clone);

EXPECT_EQ(client->getRawCache(), shared_cache.get()) << "Client should use the shared cache";
EXPECT_EQ(clone->getRawCache(), client->getRawCache()) << "Clone should share the same cache as original";
}

TEST(IOTestAwsS3Client, TwoClientsWithSharedCacheUnregisterRefcount)
{
DB::RemoteHostFilter remote_host_filter;
DB::S3::URI uri("https://s3.us-east-1.amazonaws.com/another-bucket/key");
DB::S3::PocoHTTPClientConfiguration client_configuration = DB::S3::ClientFactory::instance().createClientConfiguration(
"us-east-1",
remote_host_filter,
10,
DB::S3::PocoHTTPClientConfiguration::RetryStrategy{.max_retries = 0},
true,
true,
false,
false,
{},
{},
"https");
client_configuration.endpointOverride = uri.endpoint;

DB::S3::ClientSettings client_settings{
.use_virtual_addressing = uri.is_virtual_hosted_style,
.disable_checksum = false,
.gcs_issue_compose_request = false,
.is_s3express_bucket = false,
};

auto shared_cache = DB::S3::ClientCacheRegistry::instance().getOrCreateCacheForKey(uri.endpoint, uri.bucket);
std::unique_ptr<DB::S3::Client> client1 = DB::S3::ClientFactory::instance().create(
client_configuration,
client_settings,
"ak",
"sk",
"",
{},
{},
DB::S3::CredentialsConfiguration{.use_environment_credentials = false, .use_insecure_imds_request = false},
"",
shared_cache);
std::unique_ptr<DB::S3::Client> client2 = DB::S3::ClientFactory::instance().create(
client_configuration,
client_settings,
"ak",
"sk",
"",
{},
{},
DB::S3::CredentialsConfiguration{.use_environment_credentials = false, .use_insecure_imds_request = false},
"",
shared_cache);

ASSERT_TRUE(client1);
ASSERT_TRUE(client2);
EXPECT_EQ(client1->getRawCache(), client2->getRawCache());

client1.reset();
client2.reset();
// If refcount was wrong, unregisterClient would throw when the second client is destroyed
}

#endif
Loading