-
Notifications
You must be signed in to change notification settings - Fork 510
Expand file tree
/
Copy pathcommon.rs
More file actions
289 lines (248 loc) · 9.02 KB
/
common.rs
File metadata and controls
289 lines (248 loc) · 9.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
//! Common utilities shared between tool and prompt handlers
use std::{any::TypeId, collections::HashMap, sync::Arc};
use schemars::JsonSchema;
use crate::{
RoleServer, model::JsonObject, schemars::generate::SchemaSettings, service::RequestContext,
};
/// Generates a JSON schema for a type
pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
thread_local! {
static CACHE_FOR_TYPE: std::sync::RwLock<HashMap<TypeId, Arc<JsonObject>>> = Default::default();
};
CACHE_FOR_TYPE.with(|cache| {
if let Some(x) = cache
.read()
.expect("schema cache lock poisoned")
.get(&TypeId::of::<T>())
{
x.clone()
} else {
// explicitly to align json schema version to official specifications.
// refer to https://github.com/modelcontextprotocol/modelcontextprotocol/pull/655 for details.
let settings = SchemaSettings::draft2020_12();
// Note: AddNullable is intentionally NOT used here because the `nullable` keyword
// is an OpenAPI 3.0 extension, not part of JSON Schema 2020-12. Using it would
// cause validation failures with strict JSON Schema validators.
let generator = settings.into_generator();
let schema = generator.into_root_schema_for::<T>();
let object = serde_json::to_value(schema).expect("failed to serialize schema");
let object = match object {
serde_json::Value::Object(object) => object,
_ => panic!(
"Schema serialization produced non-object value: expected JSON object but got {:?}",
object
),
};
let schema = Arc::new(object);
cache
.write()
.expect("schema cache lock poisoned")
.insert(TypeId::of::<T>(), schema.clone());
schema
}
})
}
// TODO: should be updated according to the new specifications
/// Schema used when input is empty.
pub fn schema_for_empty_input() -> Arc<JsonObject> {
std::sync::Arc::new(
serde_json::json!({
"type": "object",
"properties": {}
})
.as_object()
.unwrap()
.clone(),
)
}
/// Generate and validate a JSON schema for outputSchema (must have root type "object").
pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObject>, String> {
thread_local! {
static CACHE_FOR_OUTPUT: std::sync::RwLock<HashMap<TypeId, Result<Arc<JsonObject>, String>>> = Default::default();
};
CACHE_FOR_OUTPUT.with(|cache| {
// Try to get from cache first
if let Some(result) = cache
.read()
.expect("output schema cache lock poisoned")
.get(&TypeId::of::<T>())
{
return result.clone();
}
// Generate and validate schema
let schema = schema_for_type::<T>();
let result = match schema.get("type") {
Some(serde_json::Value::String(t)) if t == "object" => Ok(schema.clone()),
Some(serde_json::Value::String(t)) => Err(format!(
"MCP specification requires tool outputSchema to have root type 'object', but found '{}'.",
t
)),
None => Err(
"Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'.".to_string()
),
Some(other) => Err(format!(
"Schema 'type' field has unexpected format: {:?}. Expected \"object\".",
other
)),
};
// Cache the result (both success and error cases)
cache
.write()
.expect("output schema cache lock poisoned")
.insert(TypeId::of::<T>(), result.clone());
result
})
}
/// Trait for extracting parts from a context, unifying tool and prompt extraction
pub trait FromContextPart<C>: Sized {
fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData>;
}
/// Common extractors that can be used by both tool and prompt handlers
impl<C> FromContextPart<C> for RequestContext<RoleServer>
where
C: AsRequestContext,
{
fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
Ok(context.as_request_context().clone())
}
}
impl<C> FromContextPart<C> for tokio_util::sync::CancellationToken
where
C: AsRequestContext,
{
fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
Ok(context.as_request_context().ct.clone())
}
}
impl<C> FromContextPart<C> for crate::model::Extensions
where
C: AsRequestContext,
{
fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
Ok(context.as_request_context().extensions.clone())
}
}
#[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")]
pub struct Extension<T>(pub T);
impl<C, T> FromContextPart<C> for Extension<T>
where
C: AsRequestContext,
T: Send + Sync + 'static + Clone,
{
fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
let extension = context
.as_request_context()
.extensions
.get::<T>()
.cloned()
.ok_or_else(|| {
crate::ErrorData::invalid_params(
format!("missing extension {}", std::any::type_name::<T>()),
None,
)
})?;
Ok(Extension(extension))
}
}
impl<C> FromContextPart<C> for crate::Peer<RoleServer>
where
C: AsRequestContext,
{
fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
Ok(context.as_request_context().peer.clone())
}
}
impl<C> FromContextPart<C> for crate::model::Meta
where
C: AsRequestContext,
{
fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
let request_context = context.as_request_context_mut();
let mut meta = crate::model::Meta::default();
std::mem::swap(&mut meta, &mut request_context.meta);
Ok(meta)
}
}
#[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")]
pub struct RequestId(pub crate::model::RequestId);
impl<C> FromContextPart<C> for RequestId
where
C: AsRequestContext,
{
fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData> {
Ok(RequestId(context.as_request_context().id.clone()))
}
}
/// Trait for types that can provide access to RequestContext
pub trait AsRequestContext {
fn as_request_context(&self) -> &RequestContext<RoleServer>;
fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer>;
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
struct TestObject {
value: i32,
}
#[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
struct AnotherTestObject {
value: i32,
}
#[test]
fn test_schema_for_type_handles_primitive() {
let schema = schema_for_type::<i32>();
assert_eq!(schema.get("type"), Some(&serde_json::json!("integer")));
}
#[test]
fn test_schema_for_type_handles_array() {
let schema = schema_for_type::<Vec<i32>>();
assert_eq!(schema.get("type"), Some(&serde_json::json!("array")));
let items = schema.get("items").and_then(|v| v.as_object());
assert_eq!(
items.unwrap().get("type"),
Some(&serde_json::json!("integer"))
);
}
#[test]
fn test_schema_for_type_handles_struct() {
let schema = schema_for_type::<TestObject>();
assert_eq!(schema.get("type"), Some(&serde_json::json!("object")));
let properties = schema.get("properties").and_then(|v| v.as_object());
assert!(properties.unwrap().contains_key("value"));
}
#[test]
fn test_schema_for_type_caches_primitive_types() {
let schema1 = schema_for_type::<i32>();
let schema2 = schema_for_type::<i32>();
assert!(Arc::ptr_eq(&schema1, &schema2));
}
#[test]
fn test_schema_for_type_caches_struct_types() {
let schema1 = schema_for_type::<TestObject>();
let schema2 = schema_for_type::<TestObject>();
assert!(Arc::ptr_eq(&schema1, &schema2));
}
#[test]
fn test_schema_for_type_different_types_different_schemas() {
let schema1 = schema_for_type::<TestObject>();
let schema2 = schema_for_type::<AnotherTestObject>();
assert!(!Arc::ptr_eq(&schema1, &schema2));
}
#[test]
fn test_schema_for_type_arc_can_be_shared() {
let schema = schema_for_type::<TestObject>();
let cloned = schema.clone();
assert!(Arc::ptr_eq(&schema, &cloned));
}
#[test]
fn test_schema_for_output_rejects_primitive() {
let result = schema_for_output::<i32>();
assert!(result.is_err(),);
}
#[test]
fn test_schema_for_output_accepts_object() {
let result = schema_for_output::<TestObject>();
assert!(result.is_ok(),);
}
}