diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..4a6a1abd --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[resolver] +incompatible-rust-versions = "fallback" diff --git a/src/common/security.rs b/src/common/security.rs index 89e074b3..89d60060 100644 --- a/src/common/security.rs +++ b/src/common/security.rs @@ -43,12 +43,12 @@ fn load_pem_file(tag: &str, path: &Path) -> Result> { /// Manages the TLS protocol #[derive(Default)] pub struct SecurityManager { - /// The PEM encoding of the server’s CA certificates. - ca: Vec, - /// The PEM encoding of the server’s certificate chain. - cert: Vec, + /// The path to the PEM encoding of the server’s CA certificates. + ca_path: Option, + /// The path to the PEM encoding of the server’s certificate chain. + cert_path: Option, /// The path to the file that contains the PEM encoding of the server’s private key. - key: PathBuf, + key_path: Option, } impl SecurityManager { @@ -58,15 +58,23 @@ impl SecurityManager { cert_path: impl AsRef, key_path: impl Into, ) -> Result { + let ca_path = ca_path.as_ref().to_path_buf(); + let cert_path = cert_path.as_ref().to_path_buf(); let key_path = key_path.into(); + check_pem_file("ca", &ca_path)?; + check_pem_file("certificate", &cert_path)?; check_pem_file("private key", &key_path)?; Ok(SecurityManager { - ca: load_pem_file("ca", ca_path.as_ref())?, - cert: load_pem_file("certificate", cert_path.as_ref())?, - key: key_path, + ca_path: Some(ca_path), + cert_path: Some(cert_path), + key_path: Some(key_path), }) } + pub(crate) fn tls_configured(&self) -> bool { + self.ca_path.is_some() + } + /// Connect to gRPC server using TLS connection. If TLS is not configured, use normal connection. pub async fn connect( &self, @@ -78,7 +86,7 @@ impl SecurityManager { Factory: FnOnce(Channel) -> Client, { info!("connect to rpc server at endpoint: {:?}", addr); - let channel = if !self.ca.is_empty() { + let channel = if self.tls_configured() { self.tls_channel(addr).await? } else { self.default_channel(addr).await? @@ -89,18 +97,42 @@ impl SecurityManager { } async fn tls_channel(&self, addr: &str) -> Result { + let (ca, cert, key) = self.load_tls_materials().await?; let addr = "https://".to_string() + &SCHEME_REG.replace(addr, ""); let builder = self.endpoint(addr.to_string())?; let tls = ClientTlsConfig::new() - .ca_certificate(Certificate::from_pem(&self.ca)) - .identity(Identity::from_pem( - &self.cert, - load_pem_file("private key", &self.key)?, - )); + .ca_certificate(Certificate::from_pem(ca)) + .identity(Identity::from_pem(cert, key)); let builder = builder.tls_config(tls)?; Ok(builder) } + async fn load_tls_materials(&self) -> Result<(Vec, Vec, Vec)> { + let ca_path = self + .ca_path + .clone() + .ok_or_else(|| internal_err!("TLS is not configured"))?; + let cert_path = self + .cert_path + .clone() + .ok_or_else(|| internal_err!("TLS is not configured"))?; + let key_path = self + .key_path + .clone() + .ok_or_else(|| internal_err!("TLS is not configured"))?; + + let materials = + tokio::task::spawn_blocking(move || -> Result<(Vec, Vec, Vec)> { + Ok(( + load_pem_file("ca", &ca_path)?, + load_pem_file("certificate", &cert_path)?, + load_pem_file("private key", &key_path)?, + )) + }) + .await??; + Ok(materials) + } + async fn default_channel(&self, addr: &str) -> Result { let addr = "http://".to_string() + &SCHEME_REG.replace(addr, ""); self.endpoint(addr) @@ -124,8 +156,8 @@ mod tests { use super::*; - #[test] - fn test_security() { + #[tokio::test] + async fn test_security() { let temp = tempfile::tempdir().unwrap(); let example_ca = temp.path().join("ca"); let example_cert = temp.path().join("cert"); @@ -140,9 +172,46 @@ mod tests { let key_path: PathBuf = format!("{}", example_pem.display()).into(); let ca_path: PathBuf = format!("{}", example_ca.display()).into(); let mgr = SecurityManager::load(ca_path, cert_path, &key_path).unwrap(); - assert_eq!(mgr.ca, vec![0]); - assert_eq!(mgr.cert, vec![1]); - let key = load_pem_file("private key", &key_path).unwrap(); + assert!(mgr.tls_configured()); + let (ca, cert, key) = mgr.load_tls_materials().await.unwrap(); + assert_eq!(ca, vec![0]); + assert_eq!(cert, vec![1]); assert_eq!(key, vec![2]); } + + #[tokio::test] + async fn test_security_reload() { + let temp = tempfile::tempdir().unwrap(); + let example_ca = temp.path().join("ca"); + let example_cert = temp.path().join("cert"); + let example_pem = temp.path().join("key"); + for (id, f) in [&example_ca, &example_cert, &example_pem] + .iter() + .enumerate() + { + File::create(f).unwrap().write_all(&[id as u8]).unwrap(); + } + + let mgr = SecurityManager::load(&example_ca, &example_cert, &example_pem).unwrap(); + let first = mgr.load_tls_materials().await.unwrap(); + + File::create(&example_ca) + .unwrap() + .write_all(&[9, 9]) + .unwrap(); + File::create(&example_cert) + .unwrap() + .write_all(&[8, 8, 8]) + .unwrap(); + File::create(&example_pem) + .unwrap() + .write_all(&[7, 7, 7, 7]) + .unwrap(); + + let second = mgr.load_tls_materials().await.unwrap(); + assert_ne!(first, second); + assert_eq!(second.0, vec![9, 9]); + assert_eq!(second.1, vec![8, 8, 8]); + assert_eq!(second.2, vec![7, 7, 7, 7]); + } } diff --git a/src/pd/client.rs b/src/pd/client.rs index 05b9c07c..b8965875 100644 --- a/src/pd/client.rs +++ b/src/pd/client.rs @@ -274,7 +274,10 @@ impl PdClient for PdRpcClient { } async fn invalidate_store_cache(&self, store_id: StoreId) { - self.region_cache.invalidate_store_cache(store_id).await + let store = self.region_cache.invalidate_store_cache(store_id).await; + if let Some(store) = store { + self.invalidate_kv_client_cache(&store.address).await; + } } async fn load_keyspace(&self, keyspace: &str) -> Result { @@ -338,9 +341,9 @@ impl PdRpcClient { } async fn kv_client(&self, address: &str) -> Result { - if let Some(client) = self.kv_client_cache.read().await.get(address) { - return Ok(client.clone()); - }; + if let Some(cached) = self.kv_client_cache.read().await.get(address) { + return Ok(cached.clone()); + } info!("connect to tikv endpoint: {:?}", address); match self.kv_connect.connect(address).await { Ok(client) => { @@ -353,6 +356,10 @@ impl PdRpcClient { Err(e) => Err(e), } } + + async fn invalidate_kv_client_cache(&self, address: &str) { + self.kv_client_cache.write().await.remove(address); + } } fn make_key_range(start_key: Vec, end_key: Vec) -> kvrpcpb::KeyRange { @@ -364,26 +371,121 @@ fn make_key_range(start_key: Vec, end_key: Vec) -> kvrpcpb::KeyRange { #[cfg(test)] pub mod test { + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + use async_trait::async_trait; use futures::executor; use futures::executor::block_on; use super::*; use crate::mock::*; + use crate::pd::RetryClient; + use crate::store::KvConnect; + use crate::Config; #[tokio::test] async fn test_kv_client_caching() { let client = block_on(pd_rpc_client()); - let addr1 = "foo"; - let addr2 = "bar"; - - let kv1 = client.kv_client(addr1).await.unwrap(); - let kv2 = client.kv_client(addr2).await.unwrap(); - let kv3 = client.kv_client(addr2).await.unwrap(); + let kv1 = client.kv_client("foo").await.unwrap(); + let kv2 = client.kv_client("bar").await.unwrap(); + let kv3 = client.kv_client("bar").await.unwrap(); assert!(kv1.addr != kv2.addr); assert_eq!(kv2.addr, kv3.addr); } + #[tokio::test] + async fn test_kv_client_cache_hits_lazily() { + #[derive(Clone)] + struct CountingConnect { + connects: Arc, + } + + #[async_trait] + impl KvConnect for CountingConnect { + type KvClient = MockKvClient; + + async fn connect(&self, address: &str) -> Result { + self.connects.fetch_add(1, Ordering::SeqCst); + let mut client = MockKvClient::default(); + client.addr = address.to_owned(); + Ok(client) + } + } + + let connects = Arc::new(AtomicUsize::new(0)); + let connects_clone = connects.clone(); + let client = PdRpcClient::new( + Config::default(), + move |_| CountingConnect { + connects: connects_clone.clone(), + }, + |sm| async move { + Ok(RetryClient::new_with_cluster( + sm, + Config::default().timeout, + MockCluster, + )) + }, + false, + ) + .await + .unwrap(); + + let kv1 = client.kv_client("foo").await.unwrap(); + let kv2 = client.kv_client("foo").await.unwrap(); + assert_eq!(kv1.addr, "foo"); + assert_eq!(kv2.addr, "foo"); + assert_eq!(connects.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_kv_client_cache_reconnects_after_invalidation() { + #[derive(Clone)] + struct CountingConnect { + connects: Arc, + } + + #[async_trait] + impl KvConnect for CountingConnect { + type KvClient = MockKvClient; + + async fn connect(&self, address: &str) -> Result { + self.connects.fetch_add(1, Ordering::SeqCst); + let mut client = MockKvClient::default(); + client.addr = address.to_owned(); + Ok(client) + } + } + + let connects = Arc::new(AtomicUsize::new(0)); + let connects_clone = connects.clone(); + let client = PdRpcClient::new( + Config::default(), + move |_| CountingConnect { + connects: connects_clone.clone(), + }, + |sm| async move { + Ok(RetryClient::new_with_cluster( + sm, + Config::default().timeout, + MockCluster, + )) + }, + false, + ) + .await + .unwrap(); + + let kv1 = client.kv_client("foo").await.unwrap(); + client.invalidate_kv_client_cache("foo").await; + let kv2 = client.kv_client("foo").await.unwrap(); + assert_eq!(kv1.addr, "foo"); + assert_eq!(kv2.addr, "foo"); + assert_eq!(connects.load(Ordering::SeqCst), 2); + } + #[test] fn test_group_keys_by_region() { let client = MockPdClient::default(); diff --git a/src/region_cache.rs b/src/region_cache.rs index 8837de38..33d1729c 100644 --- a/src/region_cache.rs +++ b/src/region_cache.rs @@ -233,9 +233,9 @@ impl RegionCache { } } - pub async fn invalidate_store_cache(&self, store_id: StoreId) { + pub async fn invalidate_store_cache(&self, store_id: StoreId) -> Option { let mut cache = self.store_cache.write().await; - cache.remove(&store_id); + cache.remove(&store_id) } pub async fn read_through_all_stores(&self) -> Result> { diff --git a/src/request/mod.rs b/src/request/mod.rs index c8fd07be..4b8415c8 100644 --- a/src/request/mod.rs +++ b/src/request/mod.rs @@ -90,21 +90,25 @@ impl RetryOptions { mod test { use std::any::Any; use std::iter; - use std::sync::atomic::AtomicUsize; + use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; + use async_trait::async_trait; use tonic::transport::Channel; use super::*; use crate::mock::MockKvClient; use crate::mock::MockPdClient; + use crate::proto::keyspacepb; use crate::proto::kvrpcpb; + use crate::proto::metapb::{self, RegionEpoch}; use crate::proto::pdpb::Timestamp; use crate::proto::tikvpb::tikv_client::TikvClient; - use crate::region::RegionWithLeader; + use crate::region::{RegionId, RegionVerId, RegionWithLeader, StoreId}; use crate::store::region_stream_for_keys; use crate::store::HasRegionError; + use crate::store::{RegionStore, Store}; use crate::transaction::lowering::new_commit_request; use crate::Error; use crate::Key; @@ -206,6 +210,196 @@ mod test { assert_eq!(invoking_count.load(std::sync::atomic::Ordering::SeqCst), 4); } + #[tokio::test] + async fn test_region_store_mapping_retry() { + #[derive(Debug, Clone)] + struct MockOkResponse; + + impl HasKeyErrors for MockOkResponse { + fn key_errors(&mut self) -> Option> { + None + } + } + + impl HasRegionError for MockOkResponse { + fn region_error(&mut self) -> Option { + None + } + } + + impl HasLocks for MockOkResponse {} + + struct FlakyStoreMappingPdClient { + client: MockKvClient, + invalidated: AtomicBool, + invalidation_count: AtomicUsize, + } + + impl FlakyStoreMappingPdClient { + fn region(store_id: StoreId) -> RegionWithLeader { + let mut region = RegionWithLeader::default(); + region.region.id = 1; + region.region.start_key = vec![]; + region.region.end_key = vec![]; + region.region.region_epoch = Some(RegionEpoch { + conf_ver: 0, + version: 0, + }); + region.leader = Some(metapb::Peer { + store_id, + ..Default::default() + }); + region + } + } + + #[async_trait] + impl crate::pd::PdClient for FlakyStoreMappingPdClient { + type KvClient = MockKvClient; + + async fn map_region_to_store( + self: Arc, + region: RegionWithLeader, + ) -> Result { + match region.get_store_id()? { + 41 => Err(Error::InternalError { + message: "invalid store ID 41, not found".to_owned(), + }), + _ => Ok(RegionStore::new(region, Arc::new(self.client.clone()))), + } + } + + async fn region_for_key(&self, _: &Key) -> Result { + let store_id = if self.invalidated.load(Ordering::SeqCst) { + 42 + } else { + 41 + }; + Ok(Self::region(store_id)) + } + + async fn region_for_id(&self, id: RegionId) -> Result { + match id { + 1 => self.region_for_key(&Key::EMPTY).await, + _ => Err(Error::RegionNotFoundInResponse { region_id: id }), + } + } + + async fn all_stores(&self) -> Result> { + Ok(vec![Store::new(Arc::new(self.client.clone()))]) + } + + async fn get_timestamp(self: Arc) -> Result { + Ok(Timestamp::default()) + } + + async fn update_safepoint(self: Arc, _safepoint: u64) -> Result { + unimplemented!() + } + + async fn load_keyspace(&self, _keyspace: &str) -> Result { + unimplemented!() + } + + async fn update_leader( + &self, + _ver_id: RegionVerId, + _leader: metapb::Peer, + ) -> Result<()> { + Ok(()) + } + + async fn invalidate_region_cache(&self, _ver_id: RegionVerId) { + self.invalidated.store(true, Ordering::SeqCst); + self.invalidation_count.fetch_add(1, Ordering::SeqCst); + } + + async fn invalidate_store_cache(&self, _store_id: StoreId) {} + } + + #[derive(Clone)] + struct MockKvRequest { + shard_invoking_count: Arc, + } + + #[async_trait] + impl Request for MockKvRequest { + async fn dispatch(&self, _: &TikvClient, _: Duration) -> Result> { + Ok(Box::new(MockOkResponse)) + } + + fn label(&self) -> &'static str { + "mock" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn set_leader(&mut self, _: &RegionWithLeader) -> Result<()> { + Ok(()) + } + + fn set_api_version(&mut self, _: kvrpcpb::ApiVersion) {} + } + + #[async_trait] + impl KvRequest for MockKvRequest { + type Response = MockOkResponse; + } + + impl Shardable for MockKvRequest { + type Shard = Vec>; + + fn shards( + &self, + pd_client: &Arc, + ) -> futures::stream::BoxStream< + 'static, + crate::Result<(Self::Shard, crate::region::RegionWithLeader)>, + > { + self.shard_invoking_count.fetch_add(1, Ordering::SeqCst); + region_stream_for_keys( + Some(Key::from("mock_key".to_owned())).into_iter(), + pd_client.clone(), + ) + } + + fn apply_shard(&mut self, _shard: Self::Shard) {} + + fn apply_store(&mut self, _store: &crate::store::RegionStore) -> crate::Result<()> { + Ok(()) + } + } + + let dispatch_count = Arc::new(AtomicUsize::new(0)); + let shard_invoking_count = Arc::new(AtomicUsize::new(0)); + let dispatch_count_clone = dispatch_count.clone(); + + let pd_client = Arc::new(FlakyStoreMappingPdClient { + client: MockKvClient::with_dispatch_hook(move |_: &dyn Any| { + dispatch_count_clone.fetch_add(1, Ordering::SeqCst); + Ok(Box::new(MockOkResponse) as Box) + }), + invalidated: AtomicBool::new(false), + invalidation_count: AtomicUsize::new(0), + }); + + let request = MockKvRequest { + shard_invoking_count: shard_invoking_count.clone(), + }; + + let plan = crate::request::PlanBuilder::new(pd_client.clone(), Keyspace::Disable, request) + .retry_multi_region(Backoff::no_jitter_backoff(1, 1, 3)) + .plan(); + + let response = plan.execute().await; + assert!(response.is_ok()); + assert_eq!(dispatch_count.load(Ordering::SeqCst), 1); + assert_eq!(shard_invoking_count.load(Ordering::SeqCst), 2); + assert_eq!(pd_client.invalidation_count.load(Ordering::SeqCst), 1); + } + #[tokio::test] async fn test_extract_error() { let pd_client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( @@ -234,4 +428,179 @@ mod test { .plan(); assert!(plan.execute().await.is_err()); } + + #[tokio::test] + async fn test_grpc_error_invalidates_store_cache() { + #[derive(Debug, Clone)] + struct MockOkResponse; + + impl HasKeyErrors for MockOkResponse { + fn key_errors(&mut self) -> Option> { + None + } + } + + impl HasRegionError for MockOkResponse { + fn region_error(&mut self) -> Option { + None + } + } + + impl HasLocks for MockOkResponse {} + + struct InvalidationTrackingPdClient { + client: MockKvClient, + invalidate_region_count: AtomicUsize, + invalidate_store_count: AtomicUsize, + } + + impl InvalidationTrackingPdClient { + fn region() -> RegionWithLeader { + let mut region = RegionWithLeader::default(); + region.region.id = 1; + region.region.start_key = vec![]; + region.region.end_key = vec![]; + region.region.region_epoch = Some(RegionEpoch { + conf_ver: 0, + version: 0, + }); + region.leader = Some(metapb::Peer { + store_id: 41, + ..Default::default() + }); + region + } + } + + #[async_trait] + impl crate::pd::PdClient for InvalidationTrackingPdClient { + type KvClient = MockKvClient; + + async fn map_region_to_store( + self: Arc, + region: RegionWithLeader, + ) -> Result { + Ok(RegionStore::new(region, Arc::new(self.client.clone()))) + } + + async fn region_for_key(&self, _: &Key) -> Result { + Ok(Self::region()) + } + + async fn region_for_id(&self, id: RegionId) -> Result { + match id { + 1 => Ok(Self::region()), + _ => Err(Error::RegionNotFoundInResponse { region_id: id }), + } + } + + async fn all_stores(&self) -> Result> { + Ok(vec![Store::new(Arc::new(self.client.clone()))]) + } + + async fn get_timestamp(self: Arc) -> Result { + Ok(Timestamp::default()) + } + + async fn update_safepoint(self: Arc, _safepoint: u64) -> Result { + unimplemented!() + } + + async fn load_keyspace(&self, _keyspace: &str) -> Result { + unimplemented!() + } + + async fn update_leader( + &self, + _ver_id: RegionVerId, + _leader: metapb::Peer, + ) -> Result<()> { + Ok(()) + } + + async fn invalidate_region_cache(&self, _ver_id: RegionVerId) { + self.invalidate_region_count.fetch_add(1, Ordering::SeqCst); + } + + async fn invalidate_store_cache(&self, _store_id: StoreId) { + self.invalidate_store_count.fetch_add(1, Ordering::SeqCst); + } + } + + #[derive(Clone)] + struct MockKvRequest; + + #[async_trait] + impl Request for MockKvRequest { + async fn dispatch(&self, _: &TikvClient, _: Duration) -> Result> { + Ok(Box::new(MockOkResponse)) + } + + fn label(&self) -> &'static str { + "mock" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn set_leader(&mut self, _: &RegionWithLeader) -> Result<()> { + Ok(()) + } + + fn set_api_version(&mut self, _: kvrpcpb::ApiVersion) {} + } + + #[async_trait] + impl KvRequest for MockKvRequest { + type Response = MockOkResponse; + } + + impl Shardable for MockKvRequest { + type Shard = Vec>; + + fn shards( + &self, + pd_client: &Arc, + ) -> futures::stream::BoxStream< + 'static, + crate::Result<(Self::Shard, crate::region::RegionWithLeader)>, + > { + region_stream_for_keys( + Some(Key::from("mock_key".to_owned())).into_iter(), + pd_client.clone(), + ) + } + + fn apply_shard(&mut self, _shard: Self::Shard) {} + + fn apply_store(&mut self, _store: &crate::store::RegionStore) -> crate::Result<()> { + Ok(()) + } + } + + let fail_first_dispatch = Arc::new(AtomicBool::new(true)); + let pd_client = Arc::new(InvalidationTrackingPdClient { + client: MockKvClient::with_dispatch_hook(move |_: &dyn Any| { + if fail_first_dispatch.swap(false, Ordering::SeqCst) { + Err(Error::GrpcAPI(tonic::Status::unavailable( + "transient failure", + ))) + } else { + Ok(Box::new(MockOkResponse) as Box) + } + }), + invalidate_region_count: AtomicUsize::new(0), + invalidate_store_count: AtomicUsize::new(0), + }); + + let plan = + crate::request::PlanBuilder::new(pd_client.clone(), Keyspace::Disable, MockKvRequest) + .retry_multi_region(Backoff::no_jitter_backoff(1, 1, 1)) + .plan(); + let response = plan.execute().await; + assert!(response.is_ok()); + assert_eq!(pd_client.invalidate_region_count.load(Ordering::SeqCst), 1); + assert_eq!(pd_client.invalidate_store_count.load(Ordering::SeqCst), 1); + } } diff --git a/src/request/plan.rs b/src/request/plan.rs index 8bd15bb5..c8acb5da 100644 --- a/src/request/plan.rs +++ b/src/request/plan.rs @@ -163,6 +163,8 @@ where preserve_region_results: bool, ) -> Result<::Result> { debug!("single_shard_handler"); + let region_ver_id = region.ver_id(); + let store_id = region.get_store_id().ok(); let region_store = match pd_client .clone() .map_region_to_store(region) @@ -172,27 +174,20 @@ where Ok(region_store) }) { Ok(region_store) => region_store, - Err(Error::LeaderNotFound { region }) => { - debug!( - "single_shard_handler::sharding: leader not found: {:?}", - region - ); + Err(err) => { + debug!("single_shard_handler::sharding, error: {:?}", err); return Self::handle_other_error( pd_client, plan, - region.clone(), - None, + region_ver_id, + store_id, backoff, permits, preserve_region_results, - Error::LeaderNotFound { region }, + err, ) .await; } - Err(err) => { - debug!("single_shard_handler::sharding, error: {:?}", err); - return Err(err); - } }; // limit concurrent requests