Skip to content

Commit 53894c7

Browse files
committed
add support for tools structure for function calling
1 parent 03ac579 commit 53894c7

4 files changed

Lines changed: 497 additions & 18 deletions

File tree

README.md

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,12 @@ Constructor for the ChatSession.
413413
- `client`: An instance of the OpenAI client.
414414
- `opts`: An optional table of options.
415415
- `messages`: An initial array of chat messages
416-
- `functions`: A list of function declarations
416+
- `functions`: A list of function declarations (legacy)
417+
- `tools`: An array of tool definitions (modern tool calling interface)
418+
- `tool_choice`: Controls which tool is called (`"auto"`, `"none"`, or a specific tool)
419+
- `parallel_tool_calls`: Whether the model can make multiple tool calls in a single response
417420
- `temperature`: temperature setting
418-
- `model`: Which chat completion model to use, eg. `gpt-4`, `gpt-3.5-turbo`
421+
- `model`: Which chat completion model to use, eg. `gpt-4.1`, `gpt-4o-mini`
419422

420423
##### `chat:append_message(m, ...)`
421424

@@ -547,6 +550,70 @@ local status, response = client:create_chat_completion({
547550
The `OpenRouter` client extends `OpenAI` and supports all the same methods
548551
including chat completions, chat sessions, and streaming.
549552

553+
## Tool Calling
554+
555+
OpenAI's [tool calling
556+
API](https://platform.openai.com/docs/guides/function-calling) allows models to
557+
request function calls during a conversation. The chat session manages the
558+
back-and-forth of tool calls and results automatically.
559+
560+
```lua
561+
local openai = require("openai")
562+
local cjson = require("cjson")
563+
local client = openai.new(os.getenv("OPENAI_API_KEY"))
564+
565+
-- Define available tools
566+
local tools = {
567+
{
568+
type = "function",
569+
["function"] = {
570+
name = "get_weather",
571+
description = "Get the current weather for a location",
572+
parameters = {
573+
type = "object",
574+
properties = {
575+
location = { type = "string", description = "City name" }
576+
},
577+
required = {"location"}
578+
}
579+
}
580+
}
581+
}
582+
583+
-- Create a chat session with tools
584+
local chat = client:new_chat_session({
585+
model = "gpt-4.1",
586+
tools = tools,
587+
tool_choice = "auto"
588+
})
589+
590+
-- Send a message that may trigger a tool call
591+
local res = chat:send("What's the weather in Paris?")
592+
593+
-- When a tool call is requested, res is a table with tool_calls
594+
if type(res) == "table" and res.tool_calls then
595+
for _, tool_call in ipairs(res.tool_calls) do
596+
if tool_call["function"].name == "get_weather" then
597+
local args = cjson.decode(tool_call["function"].arguments)
598+
599+
-- Execute the function and send the result back
600+
local result = get_weather(args.location) -- your implementation
601+
local final = chat:send({
602+
role = "tool",
603+
tool_call_id = tool_call.id,
604+
content = cjson.encode(result)
605+
})
606+
607+
print(final) -- The model's response incorporating the tool result
608+
end
609+
end
610+
end
611+
```
612+
613+
Tool calling also works with streaming — tool call deltas are automatically
614+
aggregated across chunks, and the final response includes the complete
615+
`tool_calls` array.
616+
550617
## Appendix
551618

552619
### Chat Session With Functions

openai/chat_completions.lua

Lines changed: 121 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,45 @@ local content_format = types.string + types.array_of(types.one_of({
2020
})
2121
})
2222
}))
23+
local tool_call_shape = types.partial({
24+
id = empty + types.string,
25+
type = empty + types.string,
26+
["function"] = empty + types.partial({
27+
name = types.string,
28+
arguments = types.string
29+
})
30+
})
31+
local tool_calls_list = types.array_of(tool_call_shape)
2332
local test_message = types.one_of({
2433
types.partial({
2534
role = types.one_of({
2635
"system",
27-
"user",
28-
"assistant"
36+
"user"
2937
}),
3038
content = empty + content_format,
39+
name = empty + types.string
40+
}),
41+
types.partial({
42+
role = "assistant",
43+
content = empty + content_format,
3144
name = empty + types.string,
32-
function_call = empty + types.table
45+
function_call = empty + types.table,
46+
tool_calls = empty + tool_calls_list
3347
}),
3448
types.partial({
3549
role = types.one_of({
3650
"function"
3751
}),
3852
name = types.string,
3953
content = empty + types.string
54+
}),
55+
types.partial({
56+
role = types.one_of({
57+
"tool"
58+
}),
59+
tool_call_id = types.string,
60+
content = empty + types.string,
61+
name = empty + types.string
4062
})
4163
})
4264
local test_function = types.shape({
@@ -57,6 +79,11 @@ local parse_chat_response = types.partial({
5779
arguments = types.string
5880
})
5981
}),
82+
types.partial({
83+
role = "assistant",
84+
content = empty + content_format,
85+
tool_calls = tool_calls_list
86+
}),
6087
types.partial({
6188
role = "assistant",
6289
content = types.string:tag("response")
@@ -76,7 +103,16 @@ local parse_completion_chunk = types.partial({
76103
choices = types.shape({
77104
types.partial({
78105
delta = types.partial({
79-
["content"] = types.string:tag("content")
106+
["content"] = empty + types.string:tag("content"),
107+
tool_calls = (empty + types.array_of(types.partial({
108+
id = empty + types.string,
109+
index = empty + types.number,
110+
type = empty + types.string,
111+
["function"] = empty + types.partial({
112+
name = empty + types.string,
113+
arguments = empty + types.string
114+
})
115+
}))):tag("tool_calls")
80116
}),
81117
index = types.number:tag("index")
82118
})
@@ -123,6 +159,9 @@ do
123159
local status, response = self.client:chat(self.messages, {
124160
function_call = self.opts.function_call,
125161
functions = self.functions,
162+
tools = self.tools,
163+
tool_choice = self.opts.tool_choice,
164+
parallel_tool_calls = self.opts.parallel_tool_calls,
126165
model = self.opts.model,
127166
temperature = self.opts.temperature,
128167
stream = stream_callback and true or nil,
@@ -146,23 +185,84 @@ do
146185
if stream_callback then
147186
assert(type(response) == "string", "Expected string response from streaming output")
148187
local parts = { }
188+
local aggregated_tool_calls = { }
149189
local f = create_stream_filter(function(c)
150190
do
151191
local parsed = parse_completion_chunk(c)
152192
if parsed then
153-
return table.insert(parts, parsed.content)
193+
if parsed.content then
194+
table.insert(parts, parsed.content)
195+
end
196+
if parsed.tool_calls then
197+
local _list_0 = parsed.tool_calls
198+
for _index_0 = 1, #_list_0 do
199+
local tool_delta = _list_0[_index_0]
200+
local tool_index = (tool_delta.index or 0) + 1
201+
local dest = aggregated_tool_calls[tool_index]
202+
if not (dest) then
203+
dest = { }
204+
aggregated_tool_calls[tool_index] = dest
205+
end
206+
if tool_delta.id then
207+
dest.id = tool_delta.id
208+
end
209+
if tool_delta.type then
210+
dest.type = tool_delta.type
211+
end
212+
if tool_delta["function"] then
213+
local _update_0 = "function"
214+
dest[_update_0] = dest[_update_0] or { }
215+
local delta_fn = tool_delta["function"]
216+
if delta_fn.name then
217+
dest["function"].name = delta_fn.name
218+
end
219+
if delta_fn.arguments then
220+
local current_args = dest["function"].arguments or ""
221+
dest["function"].arguments = current_args .. delta_fn.arguments
222+
end
223+
end
224+
end
225+
end
154226
end
155227
end
156228
end)
157229
f(response)
158230
local message = {
159-
role = "assistant",
160-
content = table.concat(parts)
231+
role = "assistant"
161232
}
233+
local combined = table.concat(parts)
234+
if #combined > 0 then
235+
message.content = combined
236+
end
237+
if next(aggregated_tool_calls) then
238+
message.tool_calls = { }
239+
for _index_0 = 1, #aggregated_tool_calls do
240+
local tool = aggregated_tool_calls[_index_0]
241+
if tool then
242+
tool.type = tool.type or "function"
243+
local tool_entry = {
244+
type = tool.type
245+
}
246+
if tool.id then
247+
tool_entry.id = tool.id
248+
end
249+
if tool["function"] then
250+
tool_entry["function"] = { }
251+
if tool["function"].name then
252+
tool_entry["function"].name = tool["function"].name
253+
end
254+
if tool["function"].arguments then
255+
tool_entry["function"].arguments = tool["function"].arguments
256+
end
257+
end
258+
table.insert(message.tool_calls, tool_entry)
259+
end
260+
end
261+
end
162262
if append_response then
163263
self:append_message(message)
164264
end
165-
return message.content
265+
return message.content or message
166266
end
167267
local out, err = parse_chat_response(response)
168268
if not (out) then
@@ -177,6 +277,9 @@ do
177277
if out.message.function_call then
178278
message.function_call = out.message.function_call
179279
end
280+
if out.message.tool_calls then
281+
message.tool_calls = out.message.tool_calls
282+
end
180283
self:append_message(message)
181284
end
182285
return out.response or out.message
@@ -202,6 +305,14 @@ do
202305
table.insert(self.functions, func)
203306
end
204307
end
308+
if type(self.opts.tools) == "table" then
309+
self.tools = { }
310+
local _list_0 = self.opts.tools
311+
for _index_0 = 1, #_list_0 do
312+
local tool = _list_0[_index_0]
313+
table.insert(self.tools, tool)
314+
end
315+
end
205316
end,
206317
__base = _base_0,
207318
__name = "ChatSession"
@@ -219,5 +330,6 @@ end
219330
return {
220331
ChatSession = ChatSession,
221332
test_message = test_message,
222-
parse_completion_chunk = parse_completion_chunk
333+
parse_completion_chunk = parse_completion_chunk,
334+
tool_call_shape = tool_call_shape
223335
}

0 commit comments

Comments
 (0)