Skip to content

Commit 5f80ad6

Browse files
authored
fix: chat completion with parallel tool call (#7634)
1 parent e91bb6b commit 5f80ad6

File tree

1 file changed

+163
-41
lines changed

1 file changed

+163
-41
lines changed

codex-rs/codex-api/src/sse/chat.rs

Lines changed: 163 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use eventsource_stream::Eventsource;
1010
use futures::Stream;
1111
use futures::StreamExt;
1212
use std::collections::HashMap;
13+
use std::collections::HashSet;
1314
use std::time::Duration;
1415
use tokio::sync::mpsc;
1516
use tokio::time::Instant;
@@ -41,12 +42,17 @@ pub async fn process_chat_sse<S>(
4142

4243
#[derive(Default, Debug)]
4344
struct ToolCallState {
45+
id: Option<String>,
4446
name: Option<String>,
4547
arguments: String,
4648
}
4749

48-
let mut tool_calls: HashMap<String, ToolCallState> = HashMap::new();
49-
let mut tool_call_order: Vec<String> = Vec::new();
50+
let mut tool_calls: HashMap<usize, ToolCallState> = HashMap::new();
51+
let mut tool_call_order: Vec<usize> = Vec::new();
52+
let mut tool_call_order_seen: HashSet<usize> = HashSet::new();
53+
let mut tool_call_index_by_id: HashMap<String, usize> = HashMap::new();
54+
let mut next_tool_call_index = 0usize;
55+
let mut last_tool_call_index: Option<usize> = None;
5056
let mut assistant_item: Option<ResponseItem> = None;
5157
let mut reasoning_item: Option<ResponseItem> = None;
5258
let mut completed_sent = false;
@@ -149,15 +155,40 @@ pub async fn process_chat_sse<S>(
149155

150156
if let Some(tool_call_values) = delta.get("tool_calls").and_then(|c| c.as_array()) {
151157
for tool_call in tool_call_values {
152-
let id = tool_call
153-
.get("id")
154-
.and_then(|i| i.as_str())
155-
.map(str::to_string)
156-
.unwrap_or_else(|| format!("tool-call-{}", tool_call_order.len()));
157-
158-
let call_state = tool_calls.entry(id.clone()).or_default();
159-
if !tool_call_order.contains(&id) {
160-
tool_call_order.push(id.clone());
158+
let mut index = tool_call
159+
.get("index")
160+
.and_then(serde_json::Value::as_u64)
161+
.map(|i| i as usize);
162+
163+
let mut call_id_for_lookup = None;
164+
if let Some(call_id) = tool_call.get("id").and_then(|i| i.as_str()) {
165+
call_id_for_lookup = Some(call_id.to_string());
166+
if let Some(existing) = tool_call_index_by_id.get(call_id) {
167+
index = Some(*existing);
168+
}
169+
}
170+
171+
if index.is_none() && call_id_for_lookup.is_none() {
172+
index = last_tool_call_index;
173+
}
174+
175+
let index = index.unwrap_or_else(|| {
176+
while tool_calls.contains_key(&next_tool_call_index) {
177+
next_tool_call_index += 1;
178+
}
179+
let idx = next_tool_call_index;
180+
next_tool_call_index += 1;
181+
idx
182+
});
183+
184+
let call_state = tool_calls.entry(index).or_default();
185+
if tool_call_order_seen.insert(index) {
186+
tool_call_order.push(index);
187+
}
188+
189+
if let Some(id) = tool_call.get("id").and_then(|i| i.as_str()) {
190+
call_state.id.get_or_insert_with(|| id.to_string());
191+
tool_call_index_by_id.entry(id.to_string()).or_insert(index);
161192
}
162193

163194
if let Some(func) = tool_call.get("function") {
@@ -171,6 +202,8 @@ pub async fn process_chat_sse<S>(
171202
call_state.arguments.push_str(arguments);
172203
}
173204
}
205+
206+
last_tool_call_index = Some(index);
174207
}
175208
}
176209
}
@@ -224,13 +257,25 @@ pub async fn process_chat_sse<S>(
224257
.await;
225258
}
226259

227-
for call_id in tool_call_order.drain(..) {
228-
let state = tool_calls.remove(&call_id).unwrap_or_default();
260+
for index in tool_call_order.drain(..) {
261+
let Some(state) = tool_calls.remove(&index) else {
262+
continue;
263+
};
264+
tool_call_order_seen.remove(&index);
265+
let ToolCallState {
266+
id,
267+
name,
268+
arguments,
269+
} = state;
270+
let Some(name) = name else {
271+
debug!("Skipping tool call at index {index} because name is missing");
272+
continue;
273+
};
229274
let item = ResponseItem::FunctionCall {
230275
id: None,
231-
name: state.name.unwrap_or_default(),
232-
arguments: state.arguments,
233-
call_id: call_id.clone(),
276+
name,
277+
arguments,
278+
call_id: id.unwrap_or_else(|| format!("tool-call-{index}")),
234279
};
235280
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
236281
}
@@ -335,6 +380,59 @@ mod tests {
335380
out
336381
}
337382

383+
#[tokio::test]
384+
async fn concatenates_tool_call_arguments_across_deltas() {
385+
let delta_name = json!({
386+
"choices": [{
387+
"delta": {
388+
"tool_calls": [{
389+
"id": "call_a",
390+
"index": 0,
391+
"function": { "name": "do_a" }
392+
}]
393+
}
394+
}]
395+
});
396+
397+
let delta_args_1 = json!({
398+
"choices": [{
399+
"delta": {
400+
"tool_calls": [{
401+
"index": 0,
402+
"function": { "arguments": "{ \"foo\":" }
403+
}]
404+
}
405+
}]
406+
});
407+
408+
let delta_args_2 = json!({
409+
"choices": [{
410+
"delta": {
411+
"tool_calls": [{
412+
"index": 0,
413+
"function": { "arguments": "1}" }
414+
}]
415+
}
416+
}]
417+
});
418+
419+
let finish = json!({
420+
"choices": [{
421+
"finish_reason": "tool_calls"
422+
}]
423+
});
424+
425+
let body = build_body(&[delta_name, delta_args_1, delta_args_2, finish]);
426+
let events = collect_events(&body).await;
427+
assert_matches!(
428+
&events[..],
429+
[
430+
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. }),
431+
ResponseEvent::Completed { .. }
432+
] if call_id == "call_a" && name == "do_a" && arguments == "{ \"foo\":1}"
433+
);
434+
}
435+
338436
#[tokio::test]
339437
async fn emits_multiple_tool_calls() {
340438
let delta_a = json!({
@@ -367,50 +465,74 @@ mod tests {
367465

368466
let body = build_body(&[delta_a, delta_b, finish]);
369467
let events = collect_events(&body).await;
370-
assert_eq!(events.len(), 3);
371-
372468
assert_matches!(
373-
&events[0],
374-
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. })
375-
if call_id == "call_a" && name == "do_a" && arguments == "{\"foo\":1}"
376-
);
377-
assert_matches!(
378-
&events[1],
379-
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. })
380-
if call_id == "call_b" && name == "do_b" && arguments == "{\"bar\":2}"
469+
&events[..],
470+
[
471+
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id: call_a, name: name_a, arguments: args_a, .. }),
472+
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id: call_b, name: name_b, arguments: args_b, .. }),
473+
ResponseEvent::Completed { .. }
474+
] if call_a == "call_a" && name_a == "do_a" && args_a == "{\"foo\":1}" && call_b == "call_b" && name_b == "do_b" && args_b == "{\"bar\":2}"
381475
);
382-
assert_matches!(events[2], ResponseEvent::Completed { .. });
383476
}
384477

385478
#[tokio::test]
386-
async fn concatenates_tool_call_arguments_across_deltas() {
387-
let delta_name = json!({
388-
"choices": [{
389-
"delta": {
390-
"tool_calls": [{
391-
"id": "call_a",
392-
"function": { "name": "do_a" }
393-
}]
479+
async fn emits_tool_calls_for_multiple_choices() {
480+
let payload = json!({
481+
"choices": [
482+
{
483+
"delta": {
484+
"tool_calls": [{
485+
"id": "call_a",
486+
"index": 0,
487+
"function": { "name": "do_a", "arguments": "{}" }
488+
}]
489+
},
490+
"finish_reason": "tool_calls"
491+
},
492+
{
493+
"delta": {
494+
"tool_calls": [{
495+
"id": "call_b",
496+
"index": 0,
497+
"function": { "name": "do_b", "arguments": "{}" }
498+
}]
499+
},
500+
"finish_reason": "tool_calls"
394501
}
395-
}]
502+
]
396503
});
397504

398-
let delta_args_1 = json!({
505+
let body = build_body(&[payload]);
506+
let events = collect_events(&body).await;
507+
assert_matches!(
508+
&events[..],
509+
[
510+
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id: call_a, name: name_a, arguments: args_a, .. }),
511+
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id: call_b, name: name_b, arguments: args_b, .. }),
512+
ResponseEvent::Completed { .. }
513+
] if call_a == "call_a" && name_a == "do_a" && args_a == "{}" && call_b == "call_b" && name_b == "do_b" && args_b == "{}"
514+
);
515+
}
516+
517+
#[tokio::test]
518+
async fn merges_tool_calls_by_index_when_id_missing_on_subsequent_deltas() {
519+
let delta_with_id = json!({
399520
"choices": [{
400521
"delta": {
401522
"tool_calls": [{
523+
"index": 0,
402524
"id": "call_a",
403-
"function": { "arguments": "{ \"foo\":" }
525+
"function": { "name": "do_a", "arguments": "{ \"foo\":" }
404526
}]
405527
}
406528
}]
407529
});
408530

409-
let delta_args_2 = json!({
531+
let delta_without_id = json!({
410532
"choices": [{
411533
"delta": {
412534
"tool_calls": [{
413-
"id": "call_a",
535+
"index": 0,
414536
"function": { "arguments": "1}" }
415537
}]
416538
}
@@ -423,7 +545,7 @@ mod tests {
423545
}]
424546
});
425547

426-
let body = build_body(&[delta_name, delta_args_1, delta_args_2, finish]);
548+
let body = build_body(&[delta_with_id, delta_without_id, finish]);
427549
let events = collect_events(&body).await;
428550
assert_matches!(
429551
&events[..],

0 commit comments

Comments
 (0)