Skip to content

Commit b12d327

Browse files
authored
Merge branch 'strands-agents:main' into feat-agent-interface
2 parents 8c88772 + 8cae18c commit b12d327

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+3516
-415
lines changed

src/strands/_async.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Private async execution utilities."""
22

33
import asyncio
4+
import contextvars
45
from concurrent.futures import ThreadPoolExecutor
56
from typing import Awaitable, Callable, TypeVar
67

@@ -27,5 +28,6 @@ def execute() -> T:
2728
return asyncio.run(execute_async())
2829

2930
with ThreadPoolExecutor() as executor:
30-
future = executor.submit(execute)
31+
context = contextvars.copy_context()
32+
future = executor.submit(context.run, execute)
3133
return future.result()

src/strands/agent/agent.py

Lines changed: 104 additions & 66 deletions
Large diffs are not rendered by default.

src/strands/agent/interrupt.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

src/strands/event_loop/event_loop.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ async def event_loop_cycle(
227227
)
228228
structured_output_context.set_forced_mode()
229229
logger.debug("Forcing structured output tool")
230-
agent._append_message(
230+
await agent._append_message(
231231
{"role": "user", "content": [{"text": "You must format the previous response as structured output."}]}
232232
)
233233

@@ -322,7 +322,7 @@ async def _handle_model_execution(
322322
model_id=model_id,
323323
)
324324
with trace_api.use_span(model_invoke_span):
325-
agent.hooks.invoke_callbacks(
325+
await agent.hooks.invoke_callbacks_async(
326326
BeforeModelCallEvent(
327327
agent=agent,
328328
)
@@ -335,14 +335,19 @@ async def _handle_model_execution(
335335
tool_specs = agent.tool_registry.get_all_tool_specs()
336336
try:
337337
async for event in stream_messages(
338-
agent.model, agent.system_prompt, agent.messages, tool_specs, structured_output_context.tool_choice
338+
agent.model,
339+
agent.system_prompt,
340+
agent.messages,
341+
tool_specs,
342+
system_prompt_content=agent._system_prompt_content,
343+
tool_choice=structured_output_context.tool_choice,
339344
):
340345
yield event
341346

342347
stop_reason, message, usage, metrics = event["stop"]
343348
invocation_state.setdefault("request_state", {})
344349

345-
agent.hooks.invoke_callbacks(
350+
await agent.hooks.invoke_callbacks_async(
346351
AfterModelCallEvent(
347352
agent=agent,
348353
stop_response=AfterModelCallEvent.ModelStopResponse(
@@ -363,7 +368,7 @@ async def _handle_model_execution(
363368
if model_invoke_span:
364369
tracer.end_span_with_error(model_invoke_span, str(e), e)
365370

366-
agent.hooks.invoke_callbacks(
371+
await agent.hooks.invoke_callbacks_async(
367372
AfterModelCallEvent(
368373
agent=agent,
369374
exception=e,
@@ -397,7 +402,7 @@ async def _handle_model_execution(
397402

398403
# Add the response message to the conversation
399404
agent.messages.append(message)
400-
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
405+
await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=message))
401406

402407
# Update metrics
403408
agent.event_loop_metrics.update_usage(usage)
@@ -502,7 +507,7 @@ async def _handle_tool_execution(
502507
}
503508

504509
agent.messages.append(tool_result_message)
505-
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
510+
await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=tool_result_message))
506511

507512
yield ToolResultMessageEvent(message=tool_result_message)
508513

src/strands/event_loop/streaming.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
TypedEvent,
2323
)
2424
from ..types.citations import CitationsContentBlock
25-
from ..types.content import ContentBlock, Message, Messages
25+
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
2626
from ..types.streaming import (
2727
ContentBlockDeltaEvent,
2828
ContentBlockStart,
@@ -418,16 +418,22 @@ async def stream_messages(
418418
system_prompt: Optional[str],
419419
messages: Messages,
420420
tool_specs: list[ToolSpec],
421+
*,
421422
tool_choice: Optional[Any] = None,
423+
system_prompt_content: Optional[list[SystemContentBlock]] = None,
424+
**kwargs: Any,
422425
) -> AsyncGenerator[TypedEvent, None]:
423426
"""Streams messages to the model and processes the response.
424427
425428
Args:
426429
model: Model provider.
427-
system_prompt: The system prompt to send.
430+
system_prompt: The system prompt string, used for backwards compatibility with models that expect it.
428431
messages: List of messages to send.
429432
tool_specs: The list of tool specs.
430433
tool_choice: Optional tool choice constraint for forcing specific tool usage.
434+
system_prompt_content: The authoritative system prompt content blocks that always contains the
435+
system prompt data.
436+
**kwargs: Additional keyword arguments for future extensibility.
431437
432438
Yields:
433439
The reason for stopping, the final message, and the usage metrics
@@ -436,7 +442,14 @@ async def stream_messages(
436442

437443
messages = _normalize_messages(messages)
438444
start_time = time.time()
439-
chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt, tool_choice=tool_choice)
445+
446+
chunks = model.stream(
447+
messages,
448+
tool_specs if tool_specs else None,
449+
system_prompt,
450+
tool_choice=tool_choice,
451+
system_prompt_content=system_prompt_content,
452+
)
440453

441454
async for event in process_stream(chunks, start_time):
442455
yield event

src/strands/hooks/registry.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
via hook provider objects.
88
"""
99

10+
import inspect
1011
import logging
1112
from dataclasses import dataclass
12-
from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar
13+
from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar
1314

1415
from ..interrupt import Interrupt, InterruptException
1516

@@ -122,10 +123,15 @@ class HookCallback(Protocol, Generic[TEvent]):
122123
```python
123124
def my_callback(event: StartRequestEvent) -> None:
124125
print(f"Request started for agent: {event.agent.name}")
126+
127+
# Or
128+
129+
async def my_callback(event: StartRequestEvent) -> None:
130+
# await an async operation
125131
```
126132
"""
127133

128-
def __call__(self, event: TEvent) -> None:
134+
def __call__(self, event: TEvent) -> None | Awaitable[None]:
129135
"""Handle a hook event.
130136
131137
Args:
@@ -164,6 +170,10 @@ def my_handler(event: StartRequestEvent):
164170
registry.add_callback(StartRequestEvent, my_handler)
165171
```
166172
"""
173+
# Related issue: https://github.com/strands-agents/sdk-python/issues/330
174+
if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback):
175+
raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback")
176+
167177
callbacks = self._registered_callbacks.setdefault(event_type, [])
168178
callbacks.append(callback)
169179

@@ -189,6 +199,52 @@ def register_hooks(self, registry: HookRegistry):
189199
"""
190200
hook.register_hooks(self)
191201

202+
async def invoke_callbacks_async(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]:
203+
"""Invoke all registered callbacks for the given event.
204+
205+
This method finds all callbacks registered for the event's type and
206+
invokes them in the appropriate order. For events with should_reverse_callbacks=True,
207+
callbacks are invoked in reverse registration order. Any exceptions raised by callback
208+
functions will propagate to the caller.
209+
210+
Additionally, this method aggregates interrupts raised by the user to instantiate human-in-the-loop workflows.
211+
212+
Args:
213+
event: The event to dispatch to registered callbacks.
214+
215+
Returns:
216+
The event dispatched to registered callbacks and any interrupts raised by the user.
217+
218+
Raises:
219+
ValueError: If interrupt name is used more than once.
220+
221+
Example:
222+
```python
223+
event = StartRequestEvent(agent=my_agent)
224+
await registry.invoke_callbacks_async(event)
225+
```
226+
"""
227+
interrupts: dict[str, Interrupt] = {}
228+
229+
for callback in self.get_callbacks_for(event):
230+
try:
231+
if inspect.iscoroutinefunction(callback):
232+
await callback(event)
233+
else:
234+
callback(event)
235+
236+
except InterruptException as exception:
237+
interrupt = exception.interrupt
238+
if interrupt.name in interrupts:
239+
message = f"interrupt_name=<{interrupt.name}> | interrupt name used more than once"
240+
logger.error(message)
241+
raise ValueError(message) from exception
242+
243+
# Each callback is allowed to raise their own interrupt.
244+
interrupts[interrupt.name] = interrupt
245+
246+
return event, list(interrupts.values())
247+
192248
def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]:
193249
"""Invoke all registered callbacks for the given event.
194250
@@ -206,6 +262,7 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte
206262
The event dispatched to registered callbacks and any interrupts raised by the user.
207263
208264
Raises:
265+
RuntimeError: If at least one callback is async.
209266
ValueError: If interrupt name is used more than once.
210267
211268
Example:
@@ -214,9 +271,13 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte
214271
registry.invoke_callbacks(event)
215272
```
216273
"""
274+
callbacks = list(self.get_callbacks_for(event))
217275
interrupts: dict[str, Interrupt] = {}
218276

219-
for callback in self.get_callbacks_for(event):
277+
if any(inspect.iscoroutinefunction(callback) for callback in callbacks):
278+
raise RuntimeError(f"event=<{event}> | use invoke_callbacks_async to invoke async callback")
279+
280+
for callback in callbacks:
220281
try:
221282
callback(event)
222283
except InterruptException as exception:

0 commit comments

Comments
 (0)