Skip to content

Commit d349fd5

Browse files
committed
chore: enhance tool call handling and fragment aggregation in tools.py
1 parent 63880e3 commit d349fd5

File tree

1 file changed

+44
-9
lines changed

1 file changed

+44
-9
lines changed

apps/application/flow/tools.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -290,26 +290,60 @@ async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_
290290
agent = create_react_agent(chat_model, tools)
291291
response = agent.astream({"messages": message_list}, stream_mode='messages')
292292

293-
# 用于存储工具调用信息
293+
# 用于存储工具调用信息(按 tool_id)以及按 index 聚合分片
294294
tool_calls_info = {}
295+
_tool_fragments = {} # index -> {'id':..., 'name':..., 'arguments':...}
295296

296297
async for chunk in response:
297298
if isinstance(chunk[0], AIMessageChunk):
298299
tool_calls = chunk[0].additional_kwargs.get('tool_calls', [])
299300
for tool_call in tool_calls:
300-
tool_id = tool_call.get('id', '')
301-
if tool_id:
302-
# 保存工具调用的输入
303-
tool_calls_info[tool_id] = {
304-
'name': tool_call.get('function', {}).get('name', ''),
305-
'input': tool_call.get('function', {}).get('arguments', '')
306-
}
301+
idx = tool_call.get('index')
302+
if idx is None:
303+
continue
304+
entry = _tool_fragments.setdefault(idx, {'id': '', 'name': '', 'arguments': ''})
305+
306+
# 更新 id 与 name(如果有)
307+
if tool_call.get('id'):
308+
entry['id'] = tool_call.get('id')
309+
310+
func = tool_call.get('function', {})
311+
# arguments 可能在 function.arguments 或顶层 arguments
312+
part_args = ''
313+
if isinstance(func, dict) and 'arguments' in func:
314+
part_args = func.get('arguments') or ''
315+
if func.get('name'):
316+
entry['name'] = func.get('name')
317+
else:
318+
part_args = tool_call.get('arguments', '') or ''
319+
320+
# 统一为字符串
321+
if not isinstance(part_args, str):
322+
try:
323+
part_args = json.dumps(part_args, ensure_ascii=False)
324+
except Exception:
325+
part_args = str(part_args)
326+
327+
entry['arguments'] += part_args
328+
329+
# 尝试判断 JSON 是否完整(若 arguments 是 JSON),完整则提交到 tool_calls_info
330+
try:
331+
json.loads(entry['arguments'])
332+
if entry['id']:
333+
tool_calls_info[entry['id']] = {
334+
'name': entry.get('name', ''),
335+
'input': entry['arguments']
336+
}
337+
_tool_fragments.pop(idx, None)
338+
except Exception:
339+
# 如果不是完整 JSON,继续等待后续片段
340+
pass
341+
307342
yield chunk[0]
308343

309344
if mcp_output_enable and isinstance(chunk[0], ToolMessage):
310345
tool_id = chunk[0].tool_call_id
311346
if tool_id in tool_calls_info:
312-
# 合并输入和输出
313347
tool_info = tool_calls_info[tool_id]
314348
content = generate_tool_message_complete(
315349
tool_info['name'],
@@ -335,6 +369,7 @@ def get_real_error(exc):
335369
raise RuntimeError(error_msg) from None
336370

337371

372+
338373
def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True):
339374
"""使用全局事件循环,不创建新实例"""
340375
result_queue = queue.Queue()

0 commit comments

Comments
 (0)