Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/mcp_agent/workflows/llm/augmented_llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,7 @@ async def generate(

if params.use_history:
messages.extend(self.history.get())
messages.extend(
AnthropicConverter.convert_mixed_messages_to_anthropic(message)
)
messages.extend(AnthropicConverter.from_mixed_messages(message))

list_tools_result = await self.agent.list_tools()
available_tools: List[ToolParam] = [
Expand Down
2 changes: 1 addition & 1 deletion src/mcp_agent/workflows/llm/augmented_llm_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ async def generate(self, message, request_params: RequestParams | None = None):
messages.append(SystemMessage(content=system_prompt))
span.set_attribute("system_prompt", system_prompt)

messages.extend(AzureConverter.convert_mixed_messages_to_azure(message))
messages.extend(AzureConverter.from_mixed_messages(message))

response = await self.agent.list_tools()

Expand Down
2 changes: 1 addition & 1 deletion src/mcp_agent/workflows/llm/augmented_llm_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def generate(self, message, request_params: RequestParams | None = None):
if params.use_history:
messages.extend(self.history.get())

messages.extend(BedrockConverter.convert_mixed_messages_to_bedrock(message))
messages.extend(BedrockConverter.from_mixed_messages(message))

response = await self.agent.list_tools()

Expand Down
2 changes: 1 addition & 1 deletion src/mcp_agent/workflows/llm/augmented_llm_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def generate(self, message, request_params: RequestParams | None = None):
if params.use_history:
messages.extend(self.history.get())

messages.extend(GoogleConverter.convert_mixed_messages_to_google(message))
messages.extend(GoogleConverter.from_mixed_messages(message))

response = await self.agent.list_tools()

Expand Down
2 changes: 1 addition & 1 deletion src/mcp_agent/workflows/llm/augmented_llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def generate(
role="system", content=system_prompt
)
)
messages.extend((OpenAIConverter.convert_mixed_messages_to_openai(message)))
messages.extend((OpenAIConverter.from_mixed_messages(message)))

response: ListToolsResult = await self.agent.list_tools()
available_tools: List[ChatCompletionToolParam] = [
Expand Down
32 changes: 32 additions & 0 deletions src/mcp_agent/workflows/llm/multipart_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Generic, Protocol

from mcp_agent.utils.prompt_message_multipart import PromptMessageMultipart
from mcp.types import PromptMessage, CallToolResult
from mcp_agent.workflows.llm.augmented_llm import MessageTypes
from mcp_agent.workflows.llm.augmented_llm import MessageParamT, MessageT


class MessageConverter(Protocol, Generic[MessageParamT, MessageT]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want an API here for converting from provider message type to PromptMessage? i.e. a to_prompt_message?

@staticmethod
def from_prompt_message_multipart(
multipart_msg: PromptMessageMultipart, concatenate_text_blocks: bool = False
) -> MessageParamT:
"""Convert a PromptMessageMultipart to a Provider-compatible message param type"""
...

@staticmethod
def from_prompt_message(message: PromptMessage) -> MessageParamT:
"""Convert a MCP PromptMessage to a Provider-compatible message param type"""
...

@staticmethod
def from_mixed_messages(message: MessageTypes) -> list[MessageParamT]:
"""Convert a mixed message type to a list of Provider-compatible message param types"""
...

@staticmethod
def from_tool_results(
tool_results: list[tuple[str, CallToolResult]],
) -> MessageParamT | list[MessageParamT]:
"""Convert a list of MCP CallToolResult to Provider-compatible message param type"""
...
26 changes: 13 additions & 13 deletions src/mcp_agent/workflows/llm/multipart_converter_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ToolResultBlockParam,
URLImageSourceParam,
URLPDFSourceParam,
Message,
)
from mcp.types import (
BlobResourceContents,
Expand Down Expand Up @@ -40,14 +41,15 @@
from mcp_agent.utils.prompt_message_multipart import PromptMessageMultipart
from mcp_agent.utils.resource_utils import extract_title_from_uri
from mcp_agent.workflows.llm.augmented_llm import MessageTypes
from mcp_agent.workflows.llm.multipart_converter import MessageConverter

_logger = get_logger("multipart_converter_anthropic")

# List of image MIME types supported by Anthropic API
SUPPORTED_IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"}


class AnthropicConverter:
class AnthropicConverter(MessageConverter[MessageParam, Message]):
"""Converts MCP message types to Anthropic API format."""

@staticmethod
Expand All @@ -63,7 +65,9 @@ def _is_supported_image_type(mime_type: str) -> bool:
return mime_type in SUPPORTED_IMAGE_MIME_TYPES

@staticmethod
def convert_to_anthropic(multipart_msg: PromptMessageMultipart) -> MessageParam:
def from_prompt_message_multipart(
multipart_msg: PromptMessageMultipart,
) -> MessageParam:
"""
Convert a PromptMessageMultipart message to Anthropic API format.

Expand Down Expand Up @@ -100,7 +104,7 @@ def convert_to_anthropic(multipart_msg: PromptMessageMultipart) -> MessageParam:
return MessageParam(role=role, content=anthropic_blocks)

@staticmethod
def convert_prompt_message_to_anthropic(message: PromptMessage) -> MessageParam:
def from_prompt_message(message: PromptMessage) -> MessageParam:
"""
Convert a standard PromptMessage to Anthropic API format.

Expand All @@ -114,7 +118,7 @@ def convert_prompt_message_to_anthropic(message: PromptMessage) -> MessageParam:
multipart = PromptMessageMultipart(role=message.role, content=[message.content])

# Use the existing conversion method
return AnthropicConverter.convert_to_anthropic(multipart)
return AnthropicConverter.from_prompt_message_multipart(multipart)

@staticmethod
def _convert_content_items(
Expand Down Expand Up @@ -364,7 +368,7 @@ def _create_fallback_text(
return TextBlockParam(type="text", text=f"[{message}]")

@staticmethod
def convert_tool_result_to_anthropic(
def from_tool_result(
tool_result: CallToolResult, tool_use_id: str
) -> ToolResultBlockParam:
"""
Expand Down Expand Up @@ -409,7 +413,7 @@ def convert_tool_result_to_anthropic(
)

@staticmethod
def create_tool_results_message(
def from_tool_results(
tool_results: List[tuple[str, CallToolResult]],
) -> MessageParam:
"""
Expand Down Expand Up @@ -482,7 +486,7 @@ def create_tool_results_message(
return MessageParam(role="user", content=content_blocks)

@staticmethod
def convert_mixed_messages_to_anthropic(
def from_mixed_messages(
message: MessageTypes,
) -> List[MessageParam]:
"""
Expand All @@ -499,15 +503,11 @@ def convert_mixed_messages_to_anthropic(
if isinstance(message, str):
messages.append(MessageParam(role="user", content=message))
elif isinstance(message, PromptMessage):
messages.append(
AnthropicConverter.convert_prompt_message_to_anthropic(message)
)
messages.append(AnthropicConverter.from_prompt_message(message))
elif isinstance(message, list):
for m in message:
if isinstance(m, PromptMessage):
messages.append(
AnthropicConverter.convert_prompt_message_to_anthropic(m)
)
messages.append(AnthropicConverter.from_prompt_message(m))
elif isinstance(m, str):
messages.append(MessageParam(role="user", content=m))
else:
Expand Down
41 changes: 22 additions & 19 deletions src/mcp_agent/workflows/llm/multipart_converter_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
AssistantMessage,
ToolMessage,
DeveloperMessage,
ChatResponseMessage,
)
from mcp.types import (
BlobResourceContents,
Expand Down Expand Up @@ -38,21 +39,31 @@
from mcp_agent.utils.prompt_message_multipart import PromptMessageMultipart
from mcp_agent.utils.resource_utils import extract_title_from_uri
from mcp_agent.workflows.llm.augmented_llm import MessageTypes
from mcp_agent.workflows.llm.multipart_converter import MessageConverter

_logger = get_logger("multipart_converter_azure")

AzureMessageParam = Union[
SystemMessage, UserMessage, AssistantMessage, ToolMessage, DeveloperMessage
]

SUPPORTED_IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"}


class AzureConverter:
class AzureConverter(
MessageConverter[
AzureMessageParam,
ChatResponseMessage,
]
):
"""Converts MCP message types to Azure API format."""

@staticmethod
def _is_supported_image_type(mime_type: str) -> bool:
return mime_type in SUPPORTED_IMAGE_MIME_TYPES

@staticmethod
def convert_to_azure(
def from_prompt_message_multipart(
multipart_msg: PromptMessageMultipart,
) -> UserMessage | AssistantMessage:
"""
Expand Down Expand Up @@ -92,7 +103,7 @@ def convert_to_azure(
return UserMessage(content=content)

@staticmethod
def convert_prompt_message_to_azure(
def from_prompt_message(
message: PromptMessage,
) -> UserMessage | AssistantMessage:
"""
Expand All @@ -105,7 +116,7 @@ def convert_prompt_message_to_azure(
An Azure UserMessage or AssistantMessage object
"""
multipart = PromptMessageMultipart(role=message.role, content=[message.content])
return AzureConverter.convert_to_azure(multipart)
return AzureConverter.from_prompt_message_multipart(multipart)

@staticmethod
def _convert_content_items(
Expand Down Expand Up @@ -248,9 +259,7 @@ def _create_fallback_text(
return TextContentItem(text=f"[{message}]")

@staticmethod
def convert_tool_result_to_azure(
tool_result: CallToolResult, tool_use_id: str
) -> ToolMessage:
def from_tool_result(tool_result: CallToolResult, tool_use_id: str) -> ToolMessage:
"""
Convert an MCP CallToolResult to an Azure ToolMessage.

Expand Down Expand Up @@ -308,7 +317,7 @@ def _extract_text_from_azure_content_blocks(
return "\n".join(texts)

@staticmethod
def create_tool_results_message(
def from_tool_results(
tool_results: List[tuple[str, CallToolResult]],
) -> List[ToolMessage]:
"""
Expand All @@ -322,20 +331,14 @@ def create_tool_results_message(
"""
tool_messages = []
for tool_use_id, result in tool_results:
tool_message = AzureConverter.convert_tool_result_to_azure(
result, tool_use_id
)
tool_message = AzureConverter.from_tool_result(result, tool_use_id)
tool_messages.append(tool_message)
return tool_messages

@staticmethod
def convert_mixed_messages_to_azure(
def from_mixed_messages(
message: MessageTypes,
) -> List[
Union[
SystemMessage, UserMessage, AssistantMessage, ToolMessage, DeveloperMessage
]
]:
) -> List[AzureMessageParam]:
"""
Convert a list of mixed messages to a list of Azure-compatible messages.

Expand All @@ -351,11 +354,11 @@ def convert_mixed_messages_to_azure(
if isinstance(message, str):
messages.append(UserMessage(content=message))
elif isinstance(message, PromptMessage):
messages.append(AzureConverter.convert_prompt_message_to_azure(message))
messages.append(AzureConverter.from_prompt_message(message))
elif isinstance(message, list):
for m in message:
if isinstance(m, PromptMessage):
messages.append(AzureConverter.convert_prompt_message_to_azure(m))
messages.append(AzureConverter.from_prompt_message(m))
elif isinstance(m, str):
messages.append(UserMessage(content=m))
else:
Expand Down
21 changes: 10 additions & 11 deletions src/mcp_agent/workflows/llm/multipart_converter_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from mcp_agent.utils.prompt_message_multipart import PromptMessageMultipart
from mcp_agent.utils.resource_utils import extract_title_from_uri
from mcp_agent.workflows.llm.augmented_llm import MessageTypes
from mcp_agent.workflows.llm.multipart_converter import MessageConverter

if TYPE_CHECKING:
from mypy_boto3_bedrock_runtime.type_defs import (
Expand All @@ -44,7 +45,7 @@
SUPPORTED_IMAGE_MIME_TYPES = {"image/jpeg", "image/png"}


class BedrockConverter:
class BedrockConverter(MessageConverter[MessageUnionTypeDef, MessageUnionTypeDef]):
"""Converts MCP message types to Amazon Bedrock API format."""

@staticmethod
Expand All @@ -53,7 +54,7 @@ def _is_supported_image_type(mime_type: str) -> bool:
return mime_type in SUPPORTED_IMAGE_MIME_TYPES

@staticmethod
def convert_to_bedrock(
def from_prompt_message_multipart(
multipart_msg: PromptMessageMultipart,
) -> MessageUnionTypeDef:
"""
Expand All @@ -68,14 +69,14 @@ def convert_to_bedrock(
return {"role": role, "content": bedrock_blocks}

@staticmethod
def convert_prompt_message_to_bedrock(
def from_prompt_message(
message: PromptMessage,
) -> MessageUnionTypeDef:
"""
Convert a standard PromptMessage to Bedrock API format.
"""
multipart = PromptMessageMultipart(role=message.role, content=[message.content])
return BedrockConverter.convert_to_bedrock(multipart)
return BedrockConverter.from_prompt_message_multipart(multipart)

@staticmethod
def _convert_content_items(
Expand Down Expand Up @@ -227,7 +228,7 @@ def _create_fallback_text(
return {"text": f"[{message}]"}

@staticmethod
def convert_tool_result_to_bedrock(
def from_tool_result(
tool_result: CallToolResult, tool_use_id: str
) -> ToolResultBlockTypeDef:
"""
Expand All @@ -245,7 +246,7 @@ def convert_tool_result_to_bedrock(
}

@staticmethod
def create_tool_results_message(
def from_tool_results(
tool_results: List[tuple[str, CallToolResult]],
) -> MessageUnionTypeDef:
"""
Expand All @@ -268,7 +269,7 @@ def create_tool_results_message(
return {"role": "user", "content": content_blocks}

@staticmethod
def convert_mixed_messages_to_bedrock(
def from_mixed_messages(
message: MessageTypes,
) -> List[MessageUnionTypeDef]:
"""
Expand All @@ -286,13 +287,11 @@ def convert_mixed_messages_to_bedrock(
if isinstance(message, str):
messages.append({"role": "user", "content": [{"text": message}]})
elif isinstance(message, PromptMessage):
messages.append(BedrockConverter.convert_prompt_message_to_bedrock(message))
messages.append(BedrockConverter.from_prompt_message(message))
elif isinstance(message, list):
for m in message:
if isinstance(m, PromptMessage):
messages.append(
BedrockConverter.convert_prompt_message_to_bedrock(m)
)
messages.append(BedrockConverter.from_prompt_message(m))
elif isinstance(m, str):
messages.append({"role": "user", "content": [{"text": m}]})
else:
Expand Down
Loading
Loading