-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathollama_chat_automatic_function_calling.py
More file actions
177 lines (142 loc) · 5.13 KB
/
ollama_chat_automatic_function_calling.py
File metadata and controls
177 lines (142 loc) · 5.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from ollama import ChatResponse, chat
def ollama_automatic_function_calling(
client_fn: chat,
messages: list[dict],
options: dict = None,
**kwargs,
) -> tuple[ChatResponse, list[dict]]:
"""
Automatically handles tool/function calls in a chat loop.
max_turns is reached.
Args:
client_fn: Ollama chat client
messages: List of message dicts for the conversation
max_turns: Maximum number of chat iterations (default: 20)
model: Model name to use (default: "glm-5:cloud")
tools: List of available tool definitions
tool_map: Dict mapping tool names to callable functions
chat_fn: Optional callable to use for chat (default: ollama.chat)
Returns:
(last_response, messages)
"""
model = kwargs.get("model")
if model is None:
raise ValueError("model must be specified")
max_turns = kwargs.get("max_turns", 20)
tools = kwargs.get("tools")
tool_map = kwargs.get("tool_map")
last_response = None
def _to_msg_dict(msg):
if hasattr(msg, "model_dump"):
return msg.model_dump(exclude_none=True)
if isinstance(msg, dict):
return msg
return {
"role": getattr(msg, "role", "assistant"),
"content": getattr(msg, "content", ""),
}
def _extract_tool_call_parts(tool_call):
function = getattr(tool_call, "function", None)
if function is None and isinstance(tool_call, dict):
function = tool_call.get("function", {})
if isinstance(function, dict):
name = function.get("name")
arguments = function.get("arguments") or {}
else:
name = getattr(function, "name", None)
arguments = getattr(function, "arguments", None) or {}
return name, arguments
def _infer_tool_name(t):
if t is None:
return None
if isinstance(t, dict):
fn = t.get("function") or {}
if isinstance(fn, dict):
return fn.get("name")
return None
return getattr(t, "name", None) or getattr(t, "__name__", None)
def _resolve_callable(fn):
wrapped = getattr(fn, "__wrapped__", None)
return wrapped if callable(wrapped) else fn
if tool_map is None and tools:
tool_map = {}
for t in tools:
name = _infer_tool_name(t)
if name:
tool_map[name] = t
if not callable(client_fn):
raise ValueError("client_fn must be a callable")
reserved = {"model", "tools", "messages", "options", "max_turns", "tool_map"}
chat_kwargs = {k: v for k, v in kwargs.items() if k not in reserved}
for _ in range(max_turns):
response = client_fn(
model=model,
tools=tools,
messages=messages,
options=options,
**chat_kwargs,
)
last_response = response
assistant_msg = response.message
messages.append(_to_msg_dict(assistant_msg))
tool_calls = getattr(assistant_msg, "tool_calls", None) or []
if not tool_calls:
return last_response, messages
for tool_call in tool_calls:
tool_name, tool_args = _extract_tool_call_parts(tool_call)
tool_fn = tool_map.get(tool_name) if tool_map else None
if not callable(tool_fn):
messages.append(
{
"role": "tool",
"content": f"Tool '{tool_name}' not found.",
"tool_name": tool_name or "unknown_tool",
}
)
continue
if isinstance(tool_args, str):
try:
import json
tool_args = json.loads(tool_args)
except Exception:
tool_args = {}
try:
result = _resolve_callable(tool_fn)(**(tool_args or {}))
content = str(result)
except Exception as e:
content = f"Tool '{tool_name}' execution failed: {e}"
messages.append(
{
"role": "tool",
"content": content,
"tool_name": tool_name,
}
)
# Safety exit: do not raise, just return the last response.
return last_response, messages
def add(a: float, b: float) -> float:
"""Apply addition"""
return a + b
def sub(a: float, b: float) -> float:
"""Apply subtraction"""
return a - b
def mul(a: float, b: float) -> float:
"""Apply multiplication"""
return a * b
def div(a: float, b: float) -> float:
"""Apply division"""
return a / b
if __name__ == "__main__":
from ollama import chat
response, messages = ollama_automatic_function_calling(
chat,
model="glm-5:cloud",
tools=[add, sub, mul, div],
messages=[{"role": "user", "content": "这个数学题:4+7*9/6是多少?"}],
think="low", # or true, false
options={
"temperature": 1,
"top_p": 1.0,
"top_k": 60,
},
)