From bd2c6012e2362e3ec022576adab562a89e146cf1 Mon Sep 17 00:00:00 2001 From: Ray Yan Date: Sat, 9 May 2026 18:29:14 +0800 Subject: [PATCH 1/6] *: fix reload tls certs in reconnect and fix region cache Signed-off-by: Ray Yan --- src/common/security.rs | 92 +++++++++++++++---- src/pd/client.rs | 73 +++++++++++++-- src/request/mod.rs | 198 ++++++++++++++++++++++++++++++++++++++++- src/request/plan.rs | 19 ++-- src/store/client.rs | 8 ++ 5 files changed, 353 insertions(+), 37 deletions(-) diff --git a/src/common/security.rs b/src/common/security.rs index 89e074b3..9c038151 100644 --- a/src/common/security.rs +++ b/src/common/security.rs @@ -43,12 +43,12 @@ fn load_pem_file(tag: &str, path: &Path) -> Result> { /// Manages the TLS protocol #[derive(Default)] pub struct SecurityManager { - /// The PEM encoding of the server’s CA certificates. - ca: Vec, - /// The PEM encoding of the server’s certificate chain. - cert: Vec, + /// The path to the PEM encoding of the server’s CA certificates. + ca_path: Option, + /// The path to the PEM encoding of the server’s certificate chain. + cert_path: Option, /// The path to the file that contains the PEM encoding of the server’s private key. - key: PathBuf, + key_path: Option, } impl SecurityManager { @@ -58,15 +58,23 @@ impl SecurityManager { cert_path: impl AsRef, key_path: impl Into, ) -> Result { + let ca_path = ca_path.as_ref().to_path_buf(); + let cert_path = cert_path.as_ref().to_path_buf(); let key_path = key_path.into(); + check_pem_file("ca", &ca_path)?; + check_pem_file("certificate", &cert_path)?; check_pem_file("private key", &key_path)?; Ok(SecurityManager { - ca: load_pem_file("ca", ca_path.as_ref())?, - cert: load_pem_file("certificate", cert_path.as_ref())?, - key: key_path, + ca_path: Some(ca_path), + cert_path: Some(cert_path), + key_path: Some(key_path), }) } + pub(crate) fn tls_configured(&self) -> bool { + self.ca_path.is_some() + } + /// Connect to gRPC server using TLS connection. If TLS is not configured, use normal connection. pub async fn connect( &self, @@ -78,7 +86,7 @@ impl SecurityManager { Factory: FnOnce(Channel) -> Client, { info!("connect to rpc server at endpoint: {:?}", addr); - let channel = if !self.ca.is_empty() { + let channel = if self.tls_configured() { self.tls_channel(addr).await? } else { self.default_channel(addr).await? @@ -89,18 +97,37 @@ impl SecurityManager { } async fn tls_channel(&self, addr: &str) -> Result { + let (ca, cert, key) = self.load_tls_materials()?; let addr = "https://".to_string() + &SCHEME_REG.replace(addr, ""); let builder = self.endpoint(addr.to_string())?; let tls = ClientTlsConfig::new() - .ca_certificate(Certificate::from_pem(&self.ca)) - .identity(Identity::from_pem( - &self.cert, - load_pem_file("private key", &self.key)?, - )); + .ca_certificate(Certificate::from_pem(ca)) + .identity(Identity::from_pem(cert, key)); let builder = builder.tls_config(tls)?; Ok(builder) } + fn load_tls_materials(&self) -> Result<(Vec, Vec, Vec)> { + let ca_path = self + .ca_path + .as_ref() + .ok_or_else(|| internal_err!("TLS is not configured"))?; + let cert_path = self + .cert_path + .as_ref() + .ok_or_else(|| internal_err!("TLS is not configured"))?; + let key_path = self + .key_path + .as_ref() + .ok_or_else(|| internal_err!("TLS is not configured"))?; + + Ok(( + load_pem_file("ca", ca_path)?, + load_pem_file("certificate", cert_path)?, + load_pem_file("private key", key_path)?, + )) + } + async fn default_channel(&self, addr: &str) -> Result { let addr = "http://".to_string() + &SCHEME_REG.replace(addr, ""); self.endpoint(addr) @@ -140,9 +167,40 @@ mod tests { let key_path: PathBuf = format!("{}", example_pem.display()).into(); let ca_path: PathBuf = format!("{}", example_ca.display()).into(); let mgr = SecurityManager::load(ca_path, cert_path, &key_path).unwrap(); - assert_eq!(mgr.ca, vec![0]); - assert_eq!(mgr.cert, vec![1]); - let key = load_pem_file("private key", &key_path).unwrap(); + assert!(mgr.tls_configured()); + let (ca, cert, key) = mgr.load_tls_materials().unwrap(); + assert_eq!(ca, vec![0]); + assert_eq!(cert, vec![1]); assert_eq!(key, vec![2]); } + + #[test] + fn test_security_reload() { + let temp = tempfile::tempdir().unwrap(); + let example_ca = temp.path().join("ca"); + let example_cert = temp.path().join("cert"); + let example_pem = temp.path().join("key"); + for (id, f) in [&example_ca, &example_cert, &example_pem] + .iter() + .enumerate() + { + File::create(f).unwrap().write_all(&[id as u8]).unwrap(); + } + + let mgr = SecurityManager::load(&example_ca, &example_cert, &example_pem).unwrap(); + let first = mgr.load_tls_materials().unwrap(); + + File::create(&example_ca).unwrap().write_all(&[9]).unwrap(); + File::create(&example_cert) + .unwrap() + .write_all(&[8]) + .unwrap(); + File::create(&example_pem).unwrap().write_all(&[7]).unwrap(); + + let second = mgr.load_tls_materials().unwrap(); + assert_ne!(first, second); + assert_eq!(second.0, vec![9]); + assert_eq!(second.1, vec![8]); + assert_eq!(second.2, vec![7]); + } } diff --git a/src/pd/client.rs b/src/pd/client.rs index 05b9c07c..1155dcf1 100644 --- a/src/pd/client.rs +++ b/src/pd/client.rs @@ -338,16 +338,21 @@ impl PdRpcClient { } async fn kv_client(&self, address: &str) -> Result { - if let Some(client) = self.kv_client_cache.read().await.get(address) { - return Ok(client.clone()); + let cache_connections = self.kv_connect.cache_connections(); + if cache_connections { + if let Some(client) = self.kv_client_cache.read().await.get(address) { + return Ok(client.clone()); + } }; info!("connect to tikv endpoint: {:?}", address); match self.kv_connect.connect(address).await { Ok(client) => { - self.kv_client_cache - .write() - .await - .insert(address.to_owned(), client.clone()); + if cache_connections { + self.kv_client_cache + .write() + .await + .insert(address.to_owned(), client.clone()); + } Ok(client) } Err(e) => Err(e), @@ -364,11 +369,18 @@ fn make_key_range(start_key: Vec, end_key: Vec) -> kvrpcpb::KeyRange { #[cfg(test)] pub mod test { + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + use async_trait::async_trait; use futures::executor; use futures::executor::block_on; use super::*; use crate::mock::*; + use crate::pd::RetryClient; + use crate::store::KvConnect; + use crate::Config; #[tokio::test] async fn test_kv_client_caching() { @@ -384,6 +396,55 @@ pub mod test { assert_eq!(kv2.addr, kv3.addr); } + #[tokio::test] + async fn test_kv_client_reloadable_connections_are_not_cached() { + #[derive(Clone)] + struct CountingConnect { + connects: Arc, + } + + #[async_trait] + impl KvConnect for CountingConnect { + type KvClient = MockKvClient; + + async fn connect(&self, address: &str) -> Result { + self.connects.fetch_add(1, Ordering::SeqCst); + let mut client = MockKvClient::default(); + client.addr = address.to_owned(); + Ok(client) + } + + fn cache_connections(&self) -> bool { + false + } + } + + let connects = Arc::new(AtomicUsize::new(0)); + let connects_clone = connects.clone(); + let client = PdRpcClient::new( + Config::default(), + move |_| CountingConnect { + connects: connects_clone.clone(), + }, + |sm| async move { + Ok(RetryClient::new_with_cluster( + sm, + Config::default().timeout, + MockCluster, + )) + }, + false, + ) + .await + .unwrap(); + + let kv1 = client.kv_client("foo").await.unwrap(); + let kv2 = client.kv_client("foo").await.unwrap(); + assert_eq!(kv1.addr, "foo"); + assert_eq!(kv2.addr, "foo"); + assert_eq!(connects.load(Ordering::SeqCst), 2); + } + #[test] fn test_group_keys_by_region() { let client = MockPdClient::default(); diff --git a/src/request/mod.rs b/src/request/mod.rs index c8fd07be..0cde241b 100644 --- a/src/request/mod.rs +++ b/src/request/mod.rs @@ -90,21 +90,25 @@ impl RetryOptions { mod test { use std::any::Any; use std::iter; - use std::sync::atomic::AtomicUsize; + use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; + use async_trait::async_trait; use tonic::transport::Channel; use super::*; use crate::mock::MockKvClient; use crate::mock::MockPdClient; + use crate::proto::keyspacepb; use crate::proto::kvrpcpb; + use crate::proto::metapb::{self, RegionEpoch}; use crate::proto::pdpb::Timestamp; use crate::proto::tikvpb::tikv_client::TikvClient; - use crate::region::RegionWithLeader; + use crate::region::{RegionId, RegionVerId, RegionWithLeader, StoreId}; use crate::store::region_stream_for_keys; use crate::store::HasRegionError; + use crate::store::{RegionStore, Store}; use crate::transaction::lowering::new_commit_request; use crate::Error; use crate::Key; @@ -206,6 +210,196 @@ mod test { assert_eq!(invoking_count.load(std::sync::atomic::Ordering::SeqCst), 4); } + #[tokio::test] + async fn test_region_store_mapping_retry() { + #[derive(Debug, Clone)] + struct MockOkResponse; + + impl HasKeyErrors for MockOkResponse { + fn key_errors(&mut self) -> Option> { + None + } + } + + impl HasRegionError for MockOkResponse { + fn region_error(&mut self) -> Option { + None + } + } + + impl HasLocks for MockOkResponse {} + + struct FlakyStoreMappingPdClient { + client: MockKvClient, + invalidated: AtomicBool, + invalidation_count: AtomicUsize, + } + + impl FlakyStoreMappingPdClient { + fn region(store_id: StoreId) -> RegionWithLeader { + let mut region = RegionWithLeader::default(); + region.region.id = 1; + region.region.start_key = vec![]; + region.region.end_key = vec![]; + region.region.region_epoch = Some(RegionEpoch { + conf_ver: 0, + version: 0, + }); + region.leader = Some(metapb::Peer { + store_id, + ..Default::default() + }); + region + } + } + + #[async_trait] + impl crate::pd::PdClient for FlakyStoreMappingPdClient { + type KvClient = MockKvClient; + + async fn map_region_to_store( + self: Arc, + region: RegionWithLeader, + ) -> Result { + match region.get_store_id()? { + 41 => Err(Error::InternalError { + message: "invalid store ID 41, not found".to_owned(), + }), + _ => Ok(RegionStore::new(region, Arc::new(self.client.clone()))), + } + } + + async fn region_for_key(&self, _: &Key) -> Result { + let store_id = if self.invalidated.load(Ordering::SeqCst) { + 42 + } else { + 41 + }; + Ok(Self::region(store_id)) + } + + async fn region_for_id(&self, id: RegionId) -> Result { + match id { + 1 => self.region_for_key(&Key::EMPTY).await, + _ => Err(Error::RegionNotFoundInResponse { region_id: id }), + } + } + + async fn all_stores(&self) -> Result> { + Ok(vec![Store::new(Arc::new(self.client.clone()))]) + } + + async fn get_timestamp(self: Arc) -> Result { + Ok(Timestamp::default()) + } + + async fn update_safepoint(self: Arc, _safepoint: u64) -> Result { + unimplemented!() + } + + async fn load_keyspace(&self, _keyspace: &str) -> Result { + unimplemented!() + } + + async fn update_leader( + &self, + _ver_id: RegionVerId, + _leader: metapb::Peer, + ) -> Result<()> { + Ok(()) + } + + async fn invalidate_region_cache(&self, _ver_id: RegionVerId) { + self.invalidated.store(true, Ordering::SeqCst); + self.invalidation_count.fetch_add(1, Ordering::SeqCst); + } + + async fn invalidate_store_cache(&self, _store_id: StoreId) {} + } + + #[derive(Clone)] + struct MockKvRequest { + shard_invoking_count: Arc, + } + + #[async_trait] + impl Request for MockKvRequest { + async fn dispatch(&self, _: &TikvClient, _: Duration) -> Result> { + Ok(Box::new(MockOkResponse)) + } + + fn label(&self) -> &'static str { + "mock" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn set_leader(&mut self, _: &RegionWithLeader) -> Result<()> { + Ok(()) + } + + fn set_api_version(&mut self, _: kvrpcpb::ApiVersion) {} + } + + #[async_trait] + impl KvRequest for MockKvRequest { + type Response = MockOkResponse; + } + + impl Shardable for MockKvRequest { + type Shard = Vec>; + + fn shards( + &self, + pd_client: &Arc, + ) -> futures::stream::BoxStream< + 'static, + crate::Result<(Self::Shard, crate::region::RegionWithLeader)>, + > { + self.shard_invoking_count.fetch_add(1, Ordering::SeqCst); + region_stream_for_keys( + Some(Key::from("mock_key".to_owned())).into_iter(), + pd_client.clone(), + ) + } + + fn apply_shard(&mut self, _shard: Self::Shard) {} + + fn apply_store(&mut self, _store: &crate::store::RegionStore) -> crate::Result<()> { + Ok(()) + } + } + + let dispatch_count = Arc::new(AtomicUsize::new(0)); + let shard_invoking_count = Arc::new(AtomicUsize::new(0)); + let dispatch_count_clone = dispatch_count.clone(); + + let pd_client = Arc::new(FlakyStoreMappingPdClient { + client: MockKvClient::with_dispatch_hook(move |_: &dyn Any| { + dispatch_count_clone.fetch_add(1, Ordering::SeqCst); + Ok(Box::new(MockOkResponse) as Box) + }), + invalidated: AtomicBool::new(false), + invalidation_count: AtomicUsize::new(0), + }); + + let request = MockKvRequest { + shard_invoking_count: shard_invoking_count.clone(), + }; + + let plan = crate::request::PlanBuilder::new(pd_client.clone(), Keyspace::Disable, request) + .retry_multi_region(Backoff::no_jitter_backoff(1, 1, 3)) + .plan(); + + let response = plan.execute().await; + assert!(response.is_ok()); + assert_eq!(dispatch_count.load(Ordering::SeqCst), 1); + assert_eq!(shard_invoking_count.load(Ordering::SeqCst), 2); + assert_eq!(pd_client.invalidation_count.load(Ordering::SeqCst), 1); + } + #[tokio::test] async fn test_extract_error() { let pd_client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( diff --git a/src/request/plan.rs b/src/request/plan.rs index 8bd15bb5..c8acb5da 100644 --- a/src/request/plan.rs +++ b/src/request/plan.rs @@ -163,6 +163,8 @@ where preserve_region_results: bool, ) -> Result<::Result> { debug!("single_shard_handler"); + let region_ver_id = region.ver_id(); + let store_id = region.get_store_id().ok(); let region_store = match pd_client .clone() .map_region_to_store(region) @@ -172,27 +174,20 @@ where Ok(region_store) }) { Ok(region_store) => region_store, - Err(Error::LeaderNotFound { region }) => { - debug!( - "single_shard_handler::sharding: leader not found: {:?}", - region - ); + Err(err) => { + debug!("single_shard_handler::sharding, error: {:?}", err); return Self::handle_other_error( pd_client, plan, - region.clone(), - None, + region_ver_id, + store_id, backoff, permits, preserve_region_results, - Error::LeaderNotFound { region }, + err, ) .await; } - Err(err) => { - debug!("single_shard_handler::sharding, error: {:?}", err); - return Err(err); - } }; // limit concurrent requests diff --git a/src/store/client.rs b/src/store/client.rs index 1c873285..34d3b466 100644 --- a/src/store/client.rs +++ b/src/store/client.rs @@ -20,6 +20,10 @@ pub trait KvConnect: Sized + Send + Sync + 'static { type KvClient: KvClient + Clone + Send + Sync + 'static; async fn connect(&self, address: &str) -> Result; + + fn cache_connections(&self) -> bool { + true + } } #[derive(new, Clone)] @@ -43,6 +47,10 @@ impl KvConnect for TikvConnect { .await .map(|c| KvRpcClient::new(c, self.timeout)) } + + fn cache_connections(&self) -> bool { + !self.security_mgr.tls_configured() + } } #[async_trait] From 5e975b087d502fa769ad1d489f3885fa6526123a Mon Sep 17 00:00:00 2001 From: Ray Yan Date: Sat, 9 May 2026 19:00:54 +0800 Subject: [PATCH 2/6] support client cache in tls Signed-off-by: Ray Yan --- src/common/security.rs | 31 +++++++++++++++ src/pd/client.rs | 89 +++++++++++++++++++++++++++++++++++------- src/store/client.rs | 8 ++-- 3 files changed, 110 insertions(+), 18 deletions(-) diff --git a/src/common/security.rs b/src/common/security.rs index 9c038151..7fe41328 100644 --- a/src/common/security.rs +++ b/src/common/security.rs @@ -1,10 +1,15 @@ // Copyright 2018 TiKV Project Authors. Licensed under Apache-2.0. +use std::collections::hash_map::DefaultHasher; +use std::fs; use std::fs::File; +use std::hash::Hash; +use std::hash::Hasher; use std::io::Read; use std::path::Path; use std::path::PathBuf; use std::time::Duration; +use std::time::SystemTime; use log::info; use regex::Regex; @@ -75,6 +80,18 @@ impl SecurityManager { self.ca_path.is_some() } + pub(crate) fn connection_cache_key(&self) -> Result> { + if !self.tls_configured() { + return Ok(None); + } + + let mut hasher = DefaultHasher::new(); + file_signature(self.ca_path.as_ref().expect("tls_configured checked"))?.hash(&mut hasher); + file_signature(self.cert_path.as_ref().expect("tls_configured checked"))?.hash(&mut hasher); + file_signature(self.key_path.as_ref().expect("tls_configured checked"))?.hash(&mut hasher); + Ok(Some(hasher.finish())) + } + /// Connect to gRPC server using TLS connection. If TLS is not configured, use normal connection. pub async fn connect( &self, @@ -141,6 +158,17 @@ impl SecurityManager { } } +fn file_signature(path: &Path) -> Result<(u64, Option)> { + let metadata = fs::metadata(path) + .map_err(|e| internal_err!("failed to stat {}: {:?}", path.display(), e))?; + let modified = metadata.modified().ok().and_then(|t: SystemTime| { + t.duration_since(SystemTime::UNIX_EPOCH) + .ok() + .map(|d| d.as_nanos()) + }); + Ok((metadata.len(), modified)) +} + #[cfg(test)] mod tests { use std::fs::File; @@ -189,6 +217,7 @@ mod tests { let mgr = SecurityManager::load(&example_ca, &example_cert, &example_pem).unwrap(); let first = mgr.load_tls_materials().unwrap(); + let key1 = mgr.connection_cache_key().unwrap(); File::create(&example_ca).unwrap().write_all(&[9]).unwrap(); File::create(&example_cert) @@ -198,9 +227,11 @@ mod tests { File::create(&example_pem).unwrap().write_all(&[7]).unwrap(); let second = mgr.load_tls_materials().unwrap(); + let key2 = mgr.connection_cache_key().unwrap(); assert_ne!(first, second); assert_eq!(second.0, vec![9]); assert_eq!(second.1, vec![8]); assert_eq!(second.2, vec![7]); + assert_ne!(key1, key2); } } diff --git a/src/pd/client.rs b/src/pd/client.rs index 1155dcf1..10d6b3a8 100644 --- a/src/pd/client.rs +++ b/src/pd/client.rs @@ -214,11 +214,17 @@ pub trait PdClient: Send + Sync + 'static { pub struct PdRpcClient { pd: Arc>, kv_connect: KvC, - kv_client_cache: Arc>>, + kv_client_cache: Arc>>>, enable_codec: bool, region_cache: RegionCache>, } +#[derive(Clone)] +struct CachedKvClient { + cache_key: Option, + client: Client, +} + #[async_trait] impl PdClient for PdRpcClient { type KvClient = KvC::KvClient; @@ -338,21 +344,22 @@ impl PdRpcClient { } async fn kv_client(&self, address: &str) -> Result { - let cache_connections = self.kv_connect.cache_connections(); - if cache_connections { - if let Some(client) = self.kv_client_cache.read().await.get(address) { - return Ok(client.clone()); + let cache_key = self.kv_connect.connection_cache_key()?; + if let Some(cached) = self.kv_client_cache.read().await.get(address) { + if cached.cache_key == cache_key { + return Ok(cached.client.clone()); } }; info!("connect to tikv endpoint: {:?}", address); match self.kv_connect.connect(address).await { Ok(client) => { - if cache_connections { - self.kv_client_cache - .write() - .await - .insert(address.to_owned(), client.clone()); - } + self.kv_client_cache.write().await.insert( + address.to_owned(), + CachedKvClient { + cache_key, + client: client.clone(), + }, + ); Ok(client) } Err(e) => Err(e), @@ -397,10 +404,60 @@ pub mod test { } #[tokio::test] - async fn test_kv_client_reloadable_connections_are_not_cached() { + async fn test_kv_client_cache_hits_when_key_is_stable() { + #[derive(Clone)] + struct CountingConnect { + connects: Arc, + } + + #[async_trait] + impl KvConnect for CountingConnect { + type KvClient = MockKvClient; + + async fn connect(&self, address: &str) -> Result { + self.connects.fetch_add(1, Ordering::SeqCst); + let mut client = MockKvClient::default(); + client.addr = address.to_owned(); + Ok(client) + } + + fn connection_cache_key(&self) -> Result> { + Ok(Some(0)) + } + } + + let connects = Arc::new(AtomicUsize::new(0)); + let connects_clone = connects.clone(); + let client = PdRpcClient::new( + Config::default(), + move |_| CountingConnect { + connects: connects_clone.clone(), + }, + |sm| async move { + Ok(RetryClient::new_with_cluster( + sm, + Config::default().timeout, + MockCluster, + )) + }, + false, + ) + .await + .unwrap(); + + let kv1 = client.kv_client("foo").await.unwrap(); + let kv2 = client.kv_client("foo").await.unwrap(); + assert_eq!(kv1.addr, "foo"); + assert_eq!(kv2.addr, "foo"); + assert_eq!(connects.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_kv_client_cache_invalidate_on_key_change() { #[derive(Clone)] struct CountingConnect { connects: Arc, + cache_key: Arc, } #[async_trait] @@ -414,17 +471,20 @@ pub mod test { Ok(client) } - fn cache_connections(&self) -> bool { - false + fn connection_cache_key(&self) -> Result> { + Ok(Some(self.cache_key.load(Ordering::SeqCst) as u64)) } } let connects = Arc::new(AtomicUsize::new(0)); + let cache_key = Arc::new(AtomicUsize::new(1)); let connects_clone = connects.clone(); + let cache_key_clone = cache_key.clone(); let client = PdRpcClient::new( Config::default(), move |_| CountingConnect { connects: connects_clone.clone(), + cache_key: cache_key_clone.clone(), }, |sm| async move { Ok(RetryClient::new_with_cluster( @@ -439,6 +499,7 @@ pub mod test { .unwrap(); let kv1 = client.kv_client("foo").await.unwrap(); + cache_key.store(2, Ordering::SeqCst); let kv2 = client.kv_client("foo").await.unwrap(); assert_eq!(kv1.addr, "foo"); assert_eq!(kv2.addr, "foo"); diff --git a/src/store/client.rs b/src/store/client.rs index 34d3b466..5a3163f8 100644 --- a/src/store/client.rs +++ b/src/store/client.rs @@ -21,8 +21,8 @@ pub trait KvConnect: Sized + Send + Sync + 'static { async fn connect(&self, address: &str) -> Result; - fn cache_connections(&self) -> bool { - true + fn connection_cache_key(&self) -> Result> { + Ok(None) } } @@ -48,8 +48,8 @@ impl KvConnect for TikvConnect { .map(|c| KvRpcClient::new(c, self.timeout)) } - fn cache_connections(&self) -> bool { - !self.security_mgr.tls_configured() + fn connection_cache_key(&self) -> Result> { + self.security_mgr.connection_cache_key() } } From 29f4ff1ace390d703c77a324466aeddc51ac5e49 Mon Sep 17 00:00:00 2001 From: Ray Yan Date: Tue, 12 May 2026 11:16:43 +0800 Subject: [PATCH 3/6] fix ci Signed-off-by: Ray Yan --- .cargo/config.toml | 2 ++ src/common/security.rs | 18 ++++++++++++------ src/pd/client.rs | 2 +- 3 files changed, 15 insertions(+), 7 deletions(-) create mode 100644 .cargo/config.toml diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..4a6a1abd --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[resolver] +incompatible-rust-versions = "fallback" diff --git a/src/common/security.rs b/src/common/security.rs index 7fe41328..ae341f00 100644 --- a/src/common/security.rs +++ b/src/common/security.rs @@ -219,19 +219,25 @@ mod tests { let first = mgr.load_tls_materials().unwrap(); let key1 = mgr.connection_cache_key().unwrap(); - File::create(&example_ca).unwrap().write_all(&[9]).unwrap(); + File::create(&example_ca) + .unwrap() + .write_all(&[9, 9]) + .unwrap(); File::create(&example_cert) .unwrap() - .write_all(&[8]) + .write_all(&[8, 8, 8]) + .unwrap(); + File::create(&example_pem) + .unwrap() + .write_all(&[7, 7, 7, 7]) .unwrap(); - File::create(&example_pem).unwrap().write_all(&[7]).unwrap(); let second = mgr.load_tls_materials().unwrap(); let key2 = mgr.connection_cache_key().unwrap(); assert_ne!(first, second); - assert_eq!(second.0, vec![9]); - assert_eq!(second.1, vec![8]); - assert_eq!(second.2, vec![7]); + assert_eq!(second.0, vec![9, 9]); + assert_eq!(second.1, vec![8, 8, 8]); + assert_eq!(second.2, vec![7, 7, 7, 7]); assert_ne!(key1, key2); } } diff --git a/src/pd/client.rs b/src/pd/client.rs index 10d6b3a8..27fef794 100644 --- a/src/pd/client.rs +++ b/src/pd/client.rs @@ -349,7 +349,7 @@ impl PdRpcClient { if cached.cache_key == cache_key { return Ok(cached.client.clone()); } - }; + } info!("connect to tikv endpoint: {:?}", address); match self.kv_connect.connect(address).await { Ok(client) => { From 7c798eb55405e0616fc678258800bb5c98bd7a1c Mon Sep 17 00:00:00 2001 From: Ray Yan Date: Tue, 12 May 2026 18:07:08 +0800 Subject: [PATCH 4/6] fix findings Signed-off-by: Ray Yan --- Cargo.toml | 2 +- src/common/security.rs | 28 +++++++++++++++++----------- src/pd/client.rs | 30 +++++++++++++++++------------- src/store/client.rs | 6 +++--- 4 files changed, 38 insertions(+), 28 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0b029bca..bf2e80b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,7 @@ serde_derive = "1.0" serde_json = "1" take_mut = "0.2.2" thiserror = "1" -tokio = { version = "1", features = ["sync", "rt-multi-thread", "macros"] } +tokio = { version = "1", features = ["sync", "rt-multi-thread", "macros", "fs"] } tonic = { version = "0.10", features = ["tls", "gzip"] } [dev-dependencies] diff --git a/src/common/security.rs b/src/common/security.rs index ae341f00..8309602f 100644 --- a/src/common/security.rs +++ b/src/common/security.rs @@ -1,7 +1,6 @@ // Copyright 2018 TiKV Project Authors. Licensed under Apache-2.0. use std::collections::hash_map::DefaultHasher; -use std::fs; use std::fs::File; use std::hash::Hash; use std::hash::Hasher; @@ -80,15 +79,21 @@ impl SecurityManager { self.ca_path.is_some() } - pub(crate) fn connection_cache_key(&self) -> Result> { + pub(crate) async fn connection_cache_key(&self) -> Result> { if !self.tls_configured() { return Ok(None); } let mut hasher = DefaultHasher::new(); - file_signature(self.ca_path.as_ref().expect("tls_configured checked"))?.hash(&mut hasher); - file_signature(self.cert_path.as_ref().expect("tls_configured checked"))?.hash(&mut hasher); - file_signature(self.key_path.as_ref().expect("tls_configured checked"))?.hash(&mut hasher); + file_signature(self.ca_path.as_ref().expect("tls_configured checked")) + .await? + .hash(&mut hasher); + file_signature(self.cert_path.as_ref().expect("tls_configured checked")) + .await? + .hash(&mut hasher); + file_signature(self.key_path.as_ref().expect("tls_configured checked")) + .await? + .hash(&mut hasher); Ok(Some(hasher.finish())) } @@ -158,8 +163,9 @@ impl SecurityManager { } } -fn file_signature(path: &Path) -> Result<(u64, Option)> { - let metadata = fs::metadata(path) +async fn file_signature(path: &Path) -> Result<(u64, Option)> { + let metadata = tokio::fs::metadata(path) + .await .map_err(|e| internal_err!("failed to stat {}: {:?}", path.display(), e))?; let modified = metadata.modified().ok().and_then(|t: SystemTime| { t.duration_since(SystemTime::UNIX_EPOCH) @@ -202,8 +208,8 @@ mod tests { assert_eq!(key, vec![2]); } - #[test] - fn test_security_reload() { + #[tokio::test] + async fn test_security_reload() { let temp = tempfile::tempdir().unwrap(); let example_ca = temp.path().join("ca"); let example_cert = temp.path().join("cert"); @@ -217,7 +223,7 @@ mod tests { let mgr = SecurityManager::load(&example_ca, &example_cert, &example_pem).unwrap(); let first = mgr.load_tls_materials().unwrap(); - let key1 = mgr.connection_cache_key().unwrap(); + let key1 = mgr.connection_cache_key().await.unwrap(); File::create(&example_ca) .unwrap() @@ -233,7 +239,7 @@ mod tests { .unwrap(); let second = mgr.load_tls_materials().unwrap(); - let key2 = mgr.connection_cache_key().unwrap(); + let key2 = mgr.connection_cache_key().await.unwrap(); assert_ne!(first, second); assert_eq!(second.0, vec![9, 9]); assert_eq!(second.1, vec![8, 8, 8]); diff --git a/src/pd/client.rs b/src/pd/client.rs index 27fef794..f98505a6 100644 --- a/src/pd/client.rs +++ b/src/pd/client.rs @@ -344,22 +344,26 @@ impl PdRpcClient { } async fn kv_client(&self, address: &str) -> Result { - let cache_key = self.kv_connect.connection_cache_key()?; - if let Some(cached) = self.kv_client_cache.read().await.get(address) { - if cached.cache_key == cache_key { - return Ok(cached.client.clone()); + let cache_key = self.kv_connect.connection_cache_key().await; + if let Ok(cache_key) = cache_key { + if let Some(cached) = self.kv_client_cache.read().await.get(address) { + if cached.cache_key == cache_key { + return Ok(cached.client.clone()); + } } } info!("connect to tikv endpoint: {:?}", address); match self.kv_connect.connect(address).await { Ok(client) => { - self.kv_client_cache.write().await.insert( - address.to_owned(), - CachedKvClient { - cache_key, - client: client.clone(), - }, - ); + if let Ok(cache_key) = cache_key { + self.kv_client_cache.write().await.insert( + address.to_owned(), + CachedKvClient { + cache_key, + client: client.clone(), + }, + ); + } Ok(client) } Err(e) => Err(e), @@ -421,7 +425,7 @@ pub mod test { Ok(client) } - fn connection_cache_key(&self) -> Result> { + async fn connection_cache_key(&self) -> Result> { Ok(Some(0)) } } @@ -471,7 +475,7 @@ pub mod test { Ok(client) } - fn connection_cache_key(&self) -> Result> { + async fn connection_cache_key(&self) -> Result> { Ok(Some(self.cache_key.load(Ordering::SeqCst) as u64)) } } diff --git a/src/store/client.rs b/src/store/client.rs index 5a3163f8..5d9c2b58 100644 --- a/src/store/client.rs +++ b/src/store/client.rs @@ -21,7 +21,7 @@ pub trait KvConnect: Sized + Send + Sync + 'static { async fn connect(&self, address: &str) -> Result; - fn connection_cache_key(&self) -> Result> { + async fn connection_cache_key(&self) -> Result> { Ok(None) } } @@ -48,8 +48,8 @@ impl KvConnect for TikvConnect { .map(|c| KvRpcClient::new(c, self.timeout)) } - fn connection_cache_key(&self) -> Result> { - self.security_mgr.connection_cache_key() + async fn connection_cache_key(&self) -> Result> { + self.security_mgr.connection_cache_key().await } } From 7a24227f5e4855cdb80436447584544a10c4a96d Mon Sep 17 00:00:00 2001 From: Ray Yan Date: Fri, 15 May 2026 15:58:09 +0800 Subject: [PATCH 5/6] optimize kv_client_cache Signed-off-by: Ray Yan --- Cargo.toml | 2 +- src/common/security.rs | 41 +--------- src/pd/client.rs | 66 +++++----------- src/region_cache.rs | 4 +- src/request/mod.rs | 174 +++++++++++++++++++++++++++++++++++++++++ src/store/client.rs | 8 -- 6 files changed, 200 insertions(+), 95 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bf2e80b4..0b029bca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,7 @@ serde_derive = "1.0" serde_json = "1" take_mut = "0.2.2" thiserror = "1" -tokio = { version = "1", features = ["sync", "rt-multi-thread", "macros", "fs"] } +tokio = { version = "1", features = ["sync", "rt-multi-thread", "macros"] } tonic = { version = "0.10", features = ["tls", "gzip"] } [dev-dependencies] diff --git a/src/common/security.rs b/src/common/security.rs index 8309602f..bf52ebf7 100644 --- a/src/common/security.rs +++ b/src/common/security.rs @@ -1,14 +1,10 @@ // Copyright 2018 TiKV Project Authors. Licensed under Apache-2.0. -use std::collections::hash_map::DefaultHasher; use std::fs::File; -use std::hash::Hash; -use std::hash::Hasher; use std::io::Read; use std::path::Path; use std::path::PathBuf; use std::time::Duration; -use std::time::SystemTime; use log::info; use regex::Regex; @@ -79,24 +75,6 @@ impl SecurityManager { self.ca_path.is_some() } - pub(crate) async fn connection_cache_key(&self) -> Result> { - if !self.tls_configured() { - return Ok(None); - } - - let mut hasher = DefaultHasher::new(); - file_signature(self.ca_path.as_ref().expect("tls_configured checked")) - .await? - .hash(&mut hasher); - file_signature(self.cert_path.as_ref().expect("tls_configured checked")) - .await? - .hash(&mut hasher); - file_signature(self.key_path.as_ref().expect("tls_configured checked")) - .await? - .hash(&mut hasher); - Ok(Some(hasher.finish())) - } - /// Connect to gRPC server using TLS connection. If TLS is not configured, use normal connection. pub async fn connect( &self, @@ -163,18 +141,6 @@ impl SecurityManager { } } -async fn file_signature(path: &Path) -> Result<(u64, Option)> { - let metadata = tokio::fs::metadata(path) - .await - .map_err(|e| internal_err!("failed to stat {}: {:?}", path.display(), e))?; - let modified = metadata.modified().ok().and_then(|t: SystemTime| { - t.duration_since(SystemTime::UNIX_EPOCH) - .ok() - .map(|d| d.as_nanos()) - }); - Ok((metadata.len(), modified)) -} - #[cfg(test)] mod tests { use std::fs::File; @@ -208,8 +174,8 @@ mod tests { assert_eq!(key, vec![2]); } - #[tokio::test] - async fn test_security_reload() { + #[test] + fn test_security_reload() { let temp = tempfile::tempdir().unwrap(); let example_ca = temp.path().join("ca"); let example_cert = temp.path().join("cert"); @@ -223,7 +189,6 @@ mod tests { let mgr = SecurityManager::load(&example_ca, &example_cert, &example_pem).unwrap(); let first = mgr.load_tls_materials().unwrap(); - let key1 = mgr.connection_cache_key().await.unwrap(); File::create(&example_ca) .unwrap() @@ -239,11 +204,9 @@ mod tests { .unwrap(); let second = mgr.load_tls_materials().unwrap(); - let key2 = mgr.connection_cache_key().await.unwrap(); assert_ne!(first, second); assert_eq!(second.0, vec![9, 9]); assert_eq!(second.1, vec![8, 8, 8]); assert_eq!(second.2, vec![7, 7, 7, 7]); - assert_ne!(key1, key2); } } diff --git a/src/pd/client.rs b/src/pd/client.rs index f98505a6..757991c0 100644 --- a/src/pd/client.rs +++ b/src/pd/client.rs @@ -214,17 +214,11 @@ pub trait PdClient: Send + Sync + 'static { pub struct PdRpcClient { pd: Arc>, kv_connect: KvC, - kv_client_cache: Arc>>>, + kv_client_cache: Arc>>, enable_codec: bool, region_cache: RegionCache>, } -#[derive(Clone)] -struct CachedKvClient { - cache_key: Option, - client: Client, -} - #[async_trait] impl PdClient for PdRpcClient { type KvClient = KvC::KvClient; @@ -280,7 +274,10 @@ impl PdClient for PdRpcClient { } async fn invalidate_store_cache(&self, store_id: StoreId) { - self.region_cache.invalidate_store_cache(store_id).await + let store = self.region_cache.invalidate_store_cache(store_id).await; + if let Some(store) = store { + self.invalidate_kv_client_cache(&store.address).await; + } } async fn load_keyspace(&self, keyspace: &str) -> Result { @@ -344,31 +341,25 @@ impl PdRpcClient { } async fn kv_client(&self, address: &str) -> Result { - let cache_key = self.kv_connect.connection_cache_key().await; - if let Ok(cache_key) = cache_key { - if let Some(cached) = self.kv_client_cache.read().await.get(address) { - if cached.cache_key == cache_key { - return Ok(cached.client.clone()); - } - } + if let Some(cached) = self.kv_client_cache.read().await.get(address) { + return Ok(cached.clone()); } info!("connect to tikv endpoint: {:?}", address); match self.kv_connect.connect(address).await { Ok(client) => { - if let Ok(cache_key) = cache_key { - self.kv_client_cache.write().await.insert( - address.to_owned(), - CachedKvClient { - cache_key, - client: client.clone(), - }, - ); - } + self.kv_client_cache.write().await.insert( + address.to_owned(), + client.clone(), + ); Ok(client) } Err(e) => Err(e), } } + + async fn invalidate_kv_client_cache(&self, address: &str) { + self.kv_client_cache.write().await.remove(address); + } } fn make_key_range(start_key: Vec, end_key: Vec) -> kvrpcpb::KeyRange { @@ -397,18 +388,15 @@ pub mod test { async fn test_kv_client_caching() { let client = block_on(pd_rpc_client()); - let addr1 = "foo"; - let addr2 = "bar"; - - let kv1 = client.kv_client(addr1).await.unwrap(); - let kv2 = client.kv_client(addr2).await.unwrap(); - let kv3 = client.kv_client(addr2).await.unwrap(); + let kv1 = client.kv_client("foo").await.unwrap(); + let kv2 = client.kv_client("bar").await.unwrap(); + let kv3 = client.kv_client("bar").await.unwrap(); assert!(kv1.addr != kv2.addr); assert_eq!(kv2.addr, kv3.addr); } #[tokio::test] - async fn test_kv_client_cache_hits_when_key_is_stable() { + async fn test_kv_client_cache_hits_lazily() { #[derive(Clone)] struct CountingConnect { connects: Arc, @@ -424,10 +412,6 @@ pub mod test { client.addr = address.to_owned(); Ok(client) } - - async fn connection_cache_key(&self) -> Result> { - Ok(Some(0)) - } } let connects = Arc::new(AtomicUsize::new(0)); @@ -457,11 +441,10 @@ pub mod test { } #[tokio::test] - async fn test_kv_client_cache_invalidate_on_key_change() { + async fn test_kv_client_cache_reconnects_after_invalidation() { #[derive(Clone)] struct CountingConnect { connects: Arc, - cache_key: Arc, } #[async_trait] @@ -474,21 +457,14 @@ pub mod test { client.addr = address.to_owned(); Ok(client) } - - async fn connection_cache_key(&self) -> Result> { - Ok(Some(self.cache_key.load(Ordering::SeqCst) as u64)) - } } let connects = Arc::new(AtomicUsize::new(0)); - let cache_key = Arc::new(AtomicUsize::new(1)); let connects_clone = connects.clone(); - let cache_key_clone = cache_key.clone(); let client = PdRpcClient::new( Config::default(), move |_| CountingConnect { connects: connects_clone.clone(), - cache_key: cache_key_clone.clone(), }, |sm| async move { Ok(RetryClient::new_with_cluster( @@ -503,7 +479,7 @@ pub mod test { .unwrap(); let kv1 = client.kv_client("foo").await.unwrap(); - cache_key.store(2, Ordering::SeqCst); + client.invalidate_kv_client_cache("foo").await; let kv2 = client.kv_client("foo").await.unwrap(); assert_eq!(kv1.addr, "foo"); assert_eq!(kv2.addr, "foo"); diff --git a/src/region_cache.rs b/src/region_cache.rs index 8837de38..33d1729c 100644 --- a/src/region_cache.rs +++ b/src/region_cache.rs @@ -233,9 +233,9 @@ impl RegionCache { } } - pub async fn invalidate_store_cache(&self, store_id: StoreId) { + pub async fn invalidate_store_cache(&self, store_id: StoreId) -> Option { let mut cache = self.store_cache.write().await; - cache.remove(&store_id); + cache.remove(&store_id) } pub async fn read_through_all_stores(&self) -> Result> { diff --git a/src/request/mod.rs b/src/request/mod.rs index 0cde241b..e521cd3c 100644 --- a/src/request/mod.rs +++ b/src/request/mod.rs @@ -428,4 +428,178 @@ mod test { .plan(); assert!(plan.execute().await.is_err()); } + + #[tokio::test] + async fn test_grpc_error_invalidates_store_cache() { + #[derive(Debug, Clone)] + struct MockOkResponse; + + impl HasKeyErrors for MockOkResponse { + fn key_errors(&mut self) -> Option> { + None + } + } + + impl HasRegionError for MockOkResponse { + fn region_error(&mut self) -> Option { + None + } + } + + impl HasLocks for MockOkResponse {} + + struct InvalidationTrackingPdClient { + client: MockKvClient, + invalidate_region_count: AtomicUsize, + invalidate_store_count: AtomicUsize, + } + + impl InvalidationTrackingPdClient { + fn region() -> RegionWithLeader { + let mut region = RegionWithLeader::default(); + region.region.id = 1; + region.region.start_key = vec![]; + region.region.end_key = vec![]; + region.region.region_epoch = Some(RegionEpoch { + conf_ver: 0, + version: 0, + }); + region.leader = Some(metapb::Peer { + store_id: 41, + ..Default::default() + }); + region + } + } + + #[async_trait] + impl crate::pd::PdClient for InvalidationTrackingPdClient { + type KvClient = MockKvClient; + + async fn map_region_to_store( + self: Arc, + region: RegionWithLeader, + ) -> Result { + Ok(RegionStore::new(region, Arc::new(self.client.clone()))) + } + + async fn region_for_key(&self, _: &Key) -> Result { + Ok(Self::region()) + } + + async fn region_for_id(&self, id: RegionId) -> Result { + match id { + 1 => Ok(Self::region()), + _ => Err(Error::RegionNotFoundInResponse { region_id: id }), + } + } + + async fn all_stores(&self) -> Result> { + Ok(vec![Store::new(Arc::new(self.client.clone()))]) + } + + async fn get_timestamp(self: Arc) -> Result { + Ok(Timestamp::default()) + } + + async fn update_safepoint(self: Arc, _safepoint: u64) -> Result { + unimplemented!() + } + + async fn load_keyspace(&self, _keyspace: &str) -> Result { + unimplemented!() + } + + async fn update_leader( + &self, + _ver_id: RegionVerId, + _leader: metapb::Peer, + ) -> Result<()> { + Ok(()) + } + + async fn invalidate_region_cache(&self, _ver_id: RegionVerId) { + self.invalidate_region_count.fetch_add(1, Ordering::SeqCst); + } + + async fn invalidate_store_cache(&self, _store_id: StoreId) { + self.invalidate_store_count.fetch_add(1, Ordering::SeqCst); + } + } + + #[derive(Clone)] + struct MockKvRequest; + + #[async_trait] + impl Request for MockKvRequest { + async fn dispatch(&self, _: &TikvClient, _: Duration) -> Result> { + Ok(Box::new(MockOkResponse)) + } + + fn label(&self) -> &'static str { + "mock" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn set_leader(&mut self, _: &RegionWithLeader) -> Result<()> { + Ok(()) + } + + fn set_api_version(&mut self, _: kvrpcpb::ApiVersion) {} + } + + #[async_trait] + impl KvRequest for MockKvRequest { + type Response = MockOkResponse; + } + + impl Shardable for MockKvRequest { + type Shard = Vec>; + + fn shards( + &self, + pd_client: &Arc, + ) -> futures::stream::BoxStream< + 'static, + crate::Result<(Self::Shard, crate::region::RegionWithLeader)>, + > { + region_stream_for_keys( + Some(Key::from("mock_key".to_owned())).into_iter(), + pd_client.clone(), + ) + } + + fn apply_shard(&mut self, _shard: Self::Shard) {} + + fn apply_store(&mut self, _store: &crate::store::RegionStore) -> crate::Result<()> { + Ok(()) + } + } + + let fail_first_dispatch = Arc::new(AtomicBool::new(true)); + let pd_client = Arc::new(InvalidationTrackingPdClient { + client: MockKvClient::with_dispatch_hook(move |_: &dyn Any| { + if fail_first_dispatch.swap(false, Ordering::SeqCst) { + Err(Error::GrpcAPI(tonic::Status::unavailable( + "transient failure", + ))) + } else { + Ok(Box::new(MockOkResponse) as Box) + } + }), + invalidate_region_count: AtomicUsize::new(0), + invalidate_store_count: AtomicUsize::new(0), + }); + + let plan = crate::request::PlanBuilder::new(pd_client.clone(), Keyspace::Disable, MockKvRequest) + .retry_multi_region(Backoff::no_jitter_backoff(1, 1, 1)) + .plan(); + let response = plan.execute().await; + assert!(response.is_ok()); + assert_eq!(pd_client.invalidate_region_count.load(Ordering::SeqCst), 1); + assert_eq!(pd_client.invalidate_store_count.load(Ordering::SeqCst), 1); + } } diff --git a/src/store/client.rs b/src/store/client.rs index 5d9c2b58..1c873285 100644 --- a/src/store/client.rs +++ b/src/store/client.rs @@ -20,10 +20,6 @@ pub trait KvConnect: Sized + Send + Sync + 'static { type KvClient: KvClient + Clone + Send + Sync + 'static; async fn connect(&self, address: &str) -> Result; - - async fn connection_cache_key(&self) -> Result> { - Ok(None) - } } #[derive(new, Clone)] @@ -47,10 +43,6 @@ impl KvConnect for TikvConnect { .await .map(|c| KvRpcClient::new(c, self.timeout)) } - - async fn connection_cache_key(&self) -> Result> { - self.security_mgr.connection_cache_key().await - } } #[async_trait] From b8d876f386838359331794d256364896d9215f33 Mon Sep 17 00:00:00 2001 From: Ray Yan Date: Mon, 18 May 2026 11:15:09 +0800 Subject: [PATCH 6/6] polish Signed-off-by: Ray Yan --- src/common/security.rs | 39 ++++++++++++++++++++++----------------- src/pd/client.rs | 8 ++++---- src/request/mod.rs | 3 ++- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/src/common/security.rs b/src/common/security.rs index bf52ebf7..89d60060 100644 --- a/src/common/security.rs +++ b/src/common/security.rs @@ -97,7 +97,7 @@ impl SecurityManager { } async fn tls_channel(&self, addr: &str) -> Result { - let (ca, cert, key) = self.load_tls_materials()?; + let (ca, cert, key) = self.load_tls_materials().await?; let addr = "https://".to_string() + &SCHEME_REG.replace(addr, ""); let builder = self.endpoint(addr.to_string())?; let tls = ClientTlsConfig::new() @@ -107,25 +107,30 @@ impl SecurityManager { Ok(builder) } - fn load_tls_materials(&self) -> Result<(Vec, Vec, Vec)> { + async fn load_tls_materials(&self) -> Result<(Vec, Vec, Vec)> { let ca_path = self .ca_path - .as_ref() + .clone() .ok_or_else(|| internal_err!("TLS is not configured"))?; let cert_path = self .cert_path - .as_ref() + .clone() .ok_or_else(|| internal_err!("TLS is not configured"))?; let key_path = self .key_path - .as_ref() + .clone() .ok_or_else(|| internal_err!("TLS is not configured"))?; - Ok(( - load_pem_file("ca", ca_path)?, - load_pem_file("certificate", cert_path)?, - load_pem_file("private key", key_path)?, - )) + let materials = + tokio::task::spawn_blocking(move || -> Result<(Vec, Vec, Vec)> { + Ok(( + load_pem_file("ca", &ca_path)?, + load_pem_file("certificate", &cert_path)?, + load_pem_file("private key", &key_path)?, + )) + }) + .await??; + Ok(materials) } async fn default_channel(&self, addr: &str) -> Result { @@ -151,8 +156,8 @@ mod tests { use super::*; - #[test] - fn test_security() { + #[tokio::test] + async fn test_security() { let temp = tempfile::tempdir().unwrap(); let example_ca = temp.path().join("ca"); let example_cert = temp.path().join("cert"); @@ -168,14 +173,14 @@ mod tests { let ca_path: PathBuf = format!("{}", example_ca.display()).into(); let mgr = SecurityManager::load(ca_path, cert_path, &key_path).unwrap(); assert!(mgr.tls_configured()); - let (ca, cert, key) = mgr.load_tls_materials().unwrap(); + let (ca, cert, key) = mgr.load_tls_materials().await.unwrap(); assert_eq!(ca, vec![0]); assert_eq!(cert, vec![1]); assert_eq!(key, vec![2]); } - #[test] - fn test_security_reload() { + #[tokio::test] + async fn test_security_reload() { let temp = tempfile::tempdir().unwrap(); let example_ca = temp.path().join("ca"); let example_cert = temp.path().join("cert"); @@ -188,7 +193,7 @@ mod tests { } let mgr = SecurityManager::load(&example_ca, &example_cert, &example_pem).unwrap(); - let first = mgr.load_tls_materials().unwrap(); + let first = mgr.load_tls_materials().await.unwrap(); File::create(&example_ca) .unwrap() @@ -203,7 +208,7 @@ mod tests { .write_all(&[7, 7, 7, 7]) .unwrap(); - let second = mgr.load_tls_materials().unwrap(); + let second = mgr.load_tls_materials().await.unwrap(); assert_ne!(first, second); assert_eq!(second.0, vec![9, 9]); assert_eq!(second.1, vec![8, 8, 8]); diff --git a/src/pd/client.rs b/src/pd/client.rs index 757991c0..b8965875 100644 --- a/src/pd/client.rs +++ b/src/pd/client.rs @@ -347,10 +347,10 @@ impl PdRpcClient { info!("connect to tikv endpoint: {:?}", address); match self.kv_connect.connect(address).await { Ok(client) => { - self.kv_client_cache.write().await.insert( - address.to_owned(), - client.clone(), - ); + self.kv_client_cache + .write() + .await + .insert(address.to_owned(), client.clone()); Ok(client) } Err(e) => Err(e), diff --git a/src/request/mod.rs b/src/request/mod.rs index e521cd3c..4b8415c8 100644 --- a/src/request/mod.rs +++ b/src/request/mod.rs @@ -594,7 +594,8 @@ mod test { invalidate_store_count: AtomicUsize::new(0), }); - let plan = crate::request::PlanBuilder::new(pd_client.clone(), Keyspace::Disable, MockKvRequest) + let plan = + crate::request::PlanBuilder::new(pd_client.clone(), Keyspace::Disable, MockKvRequest) .retry_multi_region(Backoff::no_jitter_backoff(1, 1, 1)) .plan(); let response = plan.execute().await;