@@ -10,6 +10,7 @@ use eventsource_stream::Eventsource;
1010use futures:: Stream ;
1111use futures:: StreamExt ;
1212use std:: collections:: HashMap ;
13+ use std:: collections:: HashSet ;
1314use std:: time:: Duration ;
1415use tokio:: sync:: mpsc;
1516use tokio:: time:: Instant ;
@@ -41,12 +42,17 @@ pub async fn process_chat_sse<S>(
4142
4243 #[ derive( Default , Debug ) ]
4344 struct ToolCallState {
45+ id : Option < String > ,
4446 name : Option < String > ,
4547 arguments : String ,
4648 }
4749
48- let mut tool_calls: HashMap < String , ToolCallState > = HashMap :: new ( ) ;
49- let mut tool_call_order: Vec < String > = Vec :: new ( ) ;
50+ let mut tool_calls: HashMap < usize , ToolCallState > = HashMap :: new ( ) ;
51+ let mut tool_call_order: Vec < usize > = Vec :: new ( ) ;
52+ let mut tool_call_order_seen: HashSet < usize > = HashSet :: new ( ) ;
53+ let mut tool_call_index_by_id: HashMap < String , usize > = HashMap :: new ( ) ;
54+ let mut next_tool_call_index = 0usize ;
55+ let mut last_tool_call_index: Option < usize > = None ;
5056 let mut assistant_item: Option < ResponseItem > = None ;
5157 let mut reasoning_item: Option < ResponseItem > = None ;
5258 let mut completed_sent = false ;
@@ -149,15 +155,40 @@ pub async fn process_chat_sse<S>(
149155
150156 if let Some ( tool_call_values) = delta. get ( "tool_calls" ) . and_then ( |c| c. as_array ( ) ) {
151157 for tool_call in tool_call_values {
152- let id = tool_call
153- . get ( "id" )
154- . and_then ( |i| i. as_str ( ) )
155- . map ( str:: to_string)
156- . unwrap_or_else ( || format ! ( "tool-call-{}" , tool_call_order. len( ) ) ) ;
157-
158- let call_state = tool_calls. entry ( id. clone ( ) ) . or_default ( ) ;
159- if !tool_call_order. contains ( & id) {
160- tool_call_order. push ( id. clone ( ) ) ;
158+ let mut index = tool_call
159+ . get ( "index" )
160+ . and_then ( serde_json:: Value :: as_u64)
161+ . map ( |i| i as usize ) ;
162+
163+ let mut call_id_for_lookup = None ;
164+ if let Some ( call_id) = tool_call. get ( "id" ) . and_then ( |i| i. as_str ( ) ) {
165+ call_id_for_lookup = Some ( call_id. to_string ( ) ) ;
166+ if let Some ( existing) = tool_call_index_by_id. get ( call_id) {
167+ index = Some ( * existing) ;
168+ }
169+ }
170+
171+ if index. is_none ( ) && call_id_for_lookup. is_none ( ) {
172+ index = last_tool_call_index;
173+ }
174+
175+ let index = index. unwrap_or_else ( || {
176+ while tool_calls. contains_key ( & next_tool_call_index) {
177+ next_tool_call_index += 1 ;
178+ }
179+ let idx = next_tool_call_index;
180+ next_tool_call_index += 1 ;
181+ idx
182+ } ) ;
183+
184+ let call_state = tool_calls. entry ( index) . or_default ( ) ;
185+ if tool_call_order_seen. insert ( index) {
186+ tool_call_order. push ( index) ;
187+ }
188+
189+ if let Some ( id) = tool_call. get ( "id" ) . and_then ( |i| i. as_str ( ) ) {
190+ call_state. id . get_or_insert_with ( || id. to_string ( ) ) ;
191+ tool_call_index_by_id. entry ( id. to_string ( ) ) . or_insert ( index) ;
161192 }
162193
163194 if let Some ( func) = tool_call. get ( "function" ) {
@@ -171,6 +202,8 @@ pub async fn process_chat_sse<S>(
171202 call_state. arguments . push_str ( arguments) ;
172203 }
173204 }
205+
206+ last_tool_call_index = Some ( index) ;
174207 }
175208 }
176209 }
@@ -224,13 +257,25 @@ pub async fn process_chat_sse<S>(
224257 . await ;
225258 }
226259
227- for call_id in tool_call_order. drain ( ..) {
228- let state = tool_calls. remove ( & call_id) . unwrap_or_default ( ) ;
260+ for index in tool_call_order. drain ( ..) {
261+ let Some ( state) = tool_calls. remove ( & index) else {
262+ continue ;
263+ } ;
264+ tool_call_order_seen. remove ( & index) ;
265+ let ToolCallState {
266+ id,
267+ name,
268+ arguments,
269+ } = state;
270+ let Some ( name) = name else {
271+ debug ! ( "Skipping tool call at index {index} because name is missing" ) ;
272+ continue ;
273+ } ;
229274 let item = ResponseItem :: FunctionCall {
230275 id : None ,
231- name : state . name . unwrap_or_default ( ) ,
232- arguments : state . arguments ,
233- call_id : call_id . clone ( ) ,
276+ name,
277+ arguments,
278+ call_id : id . unwrap_or_else ( || format ! ( "tool-call-{index}" ) ) ,
234279 } ;
235280 let _ = tx_event. send ( Ok ( ResponseEvent :: OutputItemDone ( item) ) ) . await ;
236281 }
@@ -335,6 +380,59 @@ mod tests {
335380 out
336381 }
337382
383+ #[ tokio:: test]
384+ async fn concatenates_tool_call_arguments_across_deltas ( ) {
385+ let delta_name = json ! ( {
386+ "choices" : [ {
387+ "delta" : {
388+ "tool_calls" : [ {
389+ "id" : "call_a" ,
390+ "index" : 0 ,
391+ "function" : { "name" : "do_a" }
392+ } ]
393+ }
394+ } ]
395+ } ) ;
396+
397+ let delta_args_1 = json ! ( {
398+ "choices" : [ {
399+ "delta" : {
400+ "tool_calls" : [ {
401+ "index" : 0 ,
402+ "function" : { "arguments" : "{ \" foo\" :" }
403+ } ]
404+ }
405+ } ]
406+ } ) ;
407+
408+ let delta_args_2 = json ! ( {
409+ "choices" : [ {
410+ "delta" : {
411+ "tool_calls" : [ {
412+ "index" : 0 ,
413+ "function" : { "arguments" : "1}" }
414+ } ]
415+ }
416+ } ]
417+ } ) ;
418+
419+ let finish = json ! ( {
420+ "choices" : [ {
421+ "finish_reason" : "tool_calls"
422+ } ]
423+ } ) ;
424+
425+ let body = build_body ( & [ delta_name, delta_args_1, delta_args_2, finish] ) ;
426+ let events = collect_events ( & body) . await ;
427+ assert_matches ! (
428+ & events[ ..] ,
429+ [
430+ ResponseEvent :: OutputItemDone ( ResponseItem :: FunctionCall { call_id, name, arguments, .. } ) ,
431+ ResponseEvent :: Completed { .. }
432+ ] if call_id == "call_a" && name == "do_a" && arguments == "{ \" foo\" :1}"
433+ ) ;
434+ }
435+
338436 #[ tokio:: test]
339437 async fn emits_multiple_tool_calls ( ) {
340438 let delta_a = json ! ( {
@@ -367,50 +465,74 @@ mod tests {
367465
368466 let body = build_body ( & [ delta_a, delta_b, finish] ) ;
369467 let events = collect_events ( & body) . await ;
370- assert_eq ! ( events. len( ) , 3 ) ;
371-
372468 assert_matches ! (
373- & events[ 0 ] ,
374- ResponseEvent :: OutputItemDone ( ResponseItem :: FunctionCall { call_id, name, arguments, .. } )
375- if call_id == "call_a" && name == "do_a" && arguments == "{\" foo\" :1}"
376- ) ;
377- assert_matches ! (
378- & events[ 1 ] ,
379- ResponseEvent :: OutputItemDone ( ResponseItem :: FunctionCall { call_id, name, arguments, .. } )
380- if call_id == "call_b" && name == "do_b" && arguments == "{\" bar\" :2}"
469+ & events[ ..] ,
470+ [
471+ ResponseEvent :: OutputItemDone ( ResponseItem :: FunctionCall { call_id: call_a, name: name_a, arguments: args_a, .. } ) ,
472+ ResponseEvent :: OutputItemDone ( ResponseItem :: FunctionCall { call_id: call_b, name: name_b, arguments: args_b, .. } ) ,
473+ ResponseEvent :: Completed { .. }
474+ ] if call_a == "call_a" && name_a == "do_a" && args_a == "{\" foo\" :1}" && call_b == "call_b" && name_b == "do_b" && args_b == "{\" bar\" :2}"
381475 ) ;
382- assert_matches ! ( events[ 2 ] , ResponseEvent :: Completed { .. } ) ;
383476 }
384477
385478 #[ tokio:: test]
386- async fn concatenates_tool_call_arguments_across_deltas ( ) {
387- let delta_name = json ! ( {
388- "choices" : [ {
389- "delta" : {
390- "tool_calls" : [ {
391- "id" : "call_a" ,
392- "function" : { "name" : "do_a" }
393- } ]
479+ async fn emits_tool_calls_for_multiple_choices ( ) {
480+ let payload = json ! ( {
481+ "choices" : [
482+ {
483+ "delta" : {
484+ "tool_calls" : [ {
485+ "id" : "call_a" ,
486+ "index" : 0 ,
487+ "function" : { "name" : "do_a" , "arguments" : "{}" }
488+ } ]
489+ } ,
490+ "finish_reason" : "tool_calls"
491+ } ,
492+ {
493+ "delta" : {
494+ "tool_calls" : [ {
495+ "id" : "call_b" ,
496+ "index" : 0 ,
497+ "function" : { "name" : "do_b" , "arguments" : "{}" }
498+ } ]
499+ } ,
500+ "finish_reason" : "tool_calls"
394501 }
395- } ]
502+ ]
396503 } ) ;
397504
398- let delta_args_1 = json ! ( {
505+ let body = build_body ( & [ payload] ) ;
506+ let events = collect_events ( & body) . await ;
507+ assert_matches ! (
508+ & events[ ..] ,
509+ [
510+ ResponseEvent :: OutputItemDone ( ResponseItem :: FunctionCall { call_id: call_a, name: name_a, arguments: args_a, .. } ) ,
511+ ResponseEvent :: OutputItemDone ( ResponseItem :: FunctionCall { call_id: call_b, name: name_b, arguments: args_b, .. } ) ,
512+ ResponseEvent :: Completed { .. }
513+ ] if call_a == "call_a" && name_a == "do_a" && args_a == "{}" && call_b == "call_b" && name_b == "do_b" && args_b == "{}"
514+ ) ;
515+ }
516+
517+ #[ tokio:: test]
518+ async fn merges_tool_calls_by_index_when_id_missing_on_subsequent_deltas ( ) {
519+ let delta_with_id = json ! ( {
399520 "choices" : [ {
400521 "delta" : {
401522 "tool_calls" : [ {
523+ "index" : 0 ,
402524 "id" : "call_a" ,
403- "function" : { "arguments" : "{ \" foo\" :" }
525+ "function" : { "name" : "do_a" , " arguments": "{ \" foo\" :" }
404526 } ]
405527 }
406528 } ]
407529 } ) ;
408530
409- let delta_args_2 = json ! ( {
531+ let delta_without_id = json ! ( {
410532 "choices" : [ {
411533 "delta" : {
412534 "tool_calls" : [ {
413- "id " : "call_a" ,
535+ "index " : 0 ,
414536 "function" : { "arguments" : "1}" }
415537 } ]
416538 }
@@ -423,7 +545,7 @@ mod tests {
423545 } ]
424546 } ) ;
425547
426- let body = build_body ( & [ delta_name , delta_args_1 , delta_args_2 , finish] ) ;
548+ let body = build_body ( & [ delta_with_id , delta_without_id , finish] ) ;
427549 let events = collect_events ( & body) . await ;
428550 assert_matches ! (
429551 & events[ ..] ,
0 commit comments