Skip to content

Commit 9a9e264

Browse files
authored
Add client api (#78)
* add client Signed-off-by: kerthcet <kerthcet@gmail.com> * rename Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 1c1775f commit 9a9e264

14 files changed

Lines changed: 309 additions & 63 deletions

File tree

.env.integration-test

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
AMRS_API_KEY=your_amrs_api_key_here
2+
OPENAI_API_KEY=your_openai_api_key_here
3+
FAKE_API_KEY=your_fake_api_key_here

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ lazy_static = "1.5.0"
1212
rand = "0.9.2"
1313
reqwest = "0.12.26"
1414
serde = "1.0.228"
15+
tokio = "1.48.0"

README.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,55 @@
44

55
The Adaptive Model Routing System (AMRS) is a framework designed to select the best-fit model for exploration and exploitation. (still under development)
66

7+
Thanks to [async-openai](https://github.com/64bit/async-openai), AMRS builds on top of it to provide adaptive model routing capabilities.
8+
9+
## Features
10+
11+
- Flexible routing strategies, including:
12+
- **Random**: Randomly selects a model from the available models.
13+
- **WRR**: Weighted Round Robin selects models based on predefined weights.
14+
- **UCB**: Upper Confidence Bound based model selection (coming soon).
15+
- **Adaptive**: Dynamically selects models based on performance metrics (coming soon).
16+
17+
18+
## How to use
19+
20+
Here's a simple example with random routing mode:
21+
22+
23+
```rust
24+
// Before running the code, make sure to set your OpenAI API key in the environment variable:
25+
// export OPENAI_API_KEY="your_openai_api_key"
26+
27+
use arms::{Client, Config, ModelConfig, CreateResponseArgs, RoutingMode};
28+
29+
let config = Config::builder()
30+
.provider("openai")
31+
.routing_mode(RoutingMode::Random)
32+
.model(
33+
ModelConfig::builder()
34+
.id("gpt-3.5-turbo")
35+
.build()
36+
.unwrap(),
37+
)
38+
.model(
39+
ModelConfig::builder()
40+
.id("gpt-4")
41+
.build()
42+
.unwrap(),
43+
)
44+
.build()
45+
.unwrap();
46+
47+
let mut client = Client::new(config);
48+
let request = CreateResponseArgs::default()
49+
.input("give me a poem about nature")
50+
.build()
51+
.unwrap();
52+
53+
let response = client.create_response(request).await.unwrap();
54+
```
55+
756
## Contributing
857

958
🚀 All kinds of contributions are welcomed ! Please follow [Contributing](/CONTRIBUTING.md).

src/client/client.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ impl Client {
1717
let providers = cfg
1818
.models
1919
.iter()
20-
.map(|m| (m.id.clone(), provider::construct_provider(m)))
20+
.map(|m| (m.id.clone(), provider::construct_provider(m.clone())))
2121
.collect();
2222

2323
Self {
@@ -28,8 +28,8 @@ impl Client {
2828

2929
pub async fn create_response(
3030
&mut self,
31-
request: provider::ResponseRequest,
32-
) -> Result<provider::ResponseResult, provider::APIError> {
31+
request: provider::CreateResponseReq,
32+
) -> Result<provider::CreateResponseRes, provider::APIError> {
3333
let model_id = self.router.sample(&request);
3434
let provider = self.providers.get(&model_id).unwrap();
3535
provider.create_response(request).await

src/config.rs

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ lazy_static! {
1414
m.insert("OPENAI", "https://api.openai.com/v1");
1515
m.insert("DEEPINFRA", "https://api.deepinfra.com/v1/openai");
1616
m.insert("OPENROUTER", "https://openrouter.ai/api/v1");
17+
18+
m.insert("FAKE", "http://localhost:8080"); // test only
1719
// TODO: support more providers here...
1820
m
1921
};
@@ -34,20 +36,34 @@ pub type ModelId = String;
3436
pub struct ModelConfig {
3537
// model-specific configs, will override global configs if provided
3638
#[builder(default = "None")]
37-
pub base_url: Option<String>,
38-
#[builder(default = "None")]
39-
pub provider: Option<ProviderName>,
39+
pub(crate) base_url: Option<String>,
40+
#[builder(default = "None", setter(custom))]
41+
pub(crate) provider: Option<ProviderName>,
4042
#[builder(default = "None")]
41-
pub temperature: Option<f32>,
43+
pub(crate) temperature: Option<f32>,
4244
#[builder(default = "None")]
43-
pub max_output_tokens: Option<usize>,
45+
pub(crate) max_output_tokens: Option<usize>,
4446

45-
pub id: ModelId,
47+
#[builder(setter(custom))]
48+
pub(crate) id: ModelId,
4649
#[builder(default=-1)]
47-
pub weight: i32,
50+
pub(crate) weight: i32,
4851
}
4952

5053
impl ModelConfigBuilder {
54+
pub fn id<S: AsRef<str>>(&mut self, name: S) -> &mut Self {
55+
self.id = Some(name.as_ref().to_string());
56+
self
57+
}
58+
59+
pub fn provider<S>(&mut self, name: Option<S>) -> &mut Self
60+
where
61+
S: AsRef<str>,
62+
{
63+
self.provider = Some(name.map(|s| s.as_ref().to_string().to_uppercase()));
64+
self
65+
}
66+
5167
fn validate(&self) -> Result<(), String> {
5268
if self.id.is_none() {
5369
return Err("Model id must be provided.".to_string());
@@ -69,7 +85,7 @@ pub struct Config {
6985
// global configs for models, will be overridden by model-specific configs
7086
#[builder(default = "https://api.openai.com/v1".to_string())]
7187
pub(crate) base_url: String,
72-
#[builder(default = "ProviderName::from(OPENAI_PROVIDER)")]
88+
#[builder(default = "ProviderName::from(OPENAI_PROVIDER)", setter(custom))]
7389
pub(crate) provider: ProviderName,
7490
#[builder(default = "0.8")]
7591
pub(crate) temperature: f32,
@@ -124,6 +140,11 @@ impl ConfigBuilder {
124140
self
125141
}
126142

143+
pub fn provider<S: AsRef<str>>(&mut self, name: S) -> &mut Self {
144+
self.provider = Some(name.as_ref().to_string().to_uppercase());
145+
self
146+
}
147+
127148
fn validate(&self) -> Result<(), String> {
128149
if self.models.is_none() || self.models.as_ref().unwrap().is_empty() {
129150
return Err("At least one model must be configured.".to_string());
@@ -258,7 +279,7 @@ mod tests {
258279
.build()
259280
.unwrap(),
260281
)
261-
.provider("unknown_provider".to_string())
282+
.provider("unknown_provider")
262283
.build();
263284
assert!(invalid_cfg_with_no_api_key.is_err());
264285

@@ -269,8 +290,8 @@ mod tests {
269290
.max_output_tokens(2048)
270291
.model(
271292
ModelConfig::builder()
272-
.id("custom-model".to_string())
273-
.provider(Some("AMRS".to_string()))
293+
.id("custom-model")
294+
.provider(Some("AMRS"))
274295
.build()
275296
.unwrap(),
276297
)
@@ -317,12 +338,7 @@ mod tests {
317338
let mut valid_specified_cfg = Config::builder()
318339
.provider("AMRS".to_string())
319340
.base_url("http://custom-api.ai".to_string())
320-
.model(
321-
ModelConfig::builder()
322-
.id("model-2".to_string())
323-
.build()
324-
.unwrap(),
325-
)
341+
.model(ModelConfig::builder().id("model-2").build().unwrap())
326342
.build();
327343
valid_specified_cfg.as_mut().unwrap().populate();
328344

src/lib.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,18 @@ mod router {
44
pub mod stats;
55
mod wrr;
66
}
7+
mod config;
78
mod client {
89
pub mod client;
910
}
1011
mod provider {
12+
mod fake;
1113
mod openai;
1214
pub mod provider;
1315
}
1416

15-
pub mod config;
1617
pub use crate::client::client::Client;
18+
pub use crate::config::{Config, ModelConfig, RoutingMode};
19+
pub use crate::provider::provider::{
20+
APIError, CreateResponseArgs, CreateResponseReq, CreateResponseRes,
21+
};

src/provider/fake.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
use std::str::FromStr;
2+
3+
use async_openai::types::responses::{
4+
AssistantRole, OutputItem, OutputMessage, OutputMessageContent, OutputStatus,
5+
OutputTextContent, Status,
6+
};
7+
use async_openai::{Client, config::OpenAIConfig};
8+
use async_trait::async_trait;
9+
use reqwest::header::HeaderName;
10+
11+
use crate::config::{ModelConfig, ModelId};
12+
use crate::provider::provider::{
13+
APIError, CreateResponseReq, CreateResponseRes, Provider, validate_request,
14+
};
15+
16+
pub struct FakeProvider {
17+
model: ModelId,
18+
}
19+
20+
impl FakeProvider {
21+
pub fn new(config: ModelConfig) -> Self {
22+
Self {
23+
model: config.id.clone(),
24+
}
25+
}
26+
}
27+
28+
#[async_trait]
29+
impl Provider for FakeProvider {
30+
fn name(&self) -> &'static str {
31+
"FakeProvider"
32+
}
33+
34+
async fn create_response(
35+
&self,
36+
request: CreateResponseReq,
37+
) -> Result<CreateResponseRes, APIError> {
38+
validate_request(&request)?;
39+
40+
Ok(CreateResponseRes {
41+
id: "fake-response-id".to_string(),
42+
object: "text_completion".to_string(),
43+
model: self.model.clone(),
44+
usage: None,
45+
output: vec![OutputItem::Message(OutputMessage {
46+
id: "fake-message-id".to_string(),
47+
status: OutputStatus::Completed,
48+
role: AssistantRole::Assistant,
49+
content: vec![OutputMessageContent::OutputText(OutputTextContent {
50+
annotations: vec![],
51+
logprobs: None,
52+
text: "This is a fake response.".to_string(),
53+
})],
54+
})],
55+
created_at: 1_600_000_000,
56+
background: None,
57+
billing: None,
58+
conversation: None,
59+
error: None,
60+
incomplete_details: None,
61+
instructions: None,
62+
max_output_tokens: None,
63+
metadata: None,
64+
prompt: None,
65+
parallel_tool_calls: None,
66+
previous_response_id: None,
67+
prompt_cache_key: None,
68+
prompt_cache_retention: None,
69+
reasoning: None,
70+
safety_identifier: None,
71+
service_tier: None,
72+
status: Status::Completed,
73+
temperature: None,
74+
text: None,
75+
top_p: None,
76+
tools: None,
77+
tool_choice: None,
78+
top_logprobs: None,
79+
truncation: None,
80+
})
81+
}
82+
}

src/provider/openai.rs

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
1-
use std::str::FromStr;
2-
31
use async_openai::{Client, config::OpenAIConfig};
42
use async_trait::async_trait;
5-
use reqwest::header::HeaderName;
3+
use derive_builder::Builder;
64

75
use crate::config::{ModelConfig, ModelId};
8-
use crate::provider::provider::{APIError, Provider, ResponseRequest, ResponseResult};
6+
use crate::provider::provider::{
7+
APIError, CreateResponseReq, CreateResponseRes, Provider, validate_request,
8+
};
99

10+
#[derive(Debug, Clone, Builder)]
11+
#[builder(pattern = "mutable", build_fn(skip))]
1012
pub struct OpenAIProvider {
1113
model: ModelId,
1214
config: OpenAIConfig,
13-
client: Option<Client<OpenAIConfig>>,
15+
client: Client<OpenAIConfig>,
1416
}
1517

1618
impl OpenAIProvider {
17-
pub fn new(config: &ModelConfig) -> Self {
19+
pub fn builder(config: ModelConfig) -> OpenAIProviderBuilder {
1820
let api_key_var = format!(
1921
"{}_API_KEY",
2022
config.provider.as_ref().unwrap().to_uppercase()
@@ -25,26 +27,21 @@ impl OpenAIProvider {
2527
.with_api_base(config.base_url.clone().unwrap())
2628
.with_api_key(api_key);
2729

28-
OpenAIProvider {
29-
model: config.id.clone(),
30-
config: openai_config,
30+
OpenAIProviderBuilder {
31+
model: Some(config.id.clone()),
32+
config: Some(openai_config),
3133
client: None,
3234
}
3335
}
36+
}
3437

35-
pub fn header(mut self, key: &str, value: &str) -> Result<Self, APIError> {
36-
let name = HeaderName::from_str(key)
37-
.map_err(|e| APIError::InvalidArgument(format!("Invalid header name: {}", e)))?;
38-
39-
self.config = self.config.with_header(name, value)?;
40-
Ok(self)
41-
}
42-
43-
pub fn build(mut self) -> Self {
44-
if self.client.is_none() {
45-
self.client = Some(Client::with_config(self.config.clone()));
38+
impl OpenAIProviderBuilder {
39+
pub fn build(&mut self) -> OpenAIProvider {
40+
OpenAIProvider {
41+
model: self.model.clone().unwrap(),
42+
config: self.config.clone().unwrap(),
43+
client: Client::with_config(self.config.as_ref().unwrap().clone()),
4644
}
47-
self
4845
}
4946
}
5047

@@ -54,8 +51,11 @@ impl Provider for OpenAIProvider {
5451
"OpenAIProvider"
5552
}
5653

57-
async fn create_response(&self, request: ResponseRequest) -> Result<ResponseResult, APIError> {
58-
let client = self.client.as_ref().unwrap();
59-
client.responses().create(request).await
54+
async fn create_response(
55+
&self,
56+
request: CreateResponseReq,
57+
) -> Result<CreateResponseRes, APIError> {
58+
validate_request(&request)?;
59+
self.client.responses().create(request).await
6060
}
6161
}

0 commit comments

Comments
 (0)