Skip to content
76 changes: 54 additions & 22 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,12 +490,23 @@ async def direct_call_tool(
):
# The MCP SDK wraps primitives and generic types like list in a `result` key, but we want to use the raw value returned by the tool function.
# See https://github.com/modelcontextprotocol/python-sdk#structured-output
if isinstance(structured, dict) and len(structured) == 1 and 'result' in structured:
return structured['result']
return structured

mapped = [await self._map_tool_result_part(part) for part in result.content]
return mapped[0] if len(mapped) == 1 else mapped
return_value = (
structured['result']
if isinstance(structured, dict) and len(structured) == 1 and 'result' in structured
else structured
)
else:
mapped = [await self._map_tool_result_part(part) for part in result.content]
return_value = mapped[0] if len(mapped) == 1 else mapped
if result.meta:
# The following branching cannot be tested until FastMCP is updated to version 2.13.1
# such that the MCP server can generate ToolResult and result.meta can be specified.
# TODO: Add tests for the following branching once FastMCP is updated.
return ( # pragma: no cover
messages.ToolReturn(return_value=return_value, metadata=result.meta)
)
else:
return return_value

async def call_tool(
self,
Expand Down Expand Up @@ -574,16 +585,24 @@ async def list_resource_templates(self) -> list[ResourceTemplate]:
return [ResourceTemplate.from_mcp_sdk(t) for t in result.resourceTemplates]

@overload
async def read_resource(self, uri: str) -> str | messages.BinaryContent | list[str | messages.BinaryContent]: ...
async def read_resource(
self, uri: str
) -> (
str | messages.TextContent | messages.BinaryContent | list[str | messages.TextContent | messages.BinaryContent]
): ...

@overload
async def read_resource(
self, uri: Resource
) -> str | messages.BinaryContent | list[str | messages.BinaryContent]: ...
) -> (
str | messages.TextContent | messages.BinaryContent | list[str | messages.TextContent | messages.BinaryContent]
): ...

async def read_resource(
self, uri: str | Resource
) -> str | messages.BinaryContent | list[str | messages.BinaryContent]:
) -> (
str | messages.TextContent | messages.BinaryContent | list[str | messages.TextContent | messages.BinaryContent]
):
"""Read the contents of a specific resource by URI.
Args:
Expand Down Expand Up @@ -682,24 +701,29 @@ async def _sampling_callback(

async def _map_tool_result_part(
self, part: mcp_types.ContentBlock
) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
) -> str | messages.TextContent | messages.BinaryContent | dict[str, Any] | list[Any]:
# See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values

if isinstance(part, mcp_types.TextContent):
text = part.text
if text.startswith(('[', '{')):
try:
return pydantic_core.from_json(text)
except ValueError:
pass
return text
if part.meta:
return messages.TextContent(content=text, metadata=part.meta)
else:
if text.startswith(('[', '{')):
try:
return pydantic_core.from_json(text)
except ValueError:
pass
return text
elif isinstance(part, mcp_types.ImageContent):
return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
return messages.BinaryContent(
data=base64.b64decode(part.data), media_type=part.mimeType, metadata=part.meta
)
elif isinstance(part, mcp_types.AudioContent):
# NOTE: The FastMCP server doesn't support audio content.
# See <https://github.com/modelcontextprotocol/python-sdk/issues/952> for more details.
return messages.BinaryContent(
data=base64.b64decode(part.data), media_type=part.mimeType
data=base64.b64decode(part.data), media_type=part.mimeType, metadata=part.meta
) # pragma: no cover
elif isinstance(part, mcp_types.EmbeddedResource):
resource = part.resource
Expand All @@ -711,12 +735,18 @@ async def _map_tool_result_part(

def _get_content(
self, resource: mcp_types.TextResourceContents | mcp_types.BlobResourceContents
) -> str | messages.BinaryContent:
) -> str | messages.TextContent | messages.BinaryContent:
if isinstance(resource, mcp_types.TextResourceContents):
return resource.text
return (
resource.text
if not resource.meta
else messages.TextContent(content=resource.text, metadata=resource.meta)
)
elif isinstance(resource, mcp_types.BlobResourceContents):
return messages.BinaryContent(
data=base64.b64decode(resource.blob), media_type=resource.mimeType or 'application/octet-stream'
data=base64.b64decode(resource.blob),
media_type=resource.mimeType or 'application/octet-stream',
metadata=resource.meta,
)
else:
assert_never(resource)
Expand Down Expand Up @@ -1178,10 +1208,12 @@ def __eq__(self, value: object, /) -> bool:

ToolResult = (
str
| messages.TextContent
| messages.BinaryContent
| messages.ToolReturn
| dict[str, Any]
| list[Any]
| Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]]
| Sequence[str | messages.TextContent | messages.BinaryContent | dict[str, Any] | list[Any]]
)
"""The result type of an MCP tool call."""

Expand Down
43 changes: 41 additions & 2 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,33 @@ def format(self) -> DocumentFormat:
raise ValueError(f'Unknown document media type: {media_type}') from e


@dataclass(repr=False)
class TextContent:
"""A plain text response from a model with optional metadata."""

content: str
"""The text content of the response."""

_: KW_ONLY

provider_details: dict[str, Any] | None = None
"""Additional data returned by the provider that can't be mapped to standard fields.
This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically."""

metadata: Any = None
"""Additional data that can be accessed programmatically by the application but is not sent to the LLM."""

kind: Literal['text'] = 'text'
"""Type identifier, this is available as a discriminator."""

def has_content(self) -> bool:
"""Return `True` if the text content is non-empty."""
return bool(self.content)

__repr__ = _utils.dataclasses_no_defaults_repr


@dataclass(init=False, repr=False)
class BinaryContent:
"""Binary content, e.g. an audio or image file."""
Expand All @@ -486,6 +513,9 @@ class BinaryContent:
- `OpenAIChatModel`, `OpenAIResponsesModel`: `BinaryContent.vendor_metadata['detail']` is used as `detail` setting for images
"""

metadata: Any = None
"""Additional data that can be accessed programmatically by the application but is not sent to the LLM."""

_identifier: Annotated[str | None, pydantic.Field(alias='identifier', default=None, exclude=True)] = field(
compare=False, default=None
)
Expand All @@ -500,6 +530,7 @@ def __init__(
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str,
identifier: str | None = None,
vendor_metadata: dict[str, Any] | None = None,
metadata: Any = None,
kind: Literal['binary'] = 'binary',
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
_identifier: str | None = None,
Expand All @@ -508,6 +539,7 @@ def __init__(
self.media_type = media_type
self._identifier = identifier or _identifier
self.vendor_metadata = vendor_metadata
self.metadata = metadata
self.kind = kind

@staticmethod
Expand All @@ -519,6 +551,7 @@ def narrow_type(bc: BinaryContent) -> BinaryContent | BinaryImage:
media_type=bc.media_type,
identifier=bc.identifier,
vendor_metadata=bc.vendor_metadata,
metadata=bc.metadata,
)
else:
return bc
Expand Down Expand Up @@ -622,11 +655,17 @@ def __init__(
identifier: str | None = None,
vendor_metadata: dict[str, Any] | None = None,
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
metadata: Any = None,
kind: Literal['binary'] = 'binary',
_identifier: str | None = None,
):
super().__init__(
data=data, media_type=media_type, identifier=identifier or _identifier, vendor_metadata=vendor_metadata
data=data,
media_type=media_type,
identifier=identifier or _identifier,
vendor_metadata=vendor_metadata,
metadata=metadata,
kind=kind,
)

if not self.is_image:
Expand Down Expand Up @@ -657,7 +696,7 @@ class CachePoint:


MultiModalContent = ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent
UserContent: TypeAlias = str | MultiModalContent | CachePoint
UserContent: TypeAlias = str | TextContent | MultiModalContent | CachePoint


@dataclass(repr=False)
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
UserPromptPart,
VideoUrl,
_utils,
messages,
usage,
)
from pydantic_ai._run_context import RunContext
Expand Down Expand Up @@ -628,6 +629,8 @@ async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int])
for item in part.content:
if isinstance(item, str):
content.append({'text': item})
elif isinstance(item, messages.TextContent):
content.append({'text': item.content})
elif isinstance(item, BinaryContent):
format = item.format
if item.is_document:
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ModelResponseStreamEvent,
RetryPromptPart,
SystemPromptPart,
TextContent,
TextPart,
ThinkingPart,
ToolCallPart,
Expand Down Expand Up @@ -374,6 +375,8 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion]
for item in part.content:
if isinstance(item, str):
content.append({'text': item})
elif isinstance(item, TextContent):
content.append({'text': item.content})
elif isinstance(item, BinaryContent):
base64_encoded = base64.b64encode(item.data).decode('utf-8')
content.append(
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ModelResponseStreamEvent,
RetryPromptPart,
SystemPromptPart,
TextContent,
TextPart,
ThinkingPart,
ToolCallPart,
Expand Down Expand Up @@ -601,6 +602,8 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]:
for item in part.content:
if isinstance(item, str):
content.append({'text': item})
elif isinstance(item, TextContent):
content.append({'text': item.content})
elif isinstance(item, BinaryContent):
inline_data_dict: BlobDict = {'data': item.data, 'mime_type': item.media_type}
part_dict: PartDict = {'inline_data': inline_data_dict}
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ModelResponseStreamEvent,
RetryPromptPart,
SystemPromptPart,
TextContent,
TextPart,
ThinkingPart,
ToolCallPart,
Expand Down Expand Up @@ -433,6 +434,8 @@ async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage:
for item in part.content:
if isinstance(item, str):
content.append(ChatCompletionInputMessageChunk(type='text', text=item)) # type: ignore
elif isinstance(item, TextContent):
content.append(ChatCompletionInputMessageChunk(type='text', text=item.content)) # type: ignore
elif isinstance(item, ImageUrl):
url = ChatCompletionInputURL(url=item.url) # type: ignore
content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore
Expand Down
10 changes: 10 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
PartStartEvent,
RetryPromptPart,
SystemPromptPart,
TextContent,
TextPart,
ThinkingPart,
ToolCallPart,
Expand Down Expand Up @@ -914,6 +915,13 @@ async def _map_user_prompt(self, part: UserPromptPart) -> chat.ChatCompletionUse
for item in part.content:
if isinstance(item, str):
content.append(ChatCompletionContentPartTextParam(text=item, type='text'))
elif isinstance(item, TextContent):
content.append(
ChatCompletionContentPartTextParam(
text=item.content,
type='text',
)
)
elif isinstance(item, ImageUrl):
image_url: ImageURL = {'url': item.url}
if metadata := item.vendor_metadata:
Expand Down Expand Up @@ -1754,6 +1762,8 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa
for item in part.content:
if isinstance(item, str):
content.append(responses.ResponseInputTextParam(text=item, type='input_text'))
elif isinstance(item, TextContent):
content.append(responses.ResponseInputTextParam(text=item.content, type='input_text'))
elif isinstance(item, BinaryContent):
if item.is_image:
detail: Literal['auto', 'low', 'high'] = 'auto'
Expand Down
30 changes: 30 additions & 0 deletions tests/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,36 @@ async def get_weather_forecast(location: str) -> str:
return f'The weather in {location} is sunny and 26 degrees Celsius.'


@mcp.tool(structured_output=False, annotations=ToolAnnotations(title='Collatz Conjecture sequence generator'))
async def get_collatz_conjecture(n: int) -> TextContent:
"""Generate the Collatz conjecture sequence for a given number.
This tool attaches response metadata.
Args:
n: The starting number for the Collatz sequence.
Returns:
A list representing the Collatz sequence with attached metadata.
"""
if n <= 0:
raise ValueError('Starting number for the Collatz conjecture must be a positive integer.')

input_param_n = n # store the original input value

sequence = [n]
while n != 1:
if n % 2 == 0:
n = n // 2
else:
n = 3 * n + 1
sequence.append(n)

return TextContent(
type='text',
text=str(sequence),
_meta={'pydantic_ai': {'tool': 'collatz_conjecture', 'n': input_param_n, 'length': len(sequence)}},
)


@mcp.tool()
async def get_image_resource() -> EmbeddedResource:
data = Path(__file__).parent.joinpath('assets/kiwi.png').read_bytes()
Expand Down
Loading
Loading