Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ jobs:
with:
components: rustfmt, clippy
- uses: actions/checkout@v3
- name: Install dependencies
run: sudo apt-get install protobuf-compiler
- name: Install Protoc
uses: arduino/setup-protoc@v3
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: Check code formatting
run: cargo fmt --all -- --check
- name: Check cargo clippy warnings
Expand Down
28 changes: 22 additions & 6 deletions src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,42 @@
use tonic::metadata::MetadataKey;
use tonic::service::Interceptor;
use tonic::{Request, Status};

pub struct TokenInterceptor {
/// Header name used for API key / token authentication.
pub const API_KEY_HEADER: &str = "api-key";

pub struct MetadataInterceptor {
api_key: Option<String>,
custom_headers: Vec<(String, String)>,
}

impl TokenInterceptor {
pub fn new(api_key: Option<String>) -> Self {
Self { api_key }
impl MetadataInterceptor {
pub fn new(api_key: Option<String>, custom_headers: Vec<(String, String)>) -> Self {
Self {
api_key,
custom_headers,
}
}
}

impl Interceptor for TokenInterceptor {
impl Interceptor for MetadataInterceptor {
fn call(&mut self, mut req: Request<()>) -> anyhow::Result<Request<()>, Status> {
if let Some(api_key) = &self.api_key {
req.metadata_mut().insert(
"api-key",
API_KEY_HEADER,
api_key.parse().map_err(|_| {
Status::invalid_argument(format!("Malformed API key or token: {api_key}"))
})?,
);
}
for (key, value) in &self.custom_headers {
let key = MetadataKey::from_bytes(key.as_bytes())
.map_err(|_| Status::invalid_argument(format!("Malformed header name: {key}")))?;
let value = value.parse().map_err(|_| {
Status::invalid_argument(format!("Malformed header value for {key}: {value}"))
})?;
req.metadata_mut().insert(key, value);
}
Ok(req)
}
}
4 changes: 2 additions & 2 deletions src/qdrant_client/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use tonic::codegen::InterceptedService;
use tonic::transport::Channel;
use tonic::Status;

use crate::auth::TokenInterceptor;
use crate::auth::MetadataInterceptor;
use crate::qdrant::collections_client::CollectionsClient;
use crate::qdrant::{
alias_operations, AliasOperations, ChangeAliases, CollectionClusterInfoRequest,
Expand All @@ -25,7 +25,7 @@ use crate::qdrant_client::{Qdrant, QdrantResult};
impl Qdrant {
pub(super) async fn with_collections_client<T, O: Future<Output = Result<T, Status>>>(
&self,
f: impl Fn(CollectionsClient<InterceptedService<Channel, TokenInterceptor>>) -> O,
f: impl Fn(CollectionsClient<InterceptedService<Channel, MetadataInterceptor>>) -> O,
) -> QdrantResult<T> {
let result = self
.channel
Expand Down
25 changes: 25 additions & 0 deletions src/qdrant_client/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ pub struct QdrantConfig {
/// Amount of concurrent connections.
/// If set to 0 or 1, connection pools will be disabled.
pub pool_size: usize,

/// Optional custom headers to send with every request (both gRPC and REST).
pub custom_headers: Vec<(String, String)>,
}

impl QdrantConfig {
Expand All @@ -56,10 +59,31 @@ impl QdrantConfig {
pub fn from_url(url: &str) -> Self {
QdrantConfig {
uri: url.to_string(),
custom_headers: Vec::new(),
..Self::default()
}
}

/// Add a custom header to send with every request.
///
/// Can be called multiple times to add multiple headers. The same header name can be
/// set multiple times; all values will be sent.
///
/// # Examples
///
/// ```rust,no_run
/// use qdrant_client::Qdrant;
///
/// let client = Qdrant::from_url("http://localhost:6334")
/// .header("x-custom-id", "my-client")
/// .header("x-request-source", "batch-job")
/// .build();
/// ```
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.custom_headers.push((key.into(), value.into()));
self
}

/// Set an optional API key
///
/// # Examples
Expand Down Expand Up @@ -204,6 +228,7 @@ impl Default for QdrantConfig {
compression: None,
check_compatibility: true,
pool_size: 3,
custom_headers: Vec::new(),
}
}
}
Expand Down
13 changes: 8 additions & 5 deletions src/qdrant_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use tonic::codegen::InterceptedService;
use tonic::transport::{Channel, Uri};
use tonic::Status;

use crate::auth::TokenInterceptor;
use crate::auth::MetadataInterceptor;
use crate::channel_pool::ChannelPool;
use crate::qdrant::{qdrant_client, HealthCheckReply, HealthCheckRequest};
use crate::qdrant_client::config::QdrantConfig;
Expand Down Expand Up @@ -178,16 +178,19 @@ impl Qdrant {
QdrantBuilder::from_url(url)
}

/// Wraps a channel with a token interceptor
fn with_api_key(&self, channel: Channel) -> InterceptedService<Channel, TokenInterceptor> {
let interceptor = TokenInterceptor::new(self.config.api_key.clone());
/// Wraps a channel with a metadata interceptor (api key + custom headers)
fn with_api_key(&self, channel: Channel) -> InterceptedService<Channel, MetadataInterceptor> {
let interceptor = MetadataInterceptor::new(
self.config.api_key.clone(),
self.config.custom_headers.clone(),
);
InterceptedService::new(channel, interceptor)
}

// Access to raw root qdrant API
async fn with_root_qdrant_client<T, O: Future<Output = Result<T, Status>>>(
&self,
f: impl Fn(qdrant_client::QdrantClient<InterceptedService<Channel, TokenInterceptor>>) -> O,
f: impl Fn(qdrant_client::QdrantClient<InterceptedService<Channel, MetadataInterceptor>>) -> O,
) -> QdrantResult<T> {
let result = self
.channel
Expand Down
4 changes: 2 additions & 2 deletions src/qdrant_client/points.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use tonic::codegen::InterceptedService;
use tonic::transport::Channel;
use tonic::Status;

use crate::auth::TokenInterceptor;
use crate::auth::MetadataInterceptor;
use crate::qdrant::points_client::PointsClient;
use crate::qdrant::{
CountPoints, CountResponse, DeletePointVectors, DeletePoints, FacetCounts, FacetResponse,
Expand All @@ -22,7 +22,7 @@ use crate::qdrant_client::{Qdrant, QdrantResult};
impl Qdrant {
pub(crate) async fn with_points_client<T, O: Future<Output = Result<T, Status>>>(
&self,
f: impl Fn(PointsClient<InterceptedService<Channel, TokenInterceptor>>) -> O,
f: impl Fn(PointsClient<InterceptedService<Channel, MetadataInterceptor>>) -> O,
) -> QdrantResult<T> {
let result = self
.channel
Expand Down
20 changes: 14 additions & 6 deletions src/qdrant_client/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use tonic::codegen::InterceptedService;
use tonic::transport::Channel;
use tonic::Status;

use crate::auth::TokenInterceptor;
use crate::auth::{MetadataInterceptor, API_KEY_HEADER};
use crate::qdrant::snapshots_client::SnapshotsClient;
use crate::qdrant::{
CreateFullSnapshotRequest, CreateSnapshotRequest, CreateSnapshotResponse,
Expand All @@ -21,7 +21,7 @@ use crate::qdrant_client::{Qdrant, QdrantResult};
impl Qdrant {
async fn with_snapshot_client<T, O: Future<Output = Result<T, Status>>>(
&self,
f: impl Fn(SnapshotsClient<InterceptedService<Channel, TokenInterceptor>>) -> O,
f: impl Fn(SnapshotsClient<InterceptedService<Channel, MetadataInterceptor>>) -> O,
) -> QdrantResult<T> {
let result = self
.channel
Expand Down Expand Up @@ -154,17 +154,25 @@ impl Qdrant {
},
};

let mut stream = reqwest::get(format!(
let url = format!(
"{}/collections/{}/snapshots/{snapshot_name}",
options
.rest_api_uri
.as_ref()
.map(|uri| uri.to_string())
.unwrap_or_else(|| String::from("http://localhost:6333")),
options.collection_name,
))
.await?
.bytes_stream();
);

let client = reqwest::Client::new();
let mut request = client.get(&url);
if let Some(api_key) = &self.config.api_key {
request = request.header(API_KEY_HEADER, api_key.as_str());
}
for (key, value) in &self.config.custom_headers {
request = request.header(key.as_str(), value.as_str());
}
let mut stream = request.send().await?.bytes_stream();

let _ = std::fs::remove_file(&options.out_path);
let mut file = std::fs::OpenOptions::new()
Expand Down
4 changes: 3 additions & 1 deletion tests/snippet_tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod test_batch_update;
mod test_clear_payload;
mod test_collection_exists;
mod test_config_headers;
mod test_count_points;
mod test_create_collection;
mod test_create_collection_with_bq;
Expand All @@ -20,6 +21,7 @@ mod test_delete_snapshot;
mod test_delete_vectors;
mod test_discover_batch_points;
mod test_discover_points;
mod test_external_api_keys;
mod test_facets;
mod test_get_collection;
mod test_get_collection_aliases;
Expand Down Expand Up @@ -56,4 +58,4 @@ mod test_upsert_image;
mod test_upsert_points;
mod test_upsert_points_fallback_shard_key;
mod test_upsert_points_insert_only;
mod test_upsert_points_with_condition;
mod test_upsert_points_with_condition;
74 changes: 74 additions & 0 deletions tests/snippet_tests/test_config_headers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use qdrant_client::config::{CompressionEncoding, QdrantConfig};

#[test]
fn header_adds_single_header() {
let config = QdrantConfig::from_url("http://localhost:6334").header("x-custom-id", "my-client");

assert_eq!(
config.custom_headers,
vec![("x-custom-id".to_string(), "my-client".to_string())]
);
}

#[test]
fn header_chain_preserves_order() {
let config = QdrantConfig::from_url("http://localhost:6334")
.header("x-a", "1")
.header("x-b", "2")
.header("x-a", "3");

assert_eq!(
config.custom_headers,
vec![
("x-a".to_string(), "1".to_string()),
("x-b".to_string(), "2".to_string()),
("x-a".to_string(), "3".to_string()),
]
);
}

#[test]
fn header_allows_duplicate_keys() {
let config = QdrantConfig::from_url("http://localhost:6334")
.header("openai-api-key", "k1")
.header("openai-api-key", "k2");

assert_eq!(config.custom_headers.len(), 2);
assert_eq!(
config.custom_headers,
vec![
("openai-api-key".to_string(), "k1".to_string()),
("openai-api-key".to_string(), "k2".to_string()),
]
);
}

#[test]
fn header_does_not_mutate_other_config() {
let base = QdrantConfig::from_url("http://localhost:6334")
.api_key("secret")
.timeout(10u64)
.connect_timeout(20u64)
.compression(Some(CompressionEncoding::Gzip))
.skip_compatibility_check();

let with_header = base.clone().header("x-feature", "on");

assert_eq!(with_header.uri, base.uri);
assert_eq!(with_header.timeout, base.timeout);
assert_eq!(with_header.connect_timeout, base.connect_timeout);
assert_eq!(
with_header.keep_alive_while_idle,
base.keep_alive_while_idle
);
assert_eq!(with_header.api_key, base.api_key);
assert_eq!(with_header.compression, base.compression);
assert_eq!(with_header.check_compatibility, base.check_compatibility);
assert_eq!(with_header.pool_size, base.pool_size);

assert_eq!(
with_header.custom_headers,
vec![("x-feature".to_string(), "on".to_string())]
);
assert!(base.custom_headers.is_empty());
}
Loading