@@ -20,23 +20,45 @@ local content_format = types.string + types.array_of(types.one_of({
2020 })
2121 })
2222}))
23+ local tool_call_shape = types .partial ({
24+ id = empty + types .string ,
25+ type = empty + types .string ,
26+ [" function" ] = empty + types .partial ({
27+ name = types .string ,
28+ arguments = types .string
29+ })
30+ })
31+ local tool_calls_list = types .array_of (tool_call_shape )
2332local test_message = types .one_of ({
2433 types .partial ({
2534 role = types .one_of ({
2635 " system" ,
27- " user" ,
28- " assistant"
36+ " user"
2937 }),
3038 content = empty + content_format ,
39+ name = empty + types .string
40+ }),
41+ types .partial ({
42+ role = " assistant" ,
43+ content = empty + content_format ,
3144 name = empty + types .string ,
32- function_call = empty + types .table
45+ function_call = empty + types .table ,
46+ tool_calls = empty + tool_calls_list
3347 }),
3448 types .partial ({
3549 role = types .one_of ({
3650 " function"
3751 }),
3852 name = types .string ,
3953 content = empty + types .string
54+ }),
55+ types .partial ({
56+ role = types .one_of ({
57+ " tool"
58+ }),
59+ tool_call_id = types .string ,
60+ content = empty + types .string ,
61+ name = empty + types .string
4062 })
4163})
4264local test_function = types .shape ({
@@ -57,6 +79,11 @@ local parse_chat_response = types.partial({
5779 arguments = types .string
5880 })
5981 }),
82+ types .partial ({
83+ role = " assistant" ,
84+ content = empty + content_format ,
85+ tool_calls = tool_calls_list
86+ }),
6087 types .partial ({
6188 role = " assistant" ,
6289 content = types .string :tag (" response" )
@@ -76,7 +103,16 @@ local parse_completion_chunk = types.partial({
76103 choices = types .shape ({
77104 types .partial ({
78105 delta = types .partial ({
79- [" content" ] = types .string :tag (" content" )
106+ [" content" ] = empty + types .string :tag (" content" ),
107+ tool_calls = (empty + types .array_of (types .partial ({
108+ id = empty + types .string ,
109+ index = empty + types .number ,
110+ type = empty + types .string ,
111+ [" function" ] = empty + types .partial ({
112+ name = empty + types .string ,
113+ arguments = empty + types .string
114+ })
115+ }))):tag (" tool_calls" )
80116 }),
81117 index = types .number :tag (" index" )
82118 })
123159 local status , response = self .client :chat (self .messages , {
124160 function_call = self .opts .function_call ,
125161 functions = self .functions ,
162+ tools = self .tools ,
163+ tool_choice = self .opts .tool_choice ,
164+ parallel_tool_calls = self .opts .parallel_tool_calls ,
126165 model = self .opts .model ,
127166 temperature = self .opts .temperature ,
128167 stream = stream_callback and true or nil ,
146185 if stream_callback then
147186 assert (type (response ) == " string" , " Expected string response from streaming output" )
148187 local parts = { }
188+ local aggregated_tool_calls = { }
149189 local f = create_stream_filter (function (c )
150190 do
151191 local parsed = parse_completion_chunk (c )
152192 if parsed then
153- return table.insert (parts , parsed .content )
193+ if parsed .content then
194+ table.insert (parts , parsed .content )
195+ end
196+ if parsed .tool_calls then
197+ local _list_0 = parsed .tool_calls
198+ for _index_0 = 1 , # _list_0 do
199+ local tool_delta = _list_0 [_index_0 ]
200+ local tool_index = (tool_delta .index or 0 ) + 1
201+ local dest = aggregated_tool_calls [tool_index ]
202+ if not (dest ) then
203+ dest = { }
204+ aggregated_tool_calls [tool_index ] = dest
205+ end
206+ if tool_delta .id then
207+ dest .id = tool_delta .id
208+ end
209+ if tool_delta .type then
210+ dest .type = tool_delta .type
211+ end
212+ if tool_delta [" function" ] then
213+ local _update_0 = " function"
214+ dest [_update_0 ] = dest [_update_0 ] or { }
215+ local delta_fn = tool_delta [" function" ]
216+ if delta_fn .name then
217+ dest [" function" ].name = delta_fn .name
218+ end
219+ if delta_fn .arguments then
220+ local current_args = dest [" function" ].arguments or " "
221+ dest [" function" ].arguments = current_args .. delta_fn .arguments
222+ end
223+ end
224+ end
225+ end
154226 end
155227 end
156228 end )
157229 f (response )
158230 local message = {
159- role = " assistant" ,
160- content = table.concat (parts )
231+ role = " assistant"
161232 }
233+ local combined = table.concat (parts )
234+ if # combined > 0 then
235+ message .content = combined
236+ end
237+ if next (aggregated_tool_calls ) then
238+ message .tool_calls = { }
239+ for _index_0 = 1 , # aggregated_tool_calls do
240+ local tool = aggregated_tool_calls [_index_0 ]
241+ if tool then
242+ tool .type = tool .type or " function"
243+ local tool_entry = {
244+ type = tool .type
245+ }
246+ if tool .id then
247+ tool_entry .id = tool .id
248+ end
249+ if tool [" function" ] then
250+ tool_entry [" function" ] = { }
251+ if tool [" function" ].name then
252+ tool_entry [" function" ].name = tool [" function" ].name
253+ end
254+ if tool [" function" ].arguments then
255+ tool_entry [" function" ].arguments = tool [" function" ].arguments
256+ end
257+ end
258+ table.insert (message .tool_calls , tool_entry )
259+ end
260+ end
261+ end
162262 if append_response then
163263 self :append_message (message )
164264 end
165- return message .content
265+ return message .content or message
166266 end
167267 local out , err = parse_chat_response (response )
168268 if not (out ) then
177277 if out .message .function_call then
178278 message .function_call = out .message .function_call
179279 end
280+ if out .message .tool_calls then
281+ message .tool_calls = out .message .tool_calls
282+ end
180283 self :append_message (message )
181284 end
182285 return out .response or out .message
202305 table.insert (self .functions , func )
203306 end
204307 end
308+ if type (self .opts .tools ) == " table" then
309+ self .tools = { }
310+ local _list_0 = self .opts .tools
311+ for _index_0 = 1 , # _list_0 do
312+ local tool = _list_0 [_index_0 ]
313+ table.insert (self .tools , tool )
314+ end
315+ end
205316 end ,
206317 __base = _base_0 ,
207318 __name = " ChatSession"
219330return {
220331 ChatSession = ChatSession ,
221332 test_message = test_message ,
222- parse_completion_chunk = parse_completion_chunk
333+ parse_completion_chunk = parse_completion_chunk ,
334+ tool_call_shape = tool_call_shape
223335}
0 commit comments