Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
da5d456
Transfer aggregation of streaming events off the Model class
aymeric-roucher Jun 17, 2025
0c08dd6
Basic gradio ui with openai
aymeric-roucher Jun 17, 2025
0a0456c
Working streaming gradio UI openaimodel
aymeric-roucher Jun 18, 2025
201189d
Working models in ToolCallingAgent
aymeric-roucher Jun 18, 2025
f9f7ded
Format
aymeric-roucher Jun 18, 2025
9c68124
Fix missing definition for removed HfApiModel
aymeric-roucher Jun 18, 2025
b4b8f8a
Add
aymeric-roucher Jun 18, 2025
99dcf01
Merge branch 'main' into richer-streaming-events
aymeric-roucher Jun 18, 2025
c1a24ff
Ruff
aymeric-roucher Jun 18, 2025
ddebb8b
Type fixes
aymeric-roucher Jun 18, 2025
15636c0
Changes
aymeric-roucher Jun 19, 2025
dd2dc02
Merge branch 'main' into richer-streaming-events
aymeric-roucher Jun 19, 2025
d0925ae
Fix variable name
aymeric-roucher Jun 19, 2025
3860b22
Fuse yielding tool call loops
aymeric-roucher Jun 19, 2025
80eb8c0
Separate ActionOutput and ToolOutput
aymeric-roucher Jun 19, 2025
f5df2c6
ChatMessageToolCallStreamDelta
aymeric-roucher Jun 19, 2025
86dcf04
Use ChatMessage in get_clean_message_list
aymeric-roucher Jun 19, 2025
d366832
Remove HfApiModel
aymeric-roucher Jun 19, 2025
077eae5
Format
aymeric-roucher Jun 19, 2025
0b2d836
Reset gradio example
aymeric-roucher Jun 19, 2025
aa556de
Replace Message with ChatMessage
aymeric-roucher Jun 19, 2025
93d651e
Replace names in tests
aymeric-roucher Jun 19, 2025
844c526
Format
aymeric-roucher Jun 19, 2025
c99cb6d
Fix one lint error
aymeric-roucher Jun 19, 2025
3b48659
Merge
aymeric-roucher Jun 19, 2025
a1248a7
Fix model tests
aymeric-roucher Jun 19, 2025
d9cf46a
Merge branch 'main' into stream-events
aymeric-roucher Jun 20, 2025
fdec7de
Remove merging errors
aymeric-roucher Jun 20, 2025
5665c92
Rename ChatMessageStreamEvent to ModelStreamEvent
aymeric-roucher Jun 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 84 additions & 53 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ class RunResult:
timing: Timing


StreamEvent: TypeAlias = Union[
ChatMessageStreamDelta,
RunItem: TypeAlias = Union[
ChatMessage,
ChatMessageToolCall,
ActionOutput,
ToolOutput,
Expand All @@ -222,6 +222,38 @@ class RunResult:
]


@dataclass
class RunStreamEvent:
"""Streaming event that wraps items generated by the agent during its run: action calls, planning steps, etc."""

name: Literal[
"output_message",
"tool_call",
"action_step",
"action_output",
"tool_output",
"planning_step",
"final_answer_step",
]
item: RunItem
type: Literal["run_stream_event"] = "run_stream_event"


@dataclass
class ModelStreamEvent:
"""Streaming event that wraps items generated by the model during its run: these are the stream deltas"""

name: Literal["chat_message_stream_delta"]
item: ChatMessageStreamDelta
type: Literal["chat_message_stream_event"] = "chat_message_stream_event"


StreamEvent: TypeAlias = Union[
RunStreamEvent,
ModelStreamEvent,
]


class MultiStepAgent(ABC):
"""
Agent class that solves the given task step by step, using the ReAct framework:
Expand Down Expand Up @@ -276,15 +308,15 @@ def __init__(
self.prompt_templates = prompt_templates or EMPTY_PROMPT_TEMPLATES
if prompt_templates is not None:
missing_keys = set(EMPTY_PROMPT_TEMPLATES.keys()) - set(prompt_templates.keys())
assert not missing_keys, (
f"Some prompt templates are missing from your custom `prompt_templates`: {missing_keys}"
)
assert (
not missing_keys
), f"Some prompt templates are missing from your custom `prompt_templates`: {missing_keys}"
for key, value in EMPTY_PROMPT_TEMPLATES.items():
if isinstance(value, dict):
for subkey in value.keys():
assert key in prompt_templates.keys() and (subkey in prompt_templates[key].keys()), (
f"Some prompt templates are missing from your custom `prompt_templates`: {subkey} under {key}"
)
assert (
key in prompt_templates.keys() and (subkey in prompt_templates[key].keys())
), f"Some prompt templates are missing from your custom `prompt_templates`: {subkey} under {key}"

self.max_steps = max_steps
self.step_number = 0
Expand Down Expand Up @@ -338,9 +370,9 @@ def _setup_managed_agents(self, managed_agents: list | None = None) -> None:
"""Setup managed agents with proper logging."""
self.managed_agents = {}
if managed_agents:
assert all(agent.name and agent.description for agent in managed_agents), (
"All managed agents need both a name and a description!"
)
assert all(
agent.name and agent.description for agent in managed_agents
), "All managed agents need both a name and a description!"
self.managed_agents = {agent.name: agent for agent in managed_agents}
# Ensure managed agents can be called as tools by the model: set their inputs and output_type
for agent in self.managed_agents.values():
Expand Down Expand Up @@ -473,7 +505,8 @@ def run(

def _run_stream(
self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None
) -> Generator[ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta]:
) -> Generator[StreamEvent]:
final_answer = None
self.step_number = 1
returned_final_answer = False
while not returned_final_answer and self.step_number <= max_steps:
Expand All @@ -486,11 +519,11 @@ def _run_stream(
):
planning_start_time = time.time()
planning_step = None
for element in self._generate_planning_step(
task, is_first_step=len(self.memory.steps) == 1, step=self.step_number
): # Don't use the attribute step_number here, because there can be steps from previous runs
yield element
planning_step = element
for event in self._generate_planning_step(
task, is_first_step=(len(self.memory.steps) == 1), step=self.step_number
):
yield event
planning_step = event.item
assert isinstance(planning_step, PlanningStep) # Last yielded element should be a PlanningStep
self.memory.steps.append(planning_step)
planning_end_time = time.time()
Expand Down Expand Up @@ -528,13 +561,13 @@ def _run_stream(
finally:
self._finalize_step(action_step)
self.memory.steps.append(action_step)
yield action_step
yield RunStreamEvent(name="action_step", item=action_step)
self.step_number += 1

if not returned_final_answer and self.step_number == max_steps + 1:
final_answer = self._handle_max_steps_reached(task, images)
yield action_step
yield FinalAnswerStep(handle_agent_output_types(final_answer))
yield RunStreamEvent(name="action_step", item=action_step)
yield RunStreamEvent(name="final_answer_step", item=FinalAnswerStep(handle_agent_output_types(final_answer)))

def _validate_final_answer(self, final_answer: Any):
for check_function in self.final_answer_checks:
Expand All @@ -551,7 +584,7 @@ def _finalize_step(self, memory_step: ActionStep):
memory_step, agent=self
)

def _handle_max_steps_reached(self, task: str, images: list["PIL.Image.Image"]) -> Any:
def _handle_max_steps_reached(self, task: str, images: list["PIL.Image.Image"] | None = None) -> Any:
action_step_start_time = time.time()
final_answer = self.provide_final_answer(task, images)
final_memory_step = ActionStep(
Expand All @@ -565,9 +598,7 @@ def _handle_max_steps_reached(self, task: str, images: list["PIL.Image.Image"])
self.memory.steps.append(final_memory_step)
return final_answer.content

def _generate_planning_step(
self, task, is_first_step: bool, step: int
) -> Generator[ChatMessageStreamDelta | PlanningStep]:
def _generate_planning_step(self, task, is_first_step: bool, step: int) -> Generator[StreamEvent]:
start_time = time.time()
if is_first_step:
input_messages = [
Expand Down Expand Up @@ -596,7 +627,7 @@ def _generate_planning_step(
if event.token_usage:
output_tokens += event.token_usage.output_tokens
input_tokens = event.token_usage.input_tokens
yield event
yield ModelStreamEvent(name="chat_message_stream_delta", item=event)
else:
plan_message = self.model.generate(input_messages, stop_sequences=["<end_plan>"])
plan_message_content = plan_message.content
Expand Down Expand Up @@ -659,7 +690,7 @@ def _generate_planning_step(
if event.token_usage:
output_tokens += event.token_usage.output_tokens
input_tokens = event.token_usage.input_tokens
yield event
yield ModelStreamEvent(name="chat_message_stream_delta", item=event)
else:
plan_message = self.model.generate(input_messages, stop_sequences=["<end_plan>"])
plan_message_content = plan_message.content
Expand All @@ -673,13 +704,14 @@ def _generate_planning_step(
)
log_headline = "Initial plan" if is_first_step else "Updated plan"
self.logger.log(Rule(f"[bold]{log_headline}", style="orange"), Text(plan), level=LogLevel.INFO)
yield PlanningStep(
planning_step = PlanningStep(
model_input_messages=input_messages,
plan=plan,
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content=plan_message_content),
token_usage=TokenUsage(input_tokens=input_tokens, output_tokens=output_tokens),
timing=Timing(start_time=start_time, end_time=time.time()),
)
yield RunStreamEvent(name="planning_step", item=planning_step)

@property
def logs(self):
Expand Down Expand Up @@ -711,7 +743,7 @@ def write_memory_to_messages(
messages.extend(memory_step.to_messages(summary_mode=summary_mode))
return messages

def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput | ToolOutput]:
def _step_stream(self, memory_step: ActionStep) -> Generator[StreamEvent]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Yields ChatMessageStreamDelta during the run if streaming is enabled.
Expand Down Expand Up @@ -966,7 +998,7 @@ def to_dict(self) -> dict[str, Any]:
requirements = tool_requirements | managed_agents_requirements
if hasattr(self, "authorized_imports"):
requirements.update(
{package.split(".")[0] for package in self.authorized_imports if package not in BASE_BUILTIN_MODULES}
{package.split(".")[0] for package in self.authorized_imports if package not in BASE_BUILTIN_MODULES} # type: ignore
)

agent_dict = {
Expand Down Expand Up @@ -1122,7 +1154,7 @@ def push_to_hub(
repo_id: str,
commit_message: str = "Upload agent",
private: bool | None = None,
token: bool | str | None = None,
token: str | None = None,
create_pr: bool = False,
) -> str:
"""
Expand All @@ -1136,7 +1168,7 @@ def push_to_hub(
Message to commit while pushing.
private (`bool`, *optional*, defaults to `None`):
Whether to make the repo private. If `None`, the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
token (`bool` or `str`, *optional*):
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
create_pr (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -1233,7 +1265,7 @@ def initialize_system_prompt(self) -> str:
)
return system_prompt

def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ToolOutput]:
def _step_stream(self, memory_step: ActionStep) -> Generator[StreamEvent]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Yields ChatMessageStreamDelta during the run if streaming is enabled.
Expand Down Expand Up @@ -1261,7 +1293,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDe
live.update(
Markdown(agglomerate_stream_deltas(chat_message_stream_deltas).render_as_markdown())
)
yield event
yield ModelStreamEvent(name="chat_message_stream_delta", item=event)
chat_message = agglomerate_stream_deltas(chat_message_stream_deltas)
else:
chat_message: ChatMessage = self.model.generate(
Expand Down Expand Up @@ -1303,15 +1335,15 @@ def process_tool_calls(self, chat_message: ChatMessage, memory_step: ActionStep)
Yields:
`ActionOutput`: The final output of tool execution.
"""
model_outputs = []
tool_calls = []
observations = []
model_outputs: list[str] = []
tool_calls: list[ToolCall] = []
observations: list[str] = []

final_answer_call = None
parallel_calls = []
assert chat_message.tool_calls is not None
for tool_call in chat_message.tool_calls:
yield tool_call
yield RunStreamEvent(name="tool_call", item=tool_call)
tool_name = tool_call.function.name
tool_arguments = tool_call.function.arguments
model_outputs.append(str(f"Called Tool: '{tool_name}' with arguments: {tool_arguments}"))
Expand All @@ -1323,7 +1355,6 @@ def process_tool_calls(self, chat_message: ChatMessage, memory_step: ActionStep)
else:
parallel_calls.append((tool_name, tool_arguments))

# Helper function to process a single tool call
def process_single_tool_call(call_info):
tool_name, tool_arguments = call_info
self.logger.log(
Expand Down Expand Up @@ -1355,14 +1386,14 @@ def process_single_tool_call(call_info):
if len(parallel_calls) == 1:
# If there's only one call, process it directly
observations.append(process_single_tool_call(parallel_calls[0]))
yield ToolOutput(output=None, is_final_answer=False)
yield RunStreamEvent(name="tool_output", item=ToolOutput(output=None, is_final_answer=False))
else:
# If multiple tool calls, process them in parallel
with ThreadPoolExecutor(self.max_tool_threads) as executor:
futures = [executor.submit(process_single_tool_call, call_info) for call_info in parallel_calls]
for future in as_completed(futures):
observations.append(future.result())
yield ToolOutput(output=None, is_final_answer=False)
yield RunStreamEvent(name="tool_output", item=ToolOutput(output=None, is_final_answer=False))

# Process final_answer call if present
if final_answer_call:
Expand Down Expand Up @@ -1392,7 +1423,7 @@ def process_single_tool_call(call_info):
level=LogLevel.INFO,
)
memory_step.action_output = final_answer
yield ToolOutput(output=final_answer, is_final_answer=True)
yield RunStreamEvent(name="tool_output", item=ToolOutput(output=final_answer, is_final_answer=True))

# Update memory step with all results
if model_outputs:
Expand Down Expand Up @@ -1584,7 +1615,7 @@ def initialize_system_prompt(self) -> str:
)
return system_prompt

def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]:
def _step_stream(self, memory_step: ActionStep) -> Generator[StreamEvent]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Yields ChatMessageStreamDelta during the run if streaming is enabled.
Expand Down Expand Up @@ -1619,13 +1650,13 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDe
memory_step.model_output_message = chat_message
output_text = chat_message.content
else:
chat_message: ChatMessage = self.model.generate(
chat_message = self.model.generate(
input_messages,
stop_sequences=["<end_code>", "Observation:", "Calling tools:"],
**additional_args,
)
memory_step.model_output_message = chat_message
output_text = chat_message.content
output_text = chat_message.content or ""
self.logger.log_markdown(
content=output_text,
title="Output message of the LLM:",
Expand Down Expand Up @@ -1655,13 +1686,13 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDe
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
raise AgentParsingError(error_msg, self.logger)

memory_step.tool_calls = [
ToolCall(
name="python_interpreter",
arguments=code_action,
id=f"call_{len(self.memory.steps)}",
)
]
tool_call = ToolCall(
name="python_interpreter",
arguments=code_action,
id=f"call_{len(self.memory.steps)}",
)
yield RunStreamEvent(name="action_call", item=tool_call)
memory_step.tool_calls = [tool_call]

### Execute action ###
self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO)
Expand Down Expand Up @@ -1705,7 +1736,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDe
]
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
memory_step.action_output = output
yield ActionOutput(output=output, is_final_answer=is_final_answer)
yield RunStreamEvent(name="action_output", item=ActionOutput(output=output, is_final_answer=is_final_answer))

def to_dict(self) -> dict[str, Any]:
"""Convert the agent to a dictionary representation.
Expand All @@ -1721,7 +1752,7 @@ def to_dict(self) -> dict[str, Any]:
return agent_dict

@classmethod
def from_dict(cls, agent_dict: dict[str, Any], **kwargs) -> "CodeAgent":
def from_dict(cls, agent_dict: dict[str, Any], **kwargs) -> "MultiStepAgent":
"""Create CodeAgent from a dictionary representation.

Args:
Expand Down
27 changes: 27 additions & 0 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,33 @@ def generate(
"""
raise NotImplementedError("This method must be implemented in child classes")

def generate_stream(
self,
messages: list[ChatMessage],
stop_sequences: list[str] | None = None,
response_format: dict[str, str] | None = None,
tools_to_call_from: list[Tool] | None = None,
**kwargs,
) -> Generator[ChatMessageStreamDelta]:
"""Process the input messages and return the model's response in streaming mode.

Parameters:
messages (`list[dict[str, str | list[dict]]] | list[ChatMessage]`):
A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`.
stop_sequences (`List[str]`, *optional*):
A list of strings that will stop the generation if encountered in the model's output.
response_format (`dict[str, str]`, *optional*):
The response format to use in the model's response.
tools_to_call_from (`List[Tool]`, *optional*):
A list of tools that the model can use to generate responses.
**kwargs:
Additional keyword arguments to be passed to the underlying model.

Returns:
`ChatMessage`: A chat message object containing the model's response.
"""
raise NotImplementedError("This method must be implemented in child classes")

def __call__(self, *args, **kwargs):
return self.generate(*args, **kwargs)

Expand Down
Loading
Loading