Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 129 additions & 8 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from rich.panel import Panel
from rich.rule import Rule
from rich.text import Text
import asyncio


if TYPE_CHECKING:
Expand Down Expand Up @@ -1285,27 +1286,38 @@ def _step_stream(
)
yield event
chat_message = agglomerate_stream_deltas(chat_message_stream_deltas)
memory_step.model_output_message = chat_message
output_text = chat_message.content
else:
chat_message: ChatMessage = self.model.generate(
input_messages,
stop_sequences=["Observation:", "Calling tools:"],
tools_to_call_from=self.tools_and_managed_agents,
)
memory_step.model_output_message = chat_message

# Handle cases where chat_message.content might be None - use improved handling from main branch
if chat_message.content is None and chat_message.raw is not None:
log_content = str(chat_message.raw)
output_text = str(chat_message.raw)
else:
log_content = str(chat_message.content) or ""
output_text = chat_message.content

self.logger.log_markdown(
content=log_content,
title="Output message of the LLM:",
level=LogLevel.DEBUG,
)

# Record model output
memory_step.model_output_message = chat_message
memory_step.model_output = chat_message.content
# This adds <end_code> sequence to the history.
# This will nudge ulterior LLM calls to finish with <end_code>, thus efficiently stopping generation.
if output_text and output_text.strip().endswith("```"):
output_text += "<end_code>"
memory_step.model_output_message.content = output_text

memory_step.token_usage = chat_message.token_usage
memory_step.model_output = output_text
except Exception as e:
raise AgentGenerationError(f"Error while generating output:\n{e}", self.logger) from e

Expand Down Expand Up @@ -1456,20 +1468,129 @@ def execute_tool_call(self, tool_name: str, arguments: dict[str, str] | str) ->
arguments = self._substitute_state_variables(arguments)
is_managed_agent = tool_name in self.managed_agents

# Validate arguments first using main branch approach
try:
validate_tool_arguments(tool, arguments)
except (ValueError, TypeError) as e:
raise AgentToolCallError(str(e), self.logger) from e

try:
# Check if tool is async and handle appropriately
if hasattr(tool, 'is_async') and tool.is_async():
# Handle async tool
if isinstance(arguments, dict):
if is_managed_agent:
coro = tool(**arguments)
else:
coro = tool.acall(**arguments, sanitize_inputs_outputs=True)
elif isinstance(arguments, str):
if is_managed_agent:
coro = tool(arguments)
else:
coro = tool.acall(arguments, sanitize_inputs_outputs=True)
else:
raise TypeError(f"Unsupported arguments type: {type(arguments)}")

# Run the coroutine
try:
loop = asyncio.get_running_loop()
# We're in an async context, but this method is sync
# We need to handle this differently - for now we'll use asyncio.run_coroutine_threadsafe
import concurrent.futures
import threading

# Create a new event loop in a thread
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(coro)
finally:
new_loop.close()

with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()

except RuntimeError as e:
if "no running event loop" in str(e):
# No event loop running, we can use asyncio.run
return asyncio.run(coro)
else:
raise e
else:
# Handle sync tool (existing code)
if isinstance(arguments, dict):
return tool(**arguments) if is_managed_agent else tool(**arguments, sanitize_inputs_outputs=True)
elif isinstance(arguments, str):
return tool(arguments) if is_managed_agent else tool(arguments, sanitize_inputs_outputs=True)
else:
raise TypeError(f"Unsupported arguments type: {type(arguments)}")

except Exception as e:
error_msg = f"Error executing tool '{tool_name}' with arguments {str(arguments)}: {type(e).__name__}: {e}"
# Handle execution errors
if is_managed_agent:
error_msg = (
f"Error executing request to team member '{tool_name}' with arguments {str(arguments)}: {e}\n"
"Please try again or request to another team member"
)
else:
error_msg = (
f"Error executing tool '{tool_name}' with arguments {str(arguments)}: {type(e).__name__}: {e}\n"
"Please try again or use another tool"
)
raise AgentToolExecutionError(error_msg, self.logger) from e

async def aexecute_tool_call(self, tool_name: str, arguments: dict[str, str] | str) -> Any:
"""
Async version of execute_tool_call for handling async tools in async contexts.

Args:
tool_name (`str`): Name of the tool or managed agent to execute.
arguments (dict[str, str] | str): Arguments passed to the tool call.
"""
# Check if the tool exists
available_tools = {**self.tools, **self.managed_agents}
if tool_name not in available_tools:
raise AgentToolExecutionError(
f"Unknown tool {tool_name}, should be one of: {', '.join(available_tools)}.", self.logger
)

# Get the tool and substitute state variables in arguments
tool = available_tools[tool_name]
arguments = self._substitute_state_variables(arguments)
is_managed_agent = tool_name in self.managed_agents

# Validate arguments first
try:
validate_tool_arguments(tool, arguments)
except (ValueError, TypeError) as e:
raise AgentToolCallError(str(e), self.logger) from e

try:
# Call tool with appropriate arguments
if isinstance(arguments, dict):
return tool(**arguments) if is_managed_agent else tool(**arguments, sanitize_inputs_outputs=True)
if hasattr(tool, 'is_async') and tool.is_async():
# Handle async tool
if isinstance(arguments, dict):
if is_managed_agent:
return await tool(**arguments)
else:
return await tool.acall(**arguments, sanitize_inputs_outputs=True)
elif isinstance(arguments, str):
if is_managed_agent:
return await tool(arguments)
else:
return await tool.acall(arguments, sanitize_inputs_outputs=True)
else:
raise TypeError(f"Unsupported arguments type: {type(arguments)}")
else:
return tool(arguments) if is_managed_agent else tool(arguments, sanitize_inputs_outputs=True)
# Handle sync tool
if isinstance(arguments, dict):
return tool(**arguments) if is_managed_agent else tool(**arguments, sanitize_inputs_outputs=True)
elif isinstance(arguments, str):
return tool(arguments) if is_managed_agent else tool(arguments, sanitize_inputs_outputs=True)
else:
raise TypeError(f"Unsupported arguments type: {type(arguments)}")

except Exception as e:
# Handle execution errors
Expand Down Expand Up @@ -1781,4 +1902,4 @@ def from_dict(cls, agent_dict: dict[str, Any], **kwargs) -> "CodeAgent":
# Update with any additional kwargs
code_agent_kwargs.update(kwargs)
# Call the parent class's from_dict method
return super().from_dict(agent_dict, **code_agent_kwargs)
return super().from_dict(agent_dict, **code_agent_kwargs)
Loading