Skip to content

Commit 3061116

Browse files
feat(tools): Support string descriptions in Annotated parameters (#1089)
--------- Co-authored-by: Dean Schmigelski <dbschmigelski+github@gmail.com>
1 parent 2b0c6e6 commit 3061116

File tree

2 files changed

+278
-12
lines changed

2 files changed

+278
-12
lines changed

src/strands/tools/decorator.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
4545
import inspect
4646
import logging
4747
from typing import (
48+
Annotated,
4849
Any,
4950
Callable,
5051
Generic,
@@ -54,12 +55,15 @@ def my_tool(param1: str, param2: int = 42) -> dict:
5455
TypeVar,
5556
Union,
5657
cast,
58+
get_args,
59+
get_origin,
5760
get_type_hints,
5861
overload,
5962
)
6063

6164
import docstring_parser
6265
from pydantic import BaseModel, Field, create_model
66+
from pydantic.fields import FieldInfo
6367
from typing_extensions import override
6468

6569
from ..interrupt import InterruptException
@@ -105,15 +109,66 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
105109
# Parse the docstring with docstring_parser
106110
doc_str = inspect.getdoc(func) or ""
107111
self.doc = docstring_parser.parse(doc_str)
108-
109-
# Get parameter descriptions from parsed docstring
110-
self.param_descriptions = {
112+
self.param_descriptions: dict[str, str] = {
111113
param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params
112114
}
113115

114116
# Create a Pydantic model for validation
115117
self.input_model = self._create_input_model()
116118

119+
def _extract_annotated_metadata(
120+
self, annotation: Any, param_name: str, param_default: Any
121+
) -> tuple[Any, FieldInfo]:
122+
"""Extracts type and a simple string description from an Annotated type hint.
123+
124+
Returns:
125+
A tuple of (actual_type, field_info), where field_info is a new, simple
126+
Pydantic FieldInfo instance created from the extracted metadata.
127+
"""
128+
actual_type = annotation
129+
description: str | None = None
130+
131+
if get_origin(annotation) is Annotated:
132+
args = get_args(annotation)
133+
actual_type = args[0]
134+
135+
# Look through metadata for a string description or a FieldInfo object
136+
for meta in args[1:]:
137+
if isinstance(meta, str):
138+
description = meta
139+
elif isinstance(meta, FieldInfo):
140+
# --- Future Contributor Note ---
141+
# We are explicitly blocking the use of `pydantic.Field` within `Annotated`
142+
# because of the complexities of Pydantic v2's immutable Core Schema.
143+
#
144+
# Once a Pydantic model's schema is built, its `FieldInfo` objects are
145+
# effectively frozen. Attempts to mutate a `FieldInfo` object after
146+
# creation (e.g., by copying it and setting `.description` or `.default`)
147+
# are unreliable because the underlying Core Schema does not see these changes.
148+
#
149+
# The correct way to support this would be to reliably extract all
150+
# constraints (ge, le, pattern, etc.) from the original FieldInfo and
151+
# rebuild a new one from scratch. However, these constraints are not
152+
# stored as public attributes, making them difficult to inspect reliably.
153+
#
154+
# Deferring this complexity until there is clear demand and a robust
155+
# pattern for inspecting FieldInfo constraints is established.
156+
raise NotImplementedError(
157+
"Using pydantic.Field within Annotated is not yet supported for tool decorators. "
158+
"Please use a simple string for the description, or define constraints in the function's "
159+
"docstring."
160+
)
161+
162+
# Determine the final description with a clear priority order
163+
# Priority: 1. Annotated string -> 2. Docstring -> 3. Fallback
164+
final_description = description
165+
if final_description is None:
166+
final_description = self.param_descriptions.get(param_name) or f"Parameter {param_name}"
167+
# Create FieldInfo object from scratch
168+
final_field = Field(default=param_default, description=final_description)
169+
170+
return actual_type, final_field
171+
117172
def _validate_signature(self) -> None:
118173
"""Verify that ToolContext is used correctly in the function signature."""
119174
for param in self.signature.parameters.values():
@@ -146,22 +201,21 @@ def _create_input_model(self) -> Type[BaseModel]:
146201
if self._is_special_parameter(name):
147202
continue
148203

149-
# Get parameter type and default
150-
param_type = self.type_hints.get(name, Any)
204+
# Use param.annotation directly to get the raw type hint. Using get_type_hints()
205+
# can cause inconsistent behavior across Python versions for complex Annotated types.
206+
param_type = param.annotation
207+
if param_type is inspect.Parameter.empty:
208+
param_type = Any
151209
default = ... if param.default is inspect.Parameter.empty else param.default
152-
description = self.param_descriptions.get(name, f"Parameter {name}")
153210

154-
# Create Field with description and default
155-
field_definitions[name] = (param_type, Field(default=default, description=description))
211+
actual_type, field_info = self._extract_annotated_metadata(param_type, name, default)
212+
field_definitions[name] = (actual_type, field_info)
156213

157-
# Create model name based on function name
158214
model_name = f"{self.func.__name__.capitalize()}Tool"
159215

160-
# Create and return the model
161216
if field_definitions:
162217
return create_model(model_name, **field_definitions)
163218
else:
164-
# Handle case with no parameters
165219
return create_model(model_name)
166220

167221
def _extract_description_from_docstring(self) -> str:

tests/strands/tools/test_decorator.py

Lines changed: 213 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
"""
44

55
from asyncio import Queue
6-
from typing import Any, AsyncGenerator, Dict, Optional, Union
6+
from typing import Annotated, Any, AsyncGenerator, Dict, List, Optional, Union
77
from unittest.mock import MagicMock
88

99
import pytest
10+
from pydantic import Field
1011

1112
import strands
1213
from strands import Agent
@@ -1611,3 +1612,214 @@ def test_function_tool_metadata_validate_signature_missing_context_config():
16111612
@strands.tool
16121613
def my_tool(tool_context: ToolContext):
16131614
pass
1615+
1616+
1617+
def test_tool_decorator_annotated_string_description():
1618+
"""Test tool decorator with Annotated type hints for descriptions."""
1619+
1620+
@strands.tool
1621+
def annotated_tool(
1622+
name: Annotated[str, "The user's full name"],
1623+
age: Annotated[int, "The user's age in years"],
1624+
city: str, # No annotation - should use docstring or generic
1625+
) -> str:
1626+
"""Tool with annotated parameters.
1627+
1628+
Args:
1629+
city: The user's city (from docstring)
1630+
"""
1631+
return f"{name}, {age}, {city}"
1632+
1633+
spec = annotated_tool.tool_spec
1634+
schema = spec["inputSchema"]["json"]
1635+
1636+
# Check that annotated descriptions are used
1637+
assert schema["properties"]["name"]["description"] == "The user's full name"
1638+
assert schema["properties"]["age"]["description"] == "The user's age in years"
1639+
1640+
# Check that docstring is still used for non-annotated params
1641+
assert schema["properties"]["city"]["description"] == "The user's city (from docstring)"
1642+
1643+
# Verify all are required
1644+
assert set(schema["required"]) == {"name", "age", "city"}
1645+
1646+
1647+
def test_tool_decorator_annotated_pydantic_field_constraints():
1648+
"""Test that using pydantic.Field in Annotated raises a NotImplementedError."""
1649+
with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"):
1650+
1651+
@strands.tool
1652+
def field_annotated_tool(
1653+
email: Annotated[str, Field(description="User's email address", pattern=r"^[\w\.-]+@[\w\.-]+\\.w+$")],
1654+
score: Annotated[int, Field(description="Score between 0-100", ge=0, le=100)] = 50,
1655+
) -> str:
1656+
"""Tool with Pydantic Field annotations."""
1657+
return f"{email}: {score}"
1658+
1659+
1660+
def test_tool_decorator_annotated_overrides_docstring():
1661+
"""Test that Annotated descriptions override docstring descriptions."""
1662+
1663+
@strands.tool
1664+
def override_tool(param: Annotated[str, "Description from annotation"]) -> str:
1665+
"""Tool with both annotation and docstring.
1666+
1667+
Args:
1668+
param: Description from docstring (should be overridden)
1669+
"""
1670+
return param
1671+
1672+
spec = override_tool.tool_spec
1673+
schema = spec["inputSchema"]["json"]
1674+
1675+
# Annotated description should win
1676+
assert schema["properties"]["param"]["description"] == "Description from annotation"
1677+
1678+
1679+
def test_tool_decorator_annotated_optional_type():
1680+
"""Test tool with Optional types in Annotated."""
1681+
1682+
@strands.tool
1683+
def optional_annotated_tool(
1684+
required: Annotated[str, "Required parameter"], optional: Annotated[Optional[str], "Optional parameter"] = None
1685+
) -> str:
1686+
"""Tool with optional annotated parameter."""
1687+
return f"{required}, {optional}"
1688+
1689+
spec = optional_annotated_tool.tool_spec
1690+
schema = spec["inputSchema"]["json"]
1691+
1692+
# Check descriptions
1693+
assert schema["properties"]["required"]["description"] == "Required parameter"
1694+
assert schema["properties"]["optional"]["description"] == "Optional parameter"
1695+
1696+
# Check required list
1697+
assert "required" in schema["required"]
1698+
assert "optional" not in schema["required"]
1699+
1700+
1701+
def test_tool_decorator_annotated_complex_types():
1702+
"""Test tool with complex types in Annotated."""
1703+
1704+
@strands.tool
1705+
def complex_annotated_tool(
1706+
tags: Annotated[List[str], "List of tag strings"], config: Annotated[Dict[str, Any], "Configuration dictionary"]
1707+
) -> str:
1708+
"""Tool with complex annotated types."""
1709+
return f"Tags: {len(tags)}, Config: {len(config)}"
1710+
1711+
spec = complex_annotated_tool.tool_spec
1712+
schema = spec["inputSchema"]["json"]
1713+
1714+
# Check descriptions
1715+
assert schema["properties"]["tags"]["description"] == "List of tag strings"
1716+
assert schema["properties"]["config"]["description"] == "Configuration dictionary"
1717+
1718+
# Check types are preserved
1719+
assert schema["properties"]["tags"]["type"] == "array"
1720+
assert schema["properties"]["config"]["type"] == "object"
1721+
1722+
1723+
def test_tool_decorator_annotated_mixed_styles():
1724+
"""Test that using pydantic.Field in a mixed-style annotation raises NotImplementedError."""
1725+
with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"):
1726+
1727+
@strands.tool
1728+
def mixed_tool(
1729+
plain: str,
1730+
annotated_str: Annotated[str, "String description"],
1731+
annotated_field: Annotated[int, Field(description="Field description", ge=0)],
1732+
docstring_only: int,
1733+
) -> str:
1734+
"""Tool with mixed parameter styles.
1735+
1736+
Args:
1737+
plain: Plain parameter description
1738+
docstring_only: Docstring description for this param
1739+
"""
1740+
return "mixed"
1741+
1742+
1743+
@pytest.mark.asyncio
1744+
async def test_tool_decorator_annotated_execution(alist):
1745+
"""Test that annotated tools execute correctly."""
1746+
1747+
@strands.tool
1748+
def execution_test(name: Annotated[str, "User name"], count: Annotated[int, "Number of times"] = 1) -> str:
1749+
"""Test execution with annotations."""
1750+
return f"Hello {name} " * count
1751+
1752+
# Test tool use
1753+
tool_use = {"toolUseId": "test-id", "input": {"name": "Alice", "count": 2}}
1754+
stream = execution_test.stream(tool_use, {})
1755+
1756+
result = (await alist(stream))[-1]
1757+
assert result["tool_result"]["status"] == "success"
1758+
assert "Hello Alice Hello Alice" in result["tool_result"]["content"][0]["text"]
1759+
1760+
# Test direct call
1761+
direct_result = execution_test("Bob", 3)
1762+
assert direct_result == "Hello Bob Hello Bob Hello Bob "
1763+
1764+
1765+
def test_tool_decorator_annotated_no_description_fallback():
1766+
"""Test that Annotated with a Field raises NotImplementedError."""
1767+
with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"):
1768+
1769+
@strands.tool
1770+
def no_desc_annotated(
1771+
param: Annotated[str, Field()], # Field without description
1772+
) -> str:
1773+
"""Tool with Annotated but no description.
1774+
1775+
Args:
1776+
param: Docstring description
1777+
"""
1778+
return param
1779+
1780+
1781+
def test_tool_decorator_annotated_empty_string_description():
1782+
"""Test handling of empty string descriptions in Annotated."""
1783+
1784+
@strands.tool
1785+
def empty_desc_tool(
1786+
param: Annotated[str, ""], # Empty string description
1787+
) -> str:
1788+
"""Tool with empty annotation description.
1789+
1790+
Args:
1791+
param: Docstring description
1792+
"""
1793+
return param
1794+
1795+
spec = empty_desc_tool.tool_spec
1796+
schema = spec["inputSchema"]["json"]
1797+
1798+
# Empty string is still a valid description, should not fall back
1799+
assert schema["properties"]["param"]["description"] == ""
1800+
1801+
1802+
@pytest.mark.asyncio
1803+
async def test_tool_decorator_annotated_validation_error(alist):
1804+
"""Test that validation works correctly with annotated parameters."""
1805+
1806+
@strands.tool
1807+
def validation_tool(age: Annotated[int, "User age"]) -> str:
1808+
"""Tool for validation testing."""
1809+
return f"Age: {age}"
1810+
1811+
# Test with wrong type
1812+
tool_use = {"toolUseId": "test-id", "input": {"age": "not an int"}}
1813+
stream = validation_tool.stream(tool_use, {})
1814+
1815+
result = (await alist(stream))[-1]
1816+
assert result["tool_result"]["status"] == "error"
1817+
1818+
1819+
def test_tool_decorator_annotated_field_with_inner_default():
1820+
"""Test that a default value in an Annotated Field raises NotImplementedError."""
1821+
with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"):
1822+
1823+
@strands.tool
1824+
def inner_default_tool(name: str, level: Annotated[int, Field(description="A level value", default=10)]) -> str:
1825+
return f"{name} is at level {level}"

0 commit comments

Comments
 (0)