diff --git a/examples/tool-call-example.ro b/examples/tool-call-example.ro index 6627f1e..384c615 100644 --- a/examples/tool-call-example.ro +++ b/examples/tool-call-example.ro @@ -1,23 +1,25 @@ -use weatherTool - function buildWeatherPrompt(city): String { - promptText = "What's the weather like in " + city + "?" + promptText = "What's the weather like in " + city + "? Use the weatherTool to get the current weather." return promptText } city = "San Francisco" promptText = buildWeatherPrompt(city) +print "Prompt: " + promptText + prompt promptText tools: [ { name: "weatherTool", - description: "Provides forecast data" + description: "Provides forecast data for a given city. Returns temperature, conditions, and forecast.", + parameters: { + city: "String" + } } ] model: "gpt-4" temperature: 0.2 maxTokens: 200 -call weatherTool(city) - + \ No newline at end of file diff --git a/plugins/weather-tool/plugin.toml b/plugins/weather-tool/plugin.toml new file mode 100644 index 0000000..04cf90a --- /dev/null +++ b/plugins/weather-tool/plugin.toml @@ -0,0 +1,8 @@ +name = "weatherTool" +type = "subprocess" +description = "Weather tool that provides forecast data for a given city" +version = "1.0.0" +path = "weatherTool.py" +interpreter = "python3" +args = [] + diff --git a/plugins/weather-tool/weatherTool.py b/plugins/weather-tool/weatherTool.py new file mode 100755 index 0000000..1c3058c --- /dev/null +++ b/plugins/weather-tool/weatherTool.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 + +import json +import sys +from typing import Any, Dict, List, Union +import random + +def parse_value(value: Any) -> Dict[str, Any]: + """Convert Python value to Pronto Value format""" + if value is None: + return {"Null": None} + elif isinstance(value, str): + return {"String": value} + elif isinstance(value, (int, float)): + return {"Number": float(value)} + elif isinstance(value, bool): + return {"Boolean": value} + elif isinstance(value, list): + return {"Array": [parse_value(v) for v in value]} + elif isinstance(value, dict): + return {"Object": {k: parse_value(v) for k, v in value.items()}} + else: + return {"Null": None} + +def extract_value(value: Dict[str, Any]) -> Any: + """Extract Python value from Pronto Value format""" + if "String" in value: + return value["String"] + elif "Number" in value: + return value["Number"] + elif "Boolean" in value: + return value["Boolean"] + elif "Null" in value: + return None + elif "Array" in value: + return [extract_value(v) for v in value["Array"]] + elif "Object" in value: + return {k: extract_value(v) for k, v in value["Object"].items()} + else: + return None + +def get_weather(city: str) -> Dict[str, Any]: + """Get weather information for a given city""" + conditions = ["Sunny", "Cloudy", "Partly Cloudy", "Rainy", "Clear"] + temperatures = { + "San Francisco": (15, 20), + "New York": (10, 18), + "London": (8, 15), + "Tokyo": (12, 22), + "Paris": (10, 16), + } + + temp_range = temperatures.get(city, (10, 25)) + temp = round(random.uniform(temp_range[0], temp_range[1]), 1) + condition = random.choice(conditions) + + return { + "city": city, + "temperature": temp, + "condition": condition, + "unit": "Celsius", + "forecast": f"{condition} with a temperature of {temp}°C" + } + +def main(): + try: + input_data = sys.stdin.read() + request = json.loads(input_data) + + if request.get("tool") != "weatherTool": + response = { + "success": False, + "error": f"Unknown tool: {request.get('tool')}", + } + print(json.dumps(response)) + sys.exit(1) + + args = request.get("arguments", []) + + city = None + + if len(args) == 0: + response = { + "success": False, + "error": "weatherTool requires a city argument", + } + print(json.dumps(response)) + sys.exit(1) + + if len(args) == 1 and isinstance(args[0], dict) and "Object" in args[0]: + obj = extract_value(args[0]) + if isinstance(obj, dict) and "city" in obj: + city = obj["city"] + elif isinstance(obj, dict) and len(obj) == 1: + city = list(obj.values())[0] + elif len(args) >= 1: + city = extract_value(args[0]) + + if not city or not isinstance(city, str): + response = { + "success": False, + "error": "weatherTool requires a city argument (String)", + } + print(json.dumps(response)) + sys.exit(1) + + weather_data = get_weather(city) + + response = { + "success": True, + "result": parse_value(weather_data), + } + + print(json.dumps(response)) + + except Exception as e: + response = { + "success": False, + "error": f"Plugin error: {str(e)}", + } + print(json.dumps(response)) + sys.exit(1) + +if __name__ == "__main__": + main() + diff --git a/rohas-core/src/parser.rs b/rohas-core/src/parser.rs index a03ac3e..c9ff1e1 100644 --- a/rohas-core/src/parser.rs +++ b/rohas-core/src/parser.rs @@ -464,10 +464,28 @@ impl Parser { } }) .ok_or(ParseError::InvalidExpression)?; + let parameters = properties + .get("parameters") + .and_then(|e| { + if let Expression::ObjectLiteral { properties: param_props } = e { + let mut params = HashMap::new(); + for (param_name, param_value) in param_props { + if let Expression::Literal(Literal::String(type_str)) = param_value { + if let Ok(param_type) = Self::parse_type_from_string(&type_str) { + params.insert(param_name.clone(), param_type); + } + } + } + Some(params) + } else { + None + } + }) + .unwrap_or_else(HashMap::new); tool_defs.push(ToolDefinition { name, description, - parameters: HashMap::new(), + parameters, }); } else { return Err(ParseError::InvalidExpression); @@ -603,7 +621,6 @@ impl Parser { "maxTokens" => *max_tokens = Some(value), "stream" => *stream = Some(value), "tools" => { - if let Expression::ArrayLiteral { elements } = value { let mut tool_defs = Vec::new(); for elem in elements { @@ -628,10 +645,28 @@ impl Parser { } }) .ok_or(ParseError::InvalidExpression)?; + let parameters = properties + .get("parameters") + .and_then(|e| { + if let Expression::ObjectLiteral { properties: param_props } = e { + let mut params = HashMap::new(); + for (param_name, param_value) in param_props { + if let Expression::Literal(Literal::String(type_str)) = param_value { + if let Ok(param_type) = Self::parse_type_from_string(&type_str) { + params.insert(param_name.clone(), param_type); + } + } + } + Some(params) + } else { + None + } + }) + .unwrap_or_else(HashMap::new); tool_defs.push(ToolDefinition { name, description, - parameters: HashMap::new(), + parameters, }); } else { return Err(ParseError::InvalidExpression); @@ -1222,6 +1257,16 @@ impl Parser { } } + fn parse_type_from_string(type_str: &str) -> Result { + match type_str { + "String" | "string" => Ok(Type::String), + "Number" | "number" => Ok(Type::Number), + "Boolean" | "boolean" => Ok(Type::Boolean), + "Any" => Ok(Type::Any), + _ => Ok(Type::Named(type_str.to_string())), + } + } + fn peek(&self) -> Option<&Token> { self.tokens.get(self.current).map(|t| &t.token) } diff --git a/rohas-llm/src/providers/openai.rs b/rohas-llm/src/providers/openai.rs index 9482125..430841d 100644 --- a/rohas-llm/src/providers/openai.rs +++ b/rohas-llm/src/providers/openai.rs @@ -45,7 +45,41 @@ impl Provider for OpenAIProvider { } if let Some(tools) = request.tools { - body["tools"] = json!(tools); + let openai_tools: Vec = tools + .into_iter() + .map(|tool| { + let mut properties = serde_json::Map::new(); + let mut required = Vec::new(); + for (param_name, param) in tool.parameters { + let mut prop = serde_json::Map::new(); + prop.insert("type".to_string(), json!(param.param_type)); + if let Some(desc) = param.description { + prop.insert("description".to_string(), json!(desc)); + } + properties.insert(param_name.clone(), json!(prop)); + if param.required.unwrap_or(false) { + required.push(param_name); + } + } + + let parameters = json!({ + "type": "object", + "properties": properties, + "required": required + }); + + json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": parameters + } + }) + }) + .collect(); + + body["tools"] = json!(openai_tools); body["tool_choice"] = json!("auto"); } diff --git a/rohas-runtime/src/executor.rs b/rohas-runtime/src/executor.rs index a6e1852..cfe41cd 100644 --- a/rohas-runtime/src/executor.rs +++ b/rohas-runtime/src/executor.rs @@ -8,6 +8,7 @@ use rohas_llm::{LLMRequest, LLMRuntime, ProviderConfig}; use std::collections::HashMap; use std::io::{self, Write}; use tokio::runtime::Runtime; +use serde_json; type ToolFunction = Box Result + Send + Sync>; @@ -280,31 +281,44 @@ impl Executor { }); let llm_request = LLMRequest { - prompt: prompt_str, - model: model_str, + prompt: prompt_str.clone(), + model: model_str.clone(), temperature: temp, max_tokens: max_toks, stream: None, tools: tools.clone().map(|tools| { tools .into_iter() - .map(|t| rohas_llm::types::ToolDefinition { - name: t.name, - description: t.description, - parameters: t - .parameters - .into_iter() - .map(|(k, v)| { - ( - k, - rohas_llm::types::ToolParameter { - param_type: format!("{:?}", v), - description: None, - required: Some(true), - }, - ) - }) - .collect(), + .map(|t| { + let mut properties = std::collections::HashMap::new(); + let mut required = Vec::new(); + + for (param_name, param_type) in &t.parameters { + let type_str = match param_type { + rohas_core::ast::Type::String => "string", + rohas_core::ast::Type::Number => "number", + rohas_core::ast::Type::Boolean => "boolean", + rohas_core::ast::Type::Array(_) => "array", + rohas_core::ast::Type::Object(_) => "object", + _ => "string", + }; + + properties.insert( + param_name.clone(), + rohas_llm::types::ToolParameter { + param_type: type_str.to_string(), + description: None, + required: Some(true), + }, + ); + required.push(param_name.clone()); + } + + rohas_llm::types::ToolDefinition { + name: t.name, + description: t.description, + parameters: properties, + } }) .collect() }), @@ -315,11 +329,85 @@ impl Executor { .as_ref() .context("LLM runtime not configured")?; - let response = self + let mut response = self .tokio_runtime .block_on(llm_runtime.execute(llm_request))?; - Ok(Some(Value::String(response.content))) + if let Some(tool_calls) = &response.tool_calls { + let mut tool_results_for_llm = Vec::new(); + + for tool_call in tool_calls { + let args = if let serde_json::Value::Object(obj) = &tool_call.arguments { + let mut arg_map = HashMap::new(); + for (key, value) in obj { + arg_map.insert(key.clone(), self.json_to_value(value)?); + } + vec![Value::Object(arg_map)] + } else { + vec![self.json_to_value(&tool_call.arguments)?] + }; + + let result = if let Ok(result) = self.tools.call(&tool_call.name, &args) { + result + } else if let Ok(result) = self.plugins.call(&tool_call.name, &args) { + result + } else { + Value::String(format!("Tool '{}' not found or failed", tool_call.name)) + }; + + let result_json = result.to_json(); + + tool_results_for_llm.push(serde_json::json!({ + "tool_call_id": tool_call.id, + "name": tool_call.name, + "content": serde_json::to_string(&result_json).unwrap_or_else(|_| result.to_string()) + })); + } + + if !tool_results_for_llm.is_empty() { + let follow_up_prompt = if response.content.is_empty() { + "Tool execution completed. Please provide a response based on the tool results.".to_string() + } else { + format!("{}\n\nTool execution completed. Please provide a final response based on the tool results.", response.content) + }; + + let tool_results_text: Vec = tool_results_for_llm + .iter() + .map(|tr| { + format!( + "Tool {} returned: {}", + tr["name"].as_str().unwrap_or("unknown"), + tr["content"].as_str().unwrap_or("") + ) + }) + .collect(); + + let final_prompt = format!( + "{}\n\nTool Results:\n{}", + follow_up_prompt, + tool_results_text.join("\n") + ); + + let follow_up_request = LLMRequest { + prompt: final_prompt, + model: model_str.clone(), + temperature: temp, + max_tokens: max_toks, + stream: None, + tools: None, // Don't send tools again in follow-up + }; + + let final_response = self + .tokio_runtime + .block_on(llm_runtime.execute(follow_up_request))?; + + Ok(Some(Value::String(final_response.content))) + } else { + Ok(Some(Value::String(response.content))) + } + } else { + Ok(Some(Value::String(response.content))) + } } Statement::ToolCallStatement { tool_name, @@ -624,4 +712,26 @@ impl Executor { }, } } + + fn json_to_value(&self, json: &serde_json::Value) -> Result { + match json { + serde_json::Value::Null => Ok(Value::Null), + serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)), + serde_json::Value::Number(n) => Ok(Value::Number( + n.as_f64().ok_or_else(|| anyhow::anyhow!("Invalid number"))?, + )), + serde_json::Value::String(s) => Ok(Value::String(s.clone())), + serde_json::Value::Array(arr) => { + let values: Result> = arr.iter().map(|v| self.json_to_value(v)).collect(); + Ok(Value::Array(values?)) + } + serde_json::Value::Object(obj) => { + let mut map = HashMap::new(); + for (k, v) in obj { + map.insert(k.clone(), self.json_to_value(v)?); + } + Ok(Value::Object(map)) + } + } + } }