From 8f0c481e618d0546de3556d3e7b864b63737c677 Mon Sep 17 00:00:00 2001 From: Anirban Basu Date: Thu, 27 Nov 2025 22:40:30 +0900 Subject: [PATCH 1/6] chore: Started off with capturing metadata for structured result. More to come! --- pydantic_ai_slim/pydantic_ai/mcp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index ac3cfeae5c..2a983686b6 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -490,9 +490,10 @@ async def direct_call_tool( ): # The MCP SDK wraps primitives and generic types like list in a `result` key, but we want to use the raw value returned by the tool function. # See https://github.com/modelcontextprotocol/python-sdk#structured-output + return_value = structured if isinstance(structured, dict) and len(structured) == 1 and 'result' in structured: - return structured['result'] - return structured + return_value = structured['result'] + return messages.ToolReturn(return_value=return_value, metadata=result.meta) if result.meta else return_value mapped = [await self._map_tool_result_part(part) for part in result.content] return mapped[0] if len(mapped) == 1 else mapped @@ -1179,6 +1180,7 @@ def __eq__(self, value: object, /) -> bool: ToolResult = ( str | messages.BinaryContent + | messages.ToolReturn | dict[str, Any] | list[Any] | Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]] From 39e0a496b1f7f9f415345374186da4cd14306a06 Mon Sep 17 00:00:00 2001 From: Anirban Basu Date: Fri, 28 Nov 2025 07:24:00 +0900 Subject: [PATCH 2/6] feat: Added parsing of metadata, nested or not, in binary and text content including resources. chore: Updated tests but have not added tests specific to metadata, hence coverage should fail. --- pydantic_ai_slim/pydantic_ai/mcp.py | 31 +++++++++++++++--------- pydantic_ai_slim/pydantic_ai/messages.py | 17 ++++++++++++- tests/test_agent.py | 26 +++++++++++++++++--- tests/test_messages.py | 3 +++ 4 files changed, 62 insertions(+), 15 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 2a983686b6..9feb7fad9e 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -575,16 +575,18 @@ async def list_resource_templates(self) -> list[ResourceTemplate]: return [ResourceTemplate.from_mcp_sdk(t) for t in result.resourceTemplates] @overload - async def read_resource(self, uri: str) -> str | messages.BinaryContent | list[str | messages.BinaryContent]: ... + async def read_resource( + self, uri: str + ) -> str | messages.TextPart | messages.BinaryContent | list[str | messages.TextPart | messages.BinaryContent]: ... @overload async def read_resource( self, uri: Resource - ) -> str | messages.BinaryContent | list[str | messages.BinaryContent]: ... + ) -> str | messages.TextPart | messages.BinaryContent | list[str | messages.TextPart | messages.BinaryContent]: ... async def read_resource( self, uri: str | Resource - ) -> str | messages.BinaryContent | list[str | messages.BinaryContent]: + ) -> str | messages.TextPart | messages.BinaryContent | list[str | messages.TextPart | messages.BinaryContent]: """Read the contents of a specific resource by URI. Args: @@ -683,7 +685,7 @@ async def _sampling_callback( async def _map_tool_result_part( self, part: mcp_types.ContentBlock - ) -> str | messages.BinaryContent | dict[str, Any] | list[Any]: + ) -> str | messages.TextPart | messages.BinaryContent | dict[str, Any] | list[Any]: # See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values if isinstance(part, mcp_types.TextContent): @@ -693,14 +695,16 @@ async def _map_tool_result_part( return pydantic_core.from_json(text) except ValueError: pass - return text + return text if not part.meta else messages.TextPart(content=text, metadata=part.meta) elif isinstance(part, mcp_types.ImageContent): - return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType) + return messages.BinaryContent( + data=base64.b64decode(part.data), media_type=part.mimeType, metadata=part.meta + ) elif isinstance(part, mcp_types.AudioContent): # NOTE: The FastMCP server doesn't support audio content. # See for more details. return messages.BinaryContent( - data=base64.b64decode(part.data), media_type=part.mimeType + data=base64.b64decode(part.data), media_type=part.mimeType, metadata=part.meta ) # pragma: no cover elif isinstance(part, mcp_types.EmbeddedResource): resource = part.resource @@ -712,12 +716,16 @@ async def _map_tool_result_part( def _get_content( self, resource: mcp_types.TextResourceContents | mcp_types.BlobResourceContents - ) -> str | messages.BinaryContent: + ) -> str | messages.TextPart | messages.BinaryContent: if isinstance(resource, mcp_types.TextResourceContents): - return resource.text + return ( + resource.text if not resource.meta else messages.TextPart(content=resource.text, metadata=resource.meta) + ) elif isinstance(resource, mcp_types.BlobResourceContents): return messages.BinaryContent( - data=base64.b64decode(resource.blob), media_type=resource.mimeType or 'application/octet-stream' + data=base64.b64decode(resource.blob), + media_type=resource.mimeType or 'application/octet-stream', + metadata=resource.meta, ) else: assert_never(resource) @@ -1179,11 +1187,12 @@ def __eq__(self, value: object, /) -> bool: ToolResult = ( str + | messages.TextPart | messages.BinaryContent | messages.ToolReturn | dict[str, Any] | list[Any] - | Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]] + | Sequence[str | messages.TextPart | messages.BinaryContent | dict[str, Any] | list[Any]] ) """The result type of an MCP tool call.""" diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index ac0fb0da6d..5118817181 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -486,6 +486,9 @@ class BinaryContent: - `OpenAIChatModel`, `OpenAIResponsesModel`: `BinaryContent.vendor_metadata['detail']` is used as `detail` setting for images """ + metadata: Any = None + """Additional data that can be accessed programmatically by the application but is not sent to the LLM.""" + _identifier: Annotated[str | None, pydantic.Field(alias='identifier', default=None, exclude=True)] = field( compare=False, default=None ) @@ -500,6 +503,7 @@ def __init__( media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str, identifier: str | None = None, vendor_metadata: dict[str, Any] | None = None, + metadata: Any = None, kind: Literal['binary'] = 'binary', # Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs. _identifier: str | None = None, @@ -508,6 +512,7 @@ def __init__( self.media_type = media_type self._identifier = identifier or _identifier self.vendor_metadata = vendor_metadata + self.metadata = metadata self.kind = kind @staticmethod @@ -519,6 +524,7 @@ def narrow_type(bc: BinaryContent) -> BinaryContent | BinaryImage: media_type=bc.media_type, identifier=bc.identifier, vendor_metadata=bc.vendor_metadata, + metadata=bc.metadata, ) else: return bc @@ -622,11 +628,17 @@ def __init__( identifier: str | None = None, vendor_metadata: dict[str, Any] | None = None, # Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs. + metadata: Any = None, kind: Literal['binary'] = 'binary', _identifier: str | None = None, ): super().__init__( - data=data, media_type=media_type, identifier=identifier or _identifier, vendor_metadata=vendor_metadata + data=data, + media_type=media_type, + identifier=identifier or _identifier, + vendor_metadata=vendor_metadata, + metadata=metadata, + kind=kind, ) if not self.is_image: @@ -1031,6 +1043,9 @@ class TextPart: This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically.""" + metadata: Any = None + """Additional data that can be accessed programmatically by the application but is not sent to the LLM.""" + part_kind: Literal['text'] = 'text' """Part type identifier, this is available on all parts as a discriminator.""" diff --git a/tests/test_agent.py b/tests/test_agent.py index c912334434..0226ed172d 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -3785,6 +3785,7 @@ def test_binary_content_serializable(): 'data': 'SGVsbG8=', 'media_type': 'text/plain', 'vendor_metadata': None, + 'metadata': None, 'kind': 'binary', 'identifier': 'f7ff9e', }, @@ -3800,7 +3801,13 @@ def test_binary_content_serializable(): }, { 'parts': [ - {'content': 'success (no tool calls)', 'id': None, 'part_kind': 'text', 'provider_details': None} + { + 'content': 'success (no tool calls)', + 'id': None, + 'part_kind': 'text', + 'metadata': None, + 'provider_details': None, + } ], 'usage': { 'input_tokens': 56, @@ -3862,7 +3869,13 @@ def test_image_url_serializable_missing_media_type(): }, { 'parts': [ - {'content': 'success (no tool calls)', 'id': None, 'part_kind': 'text', 'provider_details': None} + { + 'content': 'success (no tool calls)', + 'id': None, + 'part_kind': 'text', + 'metadata': None, + 'provider_details': None, + } ], 'usage': { 'input_tokens': 51, @@ -3931,7 +3944,13 @@ def test_image_url_serializable(): }, { 'parts': [ - {'content': 'success (no tool calls)', 'id': None, 'part_kind': 'text', 'provider_details': None} + { + 'content': 'success (no tool calls)', + 'id': None, + 'part_kind': 'text', + 'metadata': None, + 'provider_details': None, + } ], 'usage': { 'input_tokens': 51, @@ -3978,6 +3997,7 @@ def test_tool_return_part_binary_content_serialization(): 'data': 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzgAAAAASUVORK5CYII=', 'media_type': 'image/png', 'vendor_metadata': None, + 'metadata': None, '_identifier': None, 'kind': 'binary', } diff --git a/tests/test_messages.py b/tests/test_messages.py index d6d9617247..ae806fe962 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -457,6 +457,7 @@ def test_file_part_serialization_roundtrip(): 'data': 'ZmFrZQ==', 'media_type': 'image/jpeg', 'identifier': 'c053ec', + 'metadata': None, 'vendor_metadata': None, 'kind': 'binary', }, @@ -605,6 +606,7 @@ def test_binary_content_validation_with_optional_identifier(): 'data': b'fake', 'vendor_metadata': None, 'kind': 'binary', + 'metadata': None, 'media_type': 'image/jpeg', 'identifier': 'c053ec', } @@ -621,6 +623,7 @@ def test_binary_content_validation_with_optional_identifier(): 'data': b'fake', 'vendor_metadata': None, 'kind': 'binary', + 'metadata': None, 'media_type': 'image/png', 'identifier': 'foo', } From e021d47da7b3207fe34260d04f849eba12c4a6a7 Mon Sep 17 00:00:00 2001 From: Anirban Basu Date: Fri, 28 Nov 2025 07:49:42 +0900 Subject: [PATCH 3/6] chore: Added top-level ToolReturn metadata. --- pydantic_ai_slim/pydantic_ai/mcp.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 9feb7fad9e..d77954dbcf 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -496,7 +496,14 @@ async def direct_call_tool( return messages.ToolReturn(return_value=return_value, metadata=result.meta) if result.meta else return_value mapped = [await self._map_tool_result_part(part) for part in result.content] - return mapped[0] if len(mapped) == 1 else mapped + if result.meta: + return ( + messages.ToolReturn(return_value=mapped[0], metadata=result.meta) + if len(mapped) == 1 + else messages.ToolReturn(return_value=mapped, metadata=result.meta) + ) + else: + return mapped[0] if len(mapped) == 1 else mapped async def call_tool( self, From ae1b882c4e0d2bc2bb795f423b473b4800d69596 Mon Sep 17 00:00:00 2001 From: Anirban Basu Date: Fri, 28 Nov 2025 08:53:40 +0900 Subject: [PATCH 4/6] test: Added first metadata test with a single text part. --- pydantic_ai_slim/pydantic_ai/mcp.py | 15 +++++++++------ tests/mcp_server.py | 30 +++++++++++++++++++++++++++++ tests/test_mcp.py | 19 ++++++++++++++++-- 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index d77954dbcf..a22571bc83 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -697,12 +697,15 @@ async def _map_tool_result_part( if isinstance(part, mcp_types.TextContent): text = part.text - if text.startswith(('[', '{')): - try: - return pydantic_core.from_json(text) - except ValueError: - pass - return text if not part.meta else messages.TextPart(content=text, metadata=part.meta) + if part.meta: + return messages.TextPart(content=text, metadata=part.meta) + else: + if text.startswith(('[', '{')): + try: + return pydantic_core.from_json(text) + except ValueError: + pass + return text elif isinstance(part, mcp_types.ImageContent): return messages.BinaryContent( data=base64.b64decode(part.data), media_type=part.mimeType, metadata=part.meta diff --git a/tests/mcp_server.py b/tests/mcp_server.py index 8ba9b9997f..b4fe3dea04 100644 --- a/tests/mcp_server.py +++ b/tests/mcp_server.py @@ -47,6 +47,36 @@ async def get_weather_forecast(location: str) -> str: return f'The weather in {location} is sunny and 26 degrees Celsius.' +@mcp.tool(structured_output=False, annotations=ToolAnnotations(title='Collatz Conjecture sequence generator')) +async def get_collatz_conjecture(n: int) -> TextContent: + """Generate the Collatz conjecture sequence for a given number. + This tool attaches response metadata. + + Args: + n: The starting number for the Collatz sequence. + Returns: + A list representing the Collatz sequence with attached metadata. + """ + if n <= 0: + raise ValueError('Starting number for the Collatz conjecture must be a positive integer.') + + input_param_n = n # store the original input value + + sequence = [n] + while n != 1: + if n % 2 == 0: + n = n // 2 + else: + n = 3 * n + 1 + sequence.append(n) + + return TextContent( + type='text', + text=str(sequence), + _meta={'pydantic_ai': {'tool': 'collatz_conjecture', 'n': input_param_n, 'length': len(sequence)}}, + ) + + @mcp.tool() async def get_image_resource() -> EmbeddedResource: data = Path(__file__).parent.joinpath('assets/kiwi.png').read_bytes() diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 221ad37548..ddcf666bc3 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -95,7 +95,7 @@ async def test_stdio_server(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] - assert len(tools) == snapshot(18) + assert len(tools) == snapshot(19) assert tools[0].name == 'celsius_to_fahrenheit' assert isinstance(tools[0].description, str) assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') @@ -105,6 +105,21 @@ async def test_stdio_server(run_context: RunContext[int]): assert result == snapshot(32.0) +async def test_tool_response_single_text_part_metadata(run_context: RunContext[int]): + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] + assert len(tools) == snapshot(19) + assert tools[2].name == 'get_collatz_conjecture' + assert isinstance(tools[2].description, str) + assert tools[2].description.startswith('Generate the Collatz conjecture sequence for a given number.') + + result = await server.direct_call_tool('get_collatz_conjecture', {'n': 7}) + assert isinstance(result, TextPart) + assert result.content == snapshot('[7, 22, 11, 34, 17, 52, 26, 13, 40, 20, 10, 5, 16, 8, 4, 2, 1]') + assert result.metadata == snapshot({'pydantic_ai': {'tool': 'collatz_conjecture', 'n': 7, 'length': 17}}) + + async def test_reentrant_context_manager(): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: @@ -156,7 +171,7 @@ async def test_stdio_server_with_cwd(run_context: RunContext[int]): server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) async with server: tools = await server.get_tools(run_context) - assert len(tools) == snapshot(18) + assert len(tools) == snapshot(19) async def test_process_tool_call(run_context: RunContext[int]) -> int: From e9d64756de2f4d29180e94a0704b9e3f7c88b6ac Mon Sep 17 00:00:00 2001 From: Anirban Basu Date: Fri, 28 Nov 2025 22:22:47 +0900 Subject: [PATCH 5/6] chore: Added a no-cover directive and a note why ToolResult.meta cannot be tested until FastMCP can be upgraded to 2.13.1. --- pydantic_ai_slim/pydantic_ai/mcp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index a22571bc83..7c5cd823d6 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -497,7 +497,10 @@ async def direct_call_tool( mapped = [await self._map_tool_result_part(part) for part in result.content] if result.meta: - return ( + # The following branching cannot be tested until FastMCP is updated to version 2.13.1 + # such that the MCP server can generate ToolResult and result.meta can be specified. + # TODO: Add tests for the following branching once FastMCP is updated. + return ( # pragma: no cover messages.ToolReturn(return_value=mapped[0], metadata=result.meta) if len(mapped) == 1 else messages.ToolReturn(return_value=mapped, metadata=result.meta) From 191de6cb7e92d922dd3f33b855c2670d197cad6e Mon Sep 17 00:00:00 2001 From: Anirban Basu Date: Sat, 29 Nov 2025 12:42:14 +0900 Subject: [PATCH 6/6] feat: Added a TextContent. --- pydantic_ai_slim/pydantic_ai/mcp.py | 46 +++++++++++-------- pydantic_ai_slim/pydantic_ai/messages.py | 32 +++++++++++-- .../pydantic_ai/models/bedrock.py | 3 ++ pydantic_ai_slim/pydantic_ai/models/gemini.py | 3 ++ pydantic_ai_slim/pydantic_ai/models/google.py | 3 ++ .../pydantic_ai/models/huggingface.py | 3 ++ pydantic_ai_slim/pydantic_ai/models/openai.py | 10 ++++ tests/test_mcp.py | 3 +- 8 files changed, 79 insertions(+), 24 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 7c5cd823d6..353d411d93 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -490,23 +490,23 @@ async def direct_call_tool( ): # The MCP SDK wraps primitives and generic types like list in a `result` key, but we want to use the raw value returned by the tool function. # See https://github.com/modelcontextprotocol/python-sdk#structured-output - return_value = structured - if isinstance(structured, dict) and len(structured) == 1 and 'result' in structured: - return_value = structured['result'] - return messages.ToolReturn(return_value=return_value, metadata=result.meta) if result.meta else return_value - - mapped = [await self._map_tool_result_part(part) for part in result.content] + return_value = ( + structured['result'] + if isinstance(structured, dict) and len(structured) == 1 and 'result' in structured + else structured + ) + else: + mapped = [await self._map_tool_result_part(part) for part in result.content] + return_value = mapped[0] if len(mapped) == 1 else mapped if result.meta: # The following branching cannot be tested until FastMCP is updated to version 2.13.1 # such that the MCP server can generate ToolResult and result.meta can be specified. # TODO: Add tests for the following branching once FastMCP is updated. return ( # pragma: no cover - messages.ToolReturn(return_value=mapped[0], metadata=result.meta) - if len(mapped) == 1 - else messages.ToolReturn(return_value=mapped, metadata=result.meta) + messages.ToolReturn(return_value=return_value, metadata=result.meta) ) else: - return mapped[0] if len(mapped) == 1 else mapped + return return_value async def call_tool( self, @@ -587,16 +587,22 @@ async def list_resource_templates(self) -> list[ResourceTemplate]: @overload async def read_resource( self, uri: str - ) -> str | messages.TextPart | messages.BinaryContent | list[str | messages.TextPart | messages.BinaryContent]: ... + ) -> ( + str | messages.TextContent | messages.BinaryContent | list[str | messages.TextContent | messages.BinaryContent] + ): ... @overload async def read_resource( self, uri: Resource - ) -> str | messages.TextPart | messages.BinaryContent | list[str | messages.TextPart | messages.BinaryContent]: ... + ) -> ( + str | messages.TextContent | messages.BinaryContent | list[str | messages.TextContent | messages.BinaryContent] + ): ... async def read_resource( self, uri: str | Resource - ) -> str | messages.TextPart | messages.BinaryContent | list[str | messages.TextPart | messages.BinaryContent]: + ) -> ( + str | messages.TextContent | messages.BinaryContent | list[str | messages.TextContent | messages.BinaryContent] + ): """Read the contents of a specific resource by URI. Args: @@ -695,13 +701,13 @@ async def _sampling_callback( async def _map_tool_result_part( self, part: mcp_types.ContentBlock - ) -> str | messages.TextPart | messages.BinaryContent | dict[str, Any] | list[Any]: + ) -> str | messages.TextContent | messages.BinaryContent | dict[str, Any] | list[Any]: # See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values if isinstance(part, mcp_types.TextContent): text = part.text if part.meta: - return messages.TextPart(content=text, metadata=part.meta) + return messages.TextContent(content=text, metadata=part.meta) else: if text.startswith(('[', '{')): try: @@ -729,10 +735,12 @@ async def _map_tool_result_part( def _get_content( self, resource: mcp_types.TextResourceContents | mcp_types.BlobResourceContents - ) -> str | messages.TextPart | messages.BinaryContent: + ) -> str | messages.TextContent | messages.BinaryContent: if isinstance(resource, mcp_types.TextResourceContents): return ( - resource.text if not resource.meta else messages.TextPart(content=resource.text, metadata=resource.meta) + resource.text + if not resource.meta + else messages.TextContent(content=resource.text, metadata=resource.meta) ) elif isinstance(resource, mcp_types.BlobResourceContents): return messages.BinaryContent( @@ -1200,12 +1208,12 @@ def __eq__(self, value: object, /) -> bool: ToolResult = ( str - | messages.TextPart + | messages.TextContent | messages.BinaryContent | messages.ToolReturn | dict[str, Any] | list[Any] - | Sequence[str | messages.TextPart | messages.BinaryContent | dict[str, Any] | list[Any]] + | Sequence[str | messages.TextContent | messages.BinaryContent | dict[str, Any] | list[Any]] ) """The result type of an MCP tool call.""" diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 5118817181..6fcb28ff2f 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -466,6 +466,33 @@ def format(self) -> DocumentFormat: raise ValueError(f'Unknown document media type: {media_type}') from e +@dataclass(repr=False) +class TextContent: + """A plain text response from a model with optional metadata.""" + + content: str + """The text content of the response.""" + + _: KW_ONLY + + provider_details: dict[str, Any] | None = None + """Additional data returned by the provider that can't be mapped to standard fields. + + This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically.""" + + metadata: Any = None + """Additional data that can be accessed programmatically by the application but is not sent to the LLM.""" + + kind: Literal['text'] = 'text' + """Type identifier, this is available as a discriminator.""" + + def has_content(self) -> bool: + """Return `True` if the text content is non-empty.""" + return bool(self.content) + + __repr__ = _utils.dataclasses_no_defaults_repr + + @dataclass(init=False, repr=False) class BinaryContent: """Binary content, e.g. an audio or image file.""" @@ -669,7 +696,7 @@ class CachePoint: MultiModalContent = ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent -UserContent: TypeAlias = str | MultiModalContent | CachePoint +UserContent: TypeAlias = str | TextContent | MultiModalContent | CachePoint @dataclass(repr=False) @@ -1043,9 +1070,6 @@ class TextPart: This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically.""" - metadata: Any = None - """Additional data that can be accessed programmatically by the application but is not sent to the LLM.""" - part_kind: Literal['text'] = 'text' """Part type identifier, this is available on all parts as a discriminator.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index ff03460904..25f9446800 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -37,6 +37,7 @@ UserPromptPart, VideoUrl, _utils, + messages, usage, ) from pydantic_ai._run_context import RunContext @@ -628,6 +629,8 @@ async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) for item in part.content: if isinstance(item, str): content.append({'text': item}) + elif isinstance(item, messages.TextContent): + content.append({'text': item.content}) elif isinstance(item, BinaryContent): format = item.format if item.is_document: diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 4da92018fd..d4474756de 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -31,6 +31,7 @@ ModelResponseStreamEvent, RetryPromptPart, SystemPromptPart, + TextContent, TextPart, ThinkingPart, ToolCallPart, @@ -374,6 +375,8 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion] for item in part.content: if isinstance(item, str): content.append({'text': item}) + elif isinstance(item, TextContent): + content.append({'text': item.content}) elif isinstance(item, BinaryContent): base64_encoded = base64.b64encode(item.data).decode('utf-8') content.append( diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 89290ea3ce..c37938464a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -30,6 +30,7 @@ ModelResponseStreamEvent, RetryPromptPart, SystemPromptPart, + TextContent, TextPart, ThinkingPart, ToolCallPart, @@ -601,6 +602,8 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]: for item in part.content: if isinstance(item, str): content.append({'text': item}) + elif isinstance(item, TextContent): + content.append({'text': item.content}) elif isinstance(item, BinaryContent): inline_data_dict: BlobDict = {'data': item.data, 'mime_type': item.media_type} part_dict: PartDict = {'inline_data': inline_data_dict} diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 790b30bec3..2ac49116de 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -30,6 +30,7 @@ ModelResponseStreamEvent, RetryPromptPart, SystemPromptPart, + TextContent, TextPart, ThinkingPart, ToolCallPart, @@ -433,6 +434,8 @@ async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage: for item in part.content: if isinstance(item, str): content.append(ChatCompletionInputMessageChunk(type='text', text=item)) # type: ignore + elif isinstance(item, TextContent): + content.append(ChatCompletionInputMessageChunk(type='text', text=item.content)) # type: ignore elif isinstance(item, ImageUrl): url = ChatCompletionInputURL(url=item.url) # type: ignore content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 10af284ee8..c3a9ecdba4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -40,6 +40,7 @@ PartStartEvent, RetryPromptPart, SystemPromptPart, + TextContent, TextPart, ThinkingPart, ToolCallPart, @@ -914,6 +915,13 @@ async def _map_user_prompt(self, part: UserPromptPart) -> chat.ChatCompletionUse for item in part.content: if isinstance(item, str): content.append(ChatCompletionContentPartTextParam(text=item, type='text')) + elif isinstance(item, TextContent): + content.append( + ChatCompletionContentPartTextParam( + text=item.content, + type='text', + ) + ) elif isinstance(item, ImageUrl): image_url: ImageURL = {'url': item.url} if metadata := item.vendor_metadata: @@ -1754,6 +1762,8 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa for item in part.content: if isinstance(item, str): content.append(responses.ResponseInputTextParam(text=item, type='input_text')) + elif isinstance(item, TextContent): + content.append(responses.ResponseInputTextParam(text=item.content, type='input_text')) elif isinstance(item, BinaryContent): if item.is_image: detail: Literal['auto', 'low', 'high'] = 'auto' diff --git a/tests/test_mcp.py b/tests/test_mcp.py index ddcf666bc3..0cb3acfc17 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -22,6 +22,7 @@ ToolCallPart, ToolReturnPart, UserPromptPart, + messages, ) from pydantic_ai.agent import Agent from pydantic_ai.exceptions import ( @@ -115,7 +116,7 @@ async def test_tool_response_single_text_part_metadata(run_context: RunContext[i assert tools[2].description.startswith('Generate the Collatz conjecture sequence for a given number.') result = await server.direct_call_tool('get_collatz_conjecture', {'n': 7}) - assert isinstance(result, TextPart) + assert isinstance(result, messages.TextContent) assert result.content == snapshot('[7, 22, 11, 34, 17, 52, 26, 13, 40, 20, 10, 5, 16, 8, 4, 2, 1]') assert result.metadata == snapshot({'pydantic_ai': {'tool': 'collatz_conjecture', 'n': 7, 'length': 17}})