Skip to content

Commit 1ad18eb

Browse files
committed
implement client builder for native-tls
1 parent cfc155a commit 1ad18eb

6 files changed

Lines changed: 160 additions & 28 deletions

File tree

bitreq/src/client.rs

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,47 @@
99
use std::collections::{hash_map, HashMap, VecDeque};
1010
use std::sync::{Arc, Mutex};
1111

12-
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
12+
#[cfg(any(
13+
all(feature = "native-tls", feature = "tokio-native-tls"),
14+
all(feature = "rustls", feature = "tokio-rustls")
15+
))]
1316
use crate::connection::certificates::{Certificates, CertificatesBuilder};
1417
use crate::connection::AsyncConnection;
1518
use crate::request::{OwnedConnectionParams as ConnectionKey, ParsedRequest};
1619
use crate::{Error, Request, Response};
1720

1821
#[derive(Clone)]
1922
pub(crate) struct ClientConfig {
20-
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
23+
#[cfg(any(
24+
all(feature = "native-tls", feature = "tokio-native-tls"),
25+
all(feature = "rustls", feature = "tokio-rustls")
26+
))]
2127
pub(crate) tls: Option<TlsConfig>,
2228
}
2329

24-
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
30+
#[cfg(any(
31+
all(feature = "native-tls", feature = "tokio-native-tls"),
32+
all(feature = "rustls", feature = "tokio-rustls")
33+
))]
2534
#[derive(Clone)]
2635
pub(crate) struct TlsConfig {
2736
pub(crate) certificates: Certificates,
2837
}
2938

30-
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
39+
#[cfg(any(
40+
all(feature = "native-tls", feature = "tokio-native-tls"),
41+
all(feature = "rustls", feature = "tokio-rustls")
42+
))]
3143
impl TlsConfig {
3244
fn new(certificates: Certificates) -> Self { Self { certificates } }
3345
}
3446

3547
pub struct ClientBuilder {
3648
capacity: usize,
37-
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
49+
#[cfg(any(
50+
all(feature = "native-tls", feature = "tokio-native-tls"),
51+
all(feature = "rustls", feature = "tokio-rustls")
52+
))]
3853
certificates: Option<CertificatesBuilder>,
3954
}
4055

@@ -56,11 +71,17 @@ pub struct ClientBuilder {
5671
/// ```
5772
impl ClientBuilder {
5873
/// Creates a new `ClientBuilder` with a default pool capacity of 10.
59-
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
74+
#[cfg(any(
75+
all(feature = "native-tls", feature = "tokio-native-tls"),
76+
all(feature = "rustls", feature = "tokio-rustls")
77+
))]
6078
pub fn new() -> Self { Self { capacity: 10, certificates: None } }
6179

6280
/// Creates a new `ClientBuilder` with a default pool capacity of 10.
63-
#[cfg(not(all(feature = "rustls", feature = "tokio-rustls")))]
81+
#[cfg(not(any(
82+
all(feature = "native-tls", feature = "tokio-native-tls"),
83+
all(feature = "rustls", feature = "tokio-rustls")
84+
)))]
6485
pub fn new() -> Self { Self { capacity: 10 } }
6586

6687
/// Sets the maximum number of connections to keep in the pool.
@@ -69,7 +90,10 @@ impl ClientBuilder {
6990
self
7091
}
7192

72-
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
93+
#[cfg(any(
94+
all(feature = "native-tls", feature = "tokio-native-tls"),
95+
all(feature = "rustls", feature = "tokio-rustls")
96+
))]
7397
/// Builds the `Client` with the configured settings.
7498
pub fn build(self) -> Result<Client, Error> {
7599
let build_config = if let Some(builder) = self.certificates {
@@ -92,7 +116,10 @@ impl ClientBuilder {
92116
}
93117

94118
/// Builds the `Client` with the configured settings.
95-
#[cfg(not(any(all(feature = "rustls", feature = "tokio-rustls"))))]
119+
#[cfg(not(any(
120+
all(feature = "native-tls", feature = "tokio-native-tls"),
121+
all(feature = "rustls", feature = "tokio-rustls")
122+
)))]
96123
pub fn build(self) -> Result<Client, Error> {
97124
Ok(Client {
98125
r#async: Arc::new(Mutex::new(ClientImpl {
@@ -122,7 +149,10 @@ impl ClientBuilder {
122149
/// # Ok(())
123150
/// # }
124151
/// ```
125-
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
152+
#[cfg(any(
153+
all(feature = "native-tls", feature = "tokio-native-tls"),
154+
all(feature = "rustls", feature = "tokio-rustls")
155+
))]
126156
pub fn with_root_certificate<T: Into<Vec<u8>>>(mut self, cert_der: T) -> Result<Self, Error> {
127157
let cert_der = cert_der.into();
128158
if let Some(ref mut certificates) = self.certificates {
@@ -137,7 +167,10 @@ impl ClientBuilder {
137167

138168
/// Disables default root certificates for TLS connections.
139169
/// Returns [`Error::InvalidTlsConfig`] if TLS has not been configured.
140-
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
170+
#[cfg(any(
171+
all(feature = "native-tls", feature = "tokio-native-tls"),
172+
all(feature = "rustls", feature = "tokio-rustls")
173+
))]
141174
pub fn disable_default_certificates(mut self) -> Result<Self, Error> {
142175
match self.certificates {
143176
Some(ref mut certificates) => certificates.disable_default()?,

bitreq/src/connection.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ use crate::{Error, Method, ResponseLazy};
3131

3232
type UnsecuredStream = TcpStream;
3333

34-
35-
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
34+
#[cfg(any(
35+
all(feature = "native-tls", feature = "tokio-native-tls"),
36+
all(feature = "rustls", feature = "tokio-rustls")
37+
))]
3638
pub(crate) mod certificates;
3739
#[cfg(any(feature = "rustls", feature = "native-tls"))]
3840
mod rustls_stream;

bitreq/src/connection/certificates.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
#[cfg(any(feature = "rustls", feature = "native-tls"))]
22
use std::sync::Arc;
33

4+
#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
5+
use native_tls::{Certificate, TlsConnector, TlsConnectorBuilder};
46
#[cfg(feature = "rustls")]
57
use rustls::RootCertStore;
8+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
9+
use tokio_native_tls::TlsConnector as AsyncTlsConnector;
610
#[cfg(feature = "rustls-webpki")]
711
use webpki_roots::TLS_SERVER_ROOTS;
812

@@ -14,6 +18,11 @@ pub(crate) struct CertificatesBuilder {
1418
pub(crate) disable_default: bool,
1519
}
1620

21+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
22+
pub(crate) struct CertificatesBuilder {
23+
pub(crate) inner: TlsConnectorBuilder,
24+
}
25+
1726
impl CertificatesBuilder {
1827
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
1928
pub(crate) fn new(cert_der: Option<Vec<u8>>) -> Result<Self, Error> {
@@ -26,13 +35,41 @@ impl CertificatesBuilder {
2635
Ok(certificates)
2736
}
2837

38+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
39+
pub(crate) fn new(cert_der: Option<Vec<u8>>) -> Result<Self, Error> {
40+
let builder = TlsConnector::builder();
41+
let mut certificates = Self { inner: builder };
42+
43+
if let Some(cert_der) = cert_der {
44+
certificates.append_certificate(cert_der)?;
45+
}
46+
47+
Ok(certificates)
48+
}
49+
2950
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
3051
pub(crate) fn append_certificate(&mut self, cert_der: Vec<u8>) -> Result<&mut Self, Error> {
3152
self.inner.add(&rustls::Certificate(cert_der)).map_err(Error::RustlsAppendCert)?;
3253

3354
Ok(self)
3455
}
3556

57+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
58+
pub(crate) fn append_certificate(&mut self, cert_der: Vec<u8>) -> Result<&mut Self, Error> {
59+
let certificate = Certificate::from_der(&cert_der)?;
60+
self.inner.add_root_certificate(certificate);
61+
62+
Ok(self)
63+
}
64+
65+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
66+
pub(crate) fn build(self) -> Result<Certificates, Error> {
67+
let connector = self.inner.build()?;
68+
let async_connector = AsyncTlsConnector::from(connector);
69+
70+
Ok(Certificates(Arc::new(async_connector)))
71+
}
72+
3673
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
3774
pub(crate) fn build(mut self) -> Result<Certificates, Error> {
3875
if !self.disable_default {
@@ -74,8 +111,18 @@ impl CertificatesBuilder {
74111
self.disable_default = true;
75112
Ok(self)
76113
}
114+
115+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
116+
pub(crate) fn disable_default(&mut self) -> Result<&mut Self, Error> {
117+
self.inner.disable_built_in_roots(true);
118+
Ok(self)
119+
}
77120
}
78121

79122
#[derive(Clone)]
80123
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
81124
pub(crate) struct Certificates(pub(crate) Arc<RootCertStore>);
125+
126+
#[derive(Clone)]
127+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
128+
pub(crate) struct Certificates(pub(crate) Arc<AsyncTlsConnector>);

bitreq/src/connection/rustls_stream.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ use super::HttpStream;
2828
all(feature = "rustls", feature = "tokio-rustls")
2929
))]
3030
use super::{AsyncHttpStream, AsyncTcpStream};
31-
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
31+
#[cfg(any(
32+
all(feature = "native-tls", feature = "tokio-native-tls"),
33+
all(feature = "rustls", feature = "tokio-rustls")
34+
))]
3235
use crate::connection::certificates::Certificates;
3336
use crate::Error;
3437

@@ -217,21 +220,21 @@ pub(super) async fn wrap_async_stream(
217220
Ok(AsyncHttpStream::Secured(Box::new(tls)))
218221
}
219222

220-
// #[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
221-
// pub(super) async fn wrap_async_stream_with_configs(
222-
// tcp: AsyncTcpStream,
223-
// host: &str,
224-
// client_configs: Certificates,
225-
// ) -> Result<AsyncHttpStream, Error> {
226-
// #[cfg(feature = "log")]
227-
// log::trace!("Setting up TLS parameters for {host}.");
223+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
224+
pub(super) async fn wrap_async_stream_with_configs(
225+
tcp: AsyncTcpStream,
226+
host: &str,
227+
client_configs: Certificates,
228+
) -> Result<AsyncHttpStream, Error> {
229+
#[cfg(feature = "log")]
230+
log::trace!("Setting up TLS parameters for {host}.");
228231

229-
// let async_connector = client_configs.0;
232+
let async_connector = client_configs.0;
230233

231-
// #[cfg(feature = "log")]
232-
// log::trace!("Establishing TLS session to {host}.");
234+
#[cfg(feature = "log")]
235+
log::trace!("Establishing TLS session to {host}.");
233236

234-
// let tls = async_connector.connect(host, tcp).await?;
237+
let tls = async_connector.connect(host, tcp).await?;
235238

236-
// Ok(AsyncHttpStream::Secured(Box::new(tls)))
237-
// }
239+
Ok(AsyncHttpStream::Secured(Box::new(tls)))
240+
}

bitreq/src/error.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ pub enum Error {
2828
#[cfg(feature = "native-tls")]
2929
/// Ran into a native-tls error while creating the connection.
3030
NativeTlsCreateConnection(native_tls::Error),
31+
#[cfg(feature = "native-tls")]
32+
/// Ran into a native-tls error while appending a certificate.
33+
NativeTlsAppendCert,
3134
#[cfg(any(feature = "rustls", feature = "native-tls"))]
3235
/// The current TLS configuration is invalid.
3336
InvalidTlsConfig,
@@ -114,6 +117,8 @@ impl fmt::Display for Error {
114117
RustlsAppendCert(err) => write!(f, "error appending certificate: {}", err),
115118
#[cfg(feature = "native-tls")]
116119
NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {}", err),
120+
#[cfg(feature = "native-tls")]
121+
NativeTlsAppendCert => write!(f, "error appending certificate"),
117122
#[cfg(any(feature = "rustls", feature = "native-tls"))]
118123
InvalidTlsConfig => write!(f, "error disabling default certificates. Must have custom cert."),
119124
MalformedChunkLength => write!(f, "non-usize chunk length with transfer-encoding: chunked"),

bitreq/tests/main.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ async fn test_https_with_client_builder() {
5252
assert_eq!(response.status_code, 200);
5353
}
5454

55+
#[tokio::test]
56+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
57+
async fn test_https_with_client_builder() {
58+
setup();
59+
let client = bitreq::Client::builder().build().unwrap();
60+
let response = client.send_async(bitreq::get("https://example.com")).await.unwrap();
61+
assert_eq!(response.status_code, 200);
62+
}
63+
5564
#[tokio::test]
5665
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
5766
async fn test_https_with_client_builder_and_cert() {
@@ -66,6 +75,39 @@ async fn test_https_with_client_builder_and_cert() {
6675
assert_eq!(response.status_code, 200);
6776
}
6877

78+
#[tokio::test]
79+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
80+
async fn test_https_with_client_builder_and_cert() {
81+
setup();
82+
let cert_der = include_bytes!("test_cert.der");
83+
let client = bitreq::Client::builder()
84+
.with_root_certificate(cert_der.as_slice())
85+
.unwrap()
86+
.build()
87+
.unwrap();
88+
let response = client.send_async(bitreq::get("https://example.com")).await.unwrap();
89+
assert_eq!(response.status_code, 200);
90+
}
91+
92+
#[tokio::test]
93+
#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))]
94+
async fn test_https_with_multiple_certs() {
95+
setup();
96+
let cert_der = include_bytes!("test_cert.der");
97+
let ca_der = include_bytes!("ca_cert.der");
98+
99+
let client = bitreq::Client::builder()
100+
.with_root_certificate(cert_der.as_slice())
101+
.unwrap()
102+
.with_root_certificate(ca_der.as_slice())
103+
.unwrap()
104+
.build()
105+
.unwrap();
106+
107+
let response = client.send_async(bitreq::get("https://example.com")).await.unwrap();
108+
assert_eq!(response.status_code, 200);
109+
}
110+
69111
#[tokio::test]
70112
#[cfg(all(feature = "rustls", feature = "tokio-rustls"))]
71113
async fn test_https_with_multiple_certs() {

0 commit comments

Comments
 (0)