Skip to content

Commit 2aac521

Browse files
Disable parallel tool calls final answer (#1539)
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
1 parent 464612a commit 2aac521

File tree

2 files changed

+92
-25
lines changed

2 files changed

+92
-25
lines changed

src/smolagents/agents.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,11 @@ def _step_stream(
13011301
yield output
13021302
if isinstance(output, ToolOutput):
13031303
if output.is_final_answer:
1304+
if len(chat_message.tool_calls) > 1:
1305+
raise AgentExecutionError(
1306+
"If you want to return an answer, please do not perform any other tool calls than the final answer tool call!",
1307+
self.logger,
1308+
)
13041309
if got_final_answer:
13051310
raise AgentToolExecutionError(
13061311
"You returned multiple final answers. Please return only one single final answer!",

tests/test_agents.py

Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -439,9 +439,7 @@ def fake_image_generation_tool(prompt: str) -> PIL.Image.Image:
439439

440440
return PIL.Image.open(shared_datadir / "000000039769.png")
441441

442-
agent = ToolCallingAgent(
443-
tools=[fake_image_generation_tool], model=FakeToolCallModelImage(), verbosity_level=10
444-
)
442+
agent = ToolCallingAgent(tools=[fake_image_generation_tool], model=FakeToolCallModelImage())
445443
output = agent.run("Make me an image.")
446444
assert isinstance(output, AgentImage)
447445
assert isinstance(agent.state["image.png"], PIL.Image.Image)
@@ -567,7 +565,6 @@ def test_function_persistence_across_steps(self):
567565
model=FakeCodeModelFunctionDef(),
568566
max_steps=2,
569567
additional_authorized_imports=["numpy"],
570-
verbosity_level=100,
571568
)
572569
res = agent.run("ok")
573570
assert res[0] == 0.5
@@ -652,7 +649,7 @@ def weather_api(location: str, celsius: str = "") -> str:
652649
device_map="auto",
653650
do_sample=False,
654651
)
655-
agent = ToolCallingAgent(model=model, tools=[weather_api], max_steps=1, verbosity_level=10)
652+
agent = ToolCallingAgent(model=model, tools=[weather_api], max_steps=1)
656653
task = "What is the weather in Paris? "
657654
agent.run(task)
658655
assert agent.memory.steps[0].task == task
@@ -679,7 +676,6 @@ def check_always_fails(final_answer, agent_memory):
679676
model=FakeCodeModel(),
680677
tools=[],
681678
final_answer_checks=[lambda x, y: x == 7.2904],
682-
verbosity_level=1000,
683679
)
684680
output = agent.run("Dummy task.")
685681
assert output == 7.2904 # Check that output is correct
@@ -1526,29 +1522,29 @@ def weather_api(location: str, date: str) -> str:
15261522
assert agent.memory.steps[1].observations == "The weather in Paris on date:today is sunny."
15271523

15281524
@patch("openai.OpenAI")
1529-
def test_toolcalling_agent_stream_outputs_multiple_tool_calls(self, mock_openai_client, test_tool):
1530-
"""Test that ToolCallingAgent with stream_outputs=True returns the first final_answer when multiple are called."""
1525+
def test_toolcalling_agent_stream_logs_multiple_tool_calls_observations(self, mock_openai_client, test_tool):
1526+
"""Test that ToolCallingAgent with stream_outputs=True logs the observations of all tool calls when multiple are called."""
15311527
mock_client = mock_openai_client.return_value
15321528
from smolagents import OpenAIServerModel
15331529

1534-
# Mock streaming response with multiple final_answer calls
1530+
# Mock streaming response with multiple tool calls
15351531
mock_deltas = [
15361532
ChoiceDelta(role=MessageRole.ASSISTANT),
15371533
ChoiceDelta(
15381534
tool_calls=[
15391535
ChoiceDeltaToolCall(
15401536
index=0,
15411537
id="call_1",
1542-
function=ChoiceDeltaToolCallFunction(name="final_answer"),
1538+
function=ChoiceDeltaToolCallFunction(name="test_tool"),
15431539
type="function",
15441540
)
15451541
]
15461542
),
15471543
ChoiceDelta(
1548-
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='{"an'))]
1544+
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='{"in'))]
15491545
),
15501546
ChoiceDelta(
1551-
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='swer"'))]
1547+
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='put"'))]
15521548
),
15531549
ChoiceDelta(
15541550
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments=': "out'))]
@@ -1605,13 +1601,85 @@ def __init__(self, delta):
16051601
model = OpenAIServerModel(model_id="fakemodel")
16061602

16071603
agent = ToolCallingAgent(model=model, tools=[test_tool], max_steps=1, stream_outputs=True)
1608-
result = agent.run("Make 2 calls to final answer: return both 'output1' and 'output2'")
1609-
assert len(agent.memory.steps[-1].model_output_message.tool_calls) == 2
1610-
assert agent.memory.steps[-1].model_output_message.tool_calls[0].function.name == "final_answer"
1611-
assert agent.memory.steps[-1].model_output_message.tool_calls[1].function.name == "test_tool"
1604+
agent.run("Dummy task")
1605+
assert agent.memory.steps[1].model_output_message.tool_calls[0].function.name == "test_tool"
1606+
assert agent.memory.steps[1].model_output_message.tool_calls[1].function.name == "test_tool"
1607+
assert agent.memory.steps[1].observations == "Processed: output1\nProcessed: output2"
1608+
1609+
@patch("openai.OpenAI")
1610+
def test_toolcalling_agent_final_answer_cannot_be_called_with_parallel_tool_calls(
1611+
self, mock_openai_client, test_tool
1612+
):
1613+
"""Test that ToolCallingAgent with stream_outputs=True returns the all tool calls when multiple are called."""
1614+
mock_client = mock_openai_client.return_value
1615+
1616+
from smolagents import OpenAIServerModel
16121617

1613-
# The agent should return the final answer call
1614-
assert result == "output1"
1618+
class ExtendedChatMessage(ChatMessage):
1619+
def __init__(self, *args, usage, **kwargs):
1620+
super().__init__(*args, **kwargs)
1621+
1622+
def model_dump(self, include=None):
1623+
return super().model_dump_json()
1624+
1625+
class MockChoice:
1626+
def __init__(self, chat_message):
1627+
self.message = chat_message
1628+
1629+
class MockChatCompletion:
1630+
def __init__(self, chat_message):
1631+
self.choices = [MockChoice(chat_message)]
1632+
self.usage = MockTokenUsage(prompt_tokens=10, completion_tokens=20)
1633+
1634+
class MockTokenUsage:
1635+
def __init__(self, prompt_tokens, completion_tokens):
1636+
self.prompt_tokens = prompt_tokens
1637+
self.completion_tokens = completion_tokens
1638+
1639+
from dataclasses import asdict
1640+
1641+
class ExtendedChatCompletionOutputMessage(ChatCompletionOutputMessage):
1642+
def __init__(self, *args, usage, **kwargs):
1643+
super().__init__(*args, **kwargs)
1644+
self.usage = usage
1645+
1646+
def model_dump(self, include=None):
1647+
print("TOOL CALLS", self.tool_calls)
1648+
return {
1649+
"role": self.role,
1650+
"content": self.content,
1651+
"tool_calls": [asdict(tc) for tc in self.tool_calls],
1652+
}
1653+
1654+
mock_client.chat.completions.create.return_value = MockChatCompletion(
1655+
ExtendedChatCompletionOutputMessage(
1656+
role=MessageRole.ASSISTANT,
1657+
content=None,
1658+
tool_calls=[
1659+
ChatMessageToolCall(
1660+
id="call_0",
1661+
type="function",
1662+
function=ChatMessageToolCallFunction(name="test_tool", arguments={"input": "out1"}),
1663+
),
1664+
ChatMessageToolCall(
1665+
id="1",
1666+
type="function",
1667+
function=ChatMessageToolCallFunction(name="final_answer", arguments={"answer": "out1"}),
1668+
),
1669+
],
1670+
usage=MockTokenUsage(prompt_tokens=10, completion_tokens=20),
1671+
)
1672+
)
1673+
1674+
model = OpenAIServerModel(model_id="fakemodel")
1675+
1676+
agent = ToolCallingAgent(model=model, tools=[test_tool], max_steps=1)
1677+
agent.run("Dummy task")
1678+
assert agent.memory.steps[1].error is not None
1679+
assert (
1680+
"do not perform any other tool calls than the final answer tool call!"
1681+
in agent.memory.steps[1].error.message
1682+
)
16151683

16161684
@patch("huggingface_hub.InferenceClient")
16171685
def test_toolcalling_agent_api_misformatted_output(self, mock_inference_client):
@@ -1690,18 +1758,12 @@ def forward(self, answer1: str, answer2: str) -> str:
16901758
name="final_answer", arguments={"answer1": "1", "answer2": "2"}
16911759
),
16921760
),
1693-
ChatMessageToolCall(
1694-
id="call_1",
1695-
type="function",
1696-
function=ChatMessageToolCallFunction(name="test_tool", arguments={"input": "3"}),
1697-
),
16981761
],
16991762
)
17001763
agent = ToolCallingAgent(tools=[test_tool, CustomFinalAnswerToolWithCustomInputs()], model=model)
17011764
answer = agent.run("Fake task.")
17021765
assert answer == "1 and 2"
17031766
assert agent.memory.steps[-1].model_output_message.tool_calls[0].function.name == "final_answer"
1704-
assert agent.memory.steps[-1].model_output_message.tool_calls[1].function.name == "test_tool"
17051767

17061768
@pytest.mark.parametrize(
17071769
"test_case",
@@ -1932,7 +1994,7 @@ def test_errors_show_offending_line_and_error(self):
19321994
assert "ValueError" in str(agent.memory.steps)
19331995

19341996
def test_error_saves_previous_print_outputs(self):
1935-
agent = CodeAgent(tools=[PythonInterpreterTool()], model=FakeCodeModelError(), verbosity_level=10)
1997+
agent = CodeAgent(tools=[PythonInterpreterTool()], model=FakeCodeModelError())
19361998
agent.run("What is 2 multiplied by 3.6452?")
19371999
assert "Flag!" in str(agent.memory.steps[1].observations)
19382000

0 commit comments

Comments
 (0)