diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 33a02c40..0db4706f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,8 +13,10 @@ jobs: with: components: rustfmt, clippy - uses: actions/checkout@v3 - - name: Install dependencies - run: sudo apt-get install protobuf-compiler + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} - name: Check code formatting run: cargo fmt --all -- --check - name: Check cargo clippy warnings diff --git a/src/auth.rs b/src/auth.rs index 9ccf9edb..214da36b 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,26 +1,42 @@ +use tonic::metadata::MetadataKey; use tonic::service::Interceptor; use tonic::{Request, Status}; -pub struct TokenInterceptor { +/// Header name used for API key / token authentication. +pub const API_KEY_HEADER: &str = "api-key"; + +pub struct MetadataInterceptor { api_key: Option, + custom_headers: Vec<(String, String)>, } -impl TokenInterceptor { - pub fn new(api_key: Option) -> Self { - Self { api_key } +impl MetadataInterceptor { + pub fn new(api_key: Option, custom_headers: Vec<(String, String)>) -> Self { + Self { + api_key, + custom_headers, + } } } -impl Interceptor for TokenInterceptor { +impl Interceptor for MetadataInterceptor { fn call(&mut self, mut req: Request<()>) -> anyhow::Result, Status> { if let Some(api_key) = &self.api_key { req.metadata_mut().insert( - "api-key", + API_KEY_HEADER, api_key.parse().map_err(|_| { Status::invalid_argument(format!("Malformed API key or token: {api_key}")) })?, ); } + for (key, value) in &self.custom_headers { + let key = MetadataKey::from_bytes(key.as_bytes()) + .map_err(|_| Status::invalid_argument(format!("Malformed header name: {key}")))?; + let value = value.parse().map_err(|_| { + Status::invalid_argument(format!("Malformed header value for {key}: {value}")) + })?; + req.metadata_mut().insert(key, value); + } Ok(req) } } diff --git a/src/qdrant_client/collection.rs b/src/qdrant_client/collection.rs index 1bb1527c..879abfe7 100644 --- a/src/qdrant_client/collection.rs +++ b/src/qdrant_client/collection.rs @@ -4,7 +4,7 @@ use tonic::codegen::InterceptedService; use tonic::transport::Channel; use tonic::Status; -use crate::auth::TokenInterceptor; +use crate::auth::MetadataInterceptor; use crate::qdrant::collections_client::CollectionsClient; use crate::qdrant::{ alias_operations, AliasOperations, ChangeAliases, CollectionClusterInfoRequest, @@ -25,7 +25,7 @@ use crate::qdrant_client::{Qdrant, QdrantResult}; impl Qdrant { pub(super) async fn with_collections_client>>( &self, - f: impl Fn(CollectionsClient>) -> O, + f: impl Fn(CollectionsClient>) -> O, ) -> QdrantResult { let result = self .channel diff --git a/src/qdrant_client/config.rs b/src/qdrant_client/config.rs index 6843df32..f68467dc 100644 --- a/src/qdrant_client/config.rs +++ b/src/qdrant_client/config.rs @@ -42,6 +42,9 @@ pub struct QdrantConfig { /// Amount of concurrent connections. /// If set to 0 or 1, connection pools will be disabled. pub pool_size: usize, + + /// Optional custom headers to send with every request (both gRPC and REST). + pub custom_headers: Vec<(String, String)>, } impl QdrantConfig { @@ -56,10 +59,31 @@ impl QdrantConfig { pub fn from_url(url: &str) -> Self { QdrantConfig { uri: url.to_string(), + custom_headers: Vec::new(), ..Self::default() } } + /// Add a custom header to send with every request. + /// + /// Can be called multiple times to add multiple headers. The same header name can be + /// set multiple times; all values will be sent. + /// + /// # Examples + /// + /// ```rust,no_run + /// use qdrant_client::Qdrant; + /// + /// let client = Qdrant::from_url("http://localhost:6334") + /// .header("x-custom-id", "my-client") + /// .header("x-request-source", "batch-job") + /// .build(); + /// ``` + pub fn header(mut self, key: impl Into, value: impl Into) -> Self { + self.custom_headers.push((key.into(), value.into())); + self + } + /// Set an optional API key /// /// # Examples @@ -204,6 +228,7 @@ impl Default for QdrantConfig { compression: None, check_compatibility: true, pool_size: 3, + custom_headers: Vec::new(), } } } diff --git a/src/qdrant_client/mod.rs b/src/qdrant_client/mod.rs index 9b121d23..ae131beb 100644 --- a/src/qdrant_client/mod.rs +++ b/src/qdrant_client/mod.rs @@ -20,7 +20,7 @@ use tonic::codegen::InterceptedService; use tonic::transport::{Channel, Uri}; use tonic::Status; -use crate::auth::TokenInterceptor; +use crate::auth::MetadataInterceptor; use crate::channel_pool::ChannelPool; use crate::qdrant::{qdrant_client, HealthCheckReply, HealthCheckRequest}; use crate::qdrant_client::config::QdrantConfig; @@ -178,16 +178,19 @@ impl Qdrant { QdrantBuilder::from_url(url) } - /// Wraps a channel with a token interceptor - fn with_api_key(&self, channel: Channel) -> InterceptedService { - let interceptor = TokenInterceptor::new(self.config.api_key.clone()); + /// Wraps a channel with a metadata interceptor (api key + custom headers) + fn with_api_key(&self, channel: Channel) -> InterceptedService { + let interceptor = MetadataInterceptor::new( + self.config.api_key.clone(), + self.config.custom_headers.clone(), + ); InterceptedService::new(channel, interceptor) } // Access to raw root qdrant API async fn with_root_qdrant_client>>( &self, - f: impl Fn(qdrant_client::QdrantClient>) -> O, + f: impl Fn(qdrant_client::QdrantClient>) -> O, ) -> QdrantResult { let result = self .channel diff --git a/src/qdrant_client/points.rs b/src/qdrant_client/points.rs index 5da8bb4d..914fcf0b 100644 --- a/src/qdrant_client/points.rs +++ b/src/qdrant_client/points.rs @@ -4,7 +4,7 @@ use tonic::codegen::InterceptedService; use tonic::transport::Channel; use tonic::Status; -use crate::auth::TokenInterceptor; +use crate::auth::MetadataInterceptor; use crate::qdrant::points_client::PointsClient; use crate::qdrant::{ CountPoints, CountResponse, DeletePointVectors, DeletePoints, FacetCounts, FacetResponse, @@ -22,7 +22,7 @@ use crate::qdrant_client::{Qdrant, QdrantResult}; impl Qdrant { pub(crate) async fn with_points_client>>( &self, - f: impl Fn(PointsClient>) -> O, + f: impl Fn(PointsClient>) -> O, ) -> QdrantResult { let result = self .channel diff --git a/src/qdrant_client/snapshot.rs b/src/qdrant_client/snapshot.rs index 05b32a86..aeb60fa9 100644 --- a/src/qdrant_client/snapshot.rs +++ b/src/qdrant_client/snapshot.rs @@ -4,7 +4,7 @@ use tonic::codegen::InterceptedService; use tonic::transport::Channel; use tonic::Status; -use crate::auth::TokenInterceptor; +use crate::auth::{MetadataInterceptor, API_KEY_HEADER}; use crate::qdrant::snapshots_client::SnapshotsClient; use crate::qdrant::{ CreateFullSnapshotRequest, CreateSnapshotRequest, CreateSnapshotResponse, @@ -21,7 +21,7 @@ use crate::qdrant_client::{Qdrant, QdrantResult}; impl Qdrant { async fn with_snapshot_client>>( &self, - f: impl Fn(SnapshotsClient>) -> O, + f: impl Fn(SnapshotsClient>) -> O, ) -> QdrantResult { let result = self .channel @@ -154,7 +154,7 @@ impl Qdrant { }, }; - let mut stream = reqwest::get(format!( + let url = format!( "{}/collections/{}/snapshots/{snapshot_name}", options .rest_api_uri @@ -162,9 +162,17 @@ impl Qdrant { .map(|uri| uri.to_string()) .unwrap_or_else(|| String::from("http://localhost:6333")), options.collection_name, - )) - .await? - .bytes_stream(); + ); + + let client = reqwest::Client::new(); + let mut request = client.get(&url); + if let Some(api_key) = &self.config.api_key { + request = request.header(API_KEY_HEADER, api_key.as_str()); + } + for (key, value) in &self.config.custom_headers { + request = request.header(key.as_str(), value.as_str()); + } + let mut stream = request.send().await?.bytes_stream(); let _ = std::fs::remove_file(&options.out_path); let mut file = std::fs::OpenOptions::new() diff --git a/tests/snippet_tests/mod.rs b/tests/snippet_tests/mod.rs index decfd244..221bc81d 100644 --- a/tests/snippet_tests/mod.rs +++ b/tests/snippet_tests/mod.rs @@ -1,6 +1,7 @@ mod test_batch_update; mod test_clear_payload; mod test_collection_exists; +mod test_config_headers; mod test_count_points; mod test_create_collection; mod test_create_collection_with_bq; @@ -20,6 +21,7 @@ mod test_delete_snapshot; mod test_delete_vectors; mod test_discover_batch_points; mod test_discover_points; +mod test_external_api_keys; mod test_facets; mod test_get_collection; mod test_get_collection_aliases; @@ -56,4 +58,4 @@ mod test_upsert_image; mod test_upsert_points; mod test_upsert_points_fallback_shard_key; mod test_upsert_points_insert_only; -mod test_upsert_points_with_condition; \ No newline at end of file +mod test_upsert_points_with_condition; diff --git a/tests/snippet_tests/test_config_headers.rs b/tests/snippet_tests/test_config_headers.rs new file mode 100644 index 00000000..a25b9e63 --- /dev/null +++ b/tests/snippet_tests/test_config_headers.rs @@ -0,0 +1,74 @@ +use qdrant_client::config::{CompressionEncoding, QdrantConfig}; + +#[test] +fn header_adds_single_header() { + let config = QdrantConfig::from_url("http://localhost:6334").header("x-custom-id", "my-client"); + + assert_eq!( + config.custom_headers, + vec![("x-custom-id".to_string(), "my-client".to_string())] + ); +} + +#[test] +fn header_chain_preserves_order() { + let config = QdrantConfig::from_url("http://localhost:6334") + .header("x-a", "1") + .header("x-b", "2") + .header("x-a", "3"); + + assert_eq!( + config.custom_headers, + vec![ + ("x-a".to_string(), "1".to_string()), + ("x-b".to_string(), "2".to_string()), + ("x-a".to_string(), "3".to_string()), + ] + ); +} + +#[test] +fn header_allows_duplicate_keys() { + let config = QdrantConfig::from_url("http://localhost:6334") + .header("openai-api-key", "k1") + .header("openai-api-key", "k2"); + + assert_eq!(config.custom_headers.len(), 2); + assert_eq!( + config.custom_headers, + vec![ + ("openai-api-key".to_string(), "k1".to_string()), + ("openai-api-key".to_string(), "k2".to_string()), + ] + ); +} + +#[test] +fn header_does_not_mutate_other_config() { + let base = QdrantConfig::from_url("http://localhost:6334") + .api_key("secret") + .timeout(10u64) + .connect_timeout(20u64) + .compression(Some(CompressionEncoding::Gzip)) + .skip_compatibility_check(); + + let with_header = base.clone().header("x-feature", "on"); + + assert_eq!(with_header.uri, base.uri); + assert_eq!(with_header.timeout, base.timeout); + assert_eq!(with_header.connect_timeout, base.connect_timeout); + assert_eq!( + with_header.keep_alive_while_idle, + base.keep_alive_while_idle + ); + assert_eq!(with_header.api_key, base.api_key); + assert_eq!(with_header.compression, base.compression); + assert_eq!(with_header.check_compatibility, base.check_compatibility); + assert_eq!(with_header.pool_size, base.pool_size); + + assert_eq!( + with_header.custom_headers, + vec![("x-feature".to_string(), "on".to_string())] + ); + assert!(base.custom_headers.is_empty()); +} diff --git a/tests/snippet_tests/test_external_api_keys.rs b/tests/snippet_tests/test_external_api_keys.rs new file mode 100644 index 00000000..ebb4473c --- /dev/null +++ b/tests/snippet_tests/test_external_api_keys.rs @@ -0,0 +1,248 @@ +use std::collections::HashMap; + +use qdrant_client::qdrant::{ + CreateCollectionBuilder, Distance, Document, PointStruct, Query, QueryPointsBuilder, + UpsertPointsBuilder, VectorParamsBuilder, +}; +use qdrant_client::{Payload, Qdrant}; +use serde_json::json; + +const PROXY_URL: &str = "http://localhost:6334"; +const UPSERT_COLLECTION_NAME: &str = "test_external_api_keys_upsert"; +const QUERY_COLLECTION_NAME: &str = "test_external_api_keys_query"; +const DUAL_OPENAI_COLLECTION_NAME: &str = "test_external_api_keys_dual_openai"; +const DUAL_COHERE_COLLECTION_NAME: &str = "test_external_api_keys_dual_cohere"; +const OPENAI_MODEL: &str = "openai/text-embedding-3-small"; +const OPENAI_VECTOR_SIZE: u64 = 1536; +const COHERE_MODEL: &str = "cohere/embed-english-v3.0"; +const COHERE_VECTOR_SIZE: u64 = 1024; + +fn create_client_with_external_keys(external_api_keys: HashMap) -> Qdrant { + let mut builder = Qdrant::from_url(PROXY_URL) + .skip_compatibility_check() + .api_key("1234") + .timeout(30u64); + for (key, value) in external_api_keys { + builder = builder.header(key, value); + } + builder.build().expect("Failed to build client") +} + +async fn setup_collection(client: &Qdrant, collection_name: &str, vector_size: u64) { + let _ = client.delete_collection(collection_name).await; + + client + .create_collection( + CreateCollectionBuilder::new(collection_name) + .vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine)), + ) + .await + .expect("Failed to create collection"); +} + +fn cohere_document(text: impl Into, input_type: &'static str) -> Document { + Document { + text: text.into(), + model: COHERE_MODEL.to_string(), + options: HashMap::from([("input_type".to_string(), input_type.into())]), + } +} + +#[tokio::test] +async fn test_upsert_with_external_api_keys() { + let Some(openai_api_key) = std::env::var("OPENAI_API_KEY").ok() else { + eprintln!("Skipping test_upsert_with_external_api_keys: OPENAI_API_KEY is not set"); + return; + }; + let collection_name = UPSERT_COLLECTION_NAME; + let client = create_client_with_external_keys(HashMap::from([( + "openai-api-key".to_string(), + openai_api_key, + )])); + setup_collection(&client, collection_name, OPENAI_VECTOR_SIZE).await; + + let doc = Document::new("Qdrant is a vector search engine", OPENAI_MODEL); + + let result = client + .upsert_points( + UpsertPointsBuilder::new( + collection_name, + vec![PointStruct::new( + 1, + doc, + Payload::try_from(json!({"source": "test"})).unwrap(), + )], + ) + .wait(true), + ) + .await; + + assert!( + result.is_ok(), + "Upsert with external API keys failed: {result:?}" + ); + + let _ = client.delete_collection(collection_name).await; +} + +#[tokio::test] +async fn test_query_with_external_api_keys() { + let Some(openai_api_key) = std::env::var("OPENAI_API_KEY").ok() else { + eprintln!("Skipping test_query_with_external_api_keys: OPENAI_API_KEY is not set"); + return; + }; + let collection_name = QUERY_COLLECTION_NAME; + let client = create_client_with_external_keys(HashMap::from([( + "openai-api-key".to_string(), + openai_api_key, + )])); + setup_collection(&client, collection_name, OPENAI_VECTOR_SIZE).await; + + // Upsert a point first + let doc = Document::new("Qdrant is a vector search engine", OPENAI_MODEL); + client + .upsert_points( + UpsertPointsBuilder::new( + collection_name, + vec![PointStruct::new( + 1, + doc, + Payload::try_from(json!({"source": "test"})).unwrap(), + )], + ) + .wait(true), + ) + .await + .expect("Upsert failed"); + + // Query with a document (server-side inference) + let query_doc = Document::new("vector database", OPENAI_MODEL); + + let result = client + .query( + QueryPointsBuilder::new(collection_name) + .query(Query::new_nearest(query_doc)) + .limit(1) + .with_payload(true), + ) + .await; + + assert!( + result.is_ok(), + "Query with external API keys failed: {result:?}" + ); + + let response = result.unwrap(); + assert_eq!(response.result.len(), 1); + assert!(response.result[0].payload.contains_key("source")); + + let _ = client.delete_collection(collection_name).await; +} + +#[tokio::test] +async fn test_query_with_two_external_api_providers() { + let Some(openai_api_key) = std::env::var("OPENAI_API_KEY").ok() else { + eprintln!("Skipping test_query_with_two_external_api_providers: OPENAI_API_KEY is not set"); + return; + }; + let Some(cohere_api_key) = std::env::var("COHERE_API_KEY").ok() else { + eprintln!("Skipping test_query_with_two_external_api_providers: COHERE_API_KEY is not set"); + return; + }; + + let client = create_client_with_external_keys(HashMap::from([ + ("openai-api-key".to_string(), openai_api_key), + ("cohere-api-key".to_string(), cohere_api_key), + ])); + + setup_collection(&client, DUAL_OPENAI_COLLECTION_NAME, OPENAI_VECTOR_SIZE).await; + setup_collection(&client, DUAL_COHERE_COLLECTION_NAME, COHERE_VECTOR_SIZE).await; + + let openai_doc = Document::new("OpenAI provider document", OPENAI_MODEL); + let cohere_doc = cohere_document("Cohere provider document", "search_document"); + + let openai_upsert = client + .upsert_points( + UpsertPointsBuilder::new( + DUAL_OPENAI_COLLECTION_NAME, + vec![PointStruct::new( + 1, + openai_doc, + Payload::try_from(json!({"provider": "openai"})).unwrap(), + )], + ) + .wait(true), + ) + .await; + assert!( + openai_upsert.is_ok(), + "OpenAI upsert with external API keys failed: {openai_upsert:?}" + ); + + let cohere_upsert = client + .upsert_points( + UpsertPointsBuilder::new( + DUAL_COHERE_COLLECTION_NAME, + vec![PointStruct::new( + 1, + cohere_doc, + Payload::try_from(json!({"provider": "cohere"})).unwrap(), + )], + ) + .wait(true), + ) + .await; + assert!( + cohere_upsert.is_ok(), + "Cohere upsert with external API keys failed: {cohere_upsert:?}" + ); + + let openai_query = client + .query( + QueryPointsBuilder::new(DUAL_OPENAI_COLLECTION_NAME) + .query(Query::new_nearest(Document::new( + "OpenAI provider query", + OPENAI_MODEL, + ))) + .limit(1) + .with_payload(true), + ) + .await; + assert!( + openai_query.is_ok(), + "OpenAI query with external API keys failed: {openai_query:?}" + ); + + let cohere_query = client + .query( + QueryPointsBuilder::new(DUAL_COHERE_COLLECTION_NAME) + .query(Query::new_nearest(cohere_document( + "Cohere provider query", + "search_query", + ))) + .limit(1) + .with_payload(true), + ) + .await; + assert!( + cohere_query.is_ok(), + "Cohere query with external API keys failed: {cohere_query:?}" + ); + + let openai_response = openai_query.unwrap(); + assert_eq!(openai_response.result.len(), 1); + assert_eq!( + openai_response.result[0].payload["provider"], + "openai".into() + ); + + let cohere_response = cohere_query.unwrap(); + assert_eq!(cohere_response.result.len(), 1); + assert_eq!( + cohere_response.result[0].payload["provider"], + "cohere".into() + ); + + let _ = client.delete_collection(DUAL_OPENAI_COLLECTION_NAME).await; + let _ = client.delete_collection(DUAL_COHERE_COLLECTION_NAME).await; +}