diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 2c03cbdae31b..3e619f28d476 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -1277,15 +1277,14 @@ def _create_prefixed_tools( name_to_use = prefixed_name if add_prefix else tool.name - tool_obj = MCPTool( - name=name_to_use, - description=tool.description, - inputSchema=tool.inputSchema, - ) - prefixed_tools.append(tool_obj) + # Preserve all tool fields including metadata/_meta by mutating the original tool + # Similar to how _create_prefixed_prompts works + original_name = tool.name + tool.name = name_to_use + prefixed_tools.append(tool) # Update tool to server mapping for resolution (support both forms) - self.tool_name_to_mcp_server_name_mapping[tool.name] = prefix + self.tool_name_to_mcp_server_name_mapping[original_name] = prefix self.tool_name_to_mcp_server_name_mapping[prefixed_name] = prefix verbose_logger.info( diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 412a6de0059a..edf53e99573b 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -6,7 +6,7 @@ import asyncio import contextlib from datetime import datetime -from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union +from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union, cast from fastapi import FastAPI, HTTPException from pydantic import AnyUrl, ConfigDict @@ -72,7 +72,13 @@ auth_context_var, ) from mcp.server.streamable_http_manager import StreamableHTTPSessionManager - from mcp.types import EmbeddedResource, ImageContent, Prompt, TextContent + from mcp.types import ( + CallToolResult, + EmbeddedResource, + ImageContent, + Prompt, + TextContent, + ) from mcp.types import Tool as MCPTool from litellm.proxy._experimental.mcp_server.auth.litellm_auth_handler import ( @@ -234,7 +240,7 @@ async def list_tools() -> List[MCPTool]: @server.call_tool() async def mcp_server_tool_call( name: str, arguments: Dict[str, Any] | None - ) -> List[Union[TextContent, ImageContent, EmbeddedResource]]: + ) -> CallToolResult: """ Call a specific tool with the provided arguments @@ -300,26 +306,37 @@ async def mcp_server_tool_call( ) except BlockedPiiEntityError as e: verbose_logger.error(f"BlockedPiiEntityError in MCP tool call: {str(e)}") - # Return error as text content for MCP protocol - return [ - TextContent( - text=f"Error: Blocked PII entity detected - {str(e)}", type="text" - ) - ] + return CallToolResult( + content=[ + TextContent( + text=f"Error: Blocked PII entity detected - {str(e)}", + type="text", + ) + ], + isError=True, + ) except GuardrailRaisedException as e: verbose_logger.error(f"GuardrailRaisedException in MCP tool call: {str(e)}") - # Return error as text content for MCP protocol - return [ - TextContent(text=f"Error: Guardrail violation - {str(e)}", type="text") - ] + return CallToolResult( + content=[ + TextContent( + text=f"Error: Guardrail violation - {str(e)}", type="text" + ) + ], + isError=True, + ) except HTTPException as e: verbose_logger.error(f"HTTPException in MCP tool call: {str(e)}") - # Return error as text content for MCP protocol - return [TextContent(text=f"Error: {str(e.detail)}", type="text")] + return CallToolResult( + content=[TextContent(text=f"Error: {str(e.detail)}", type="text")], + isError=True, + ) except Exception as e: verbose_logger.exception(f"MCP mcp_server_tool_call - error: {e}") - # Return error as text content for MCP protocol - return [TextContent(text=f"Error: {str(e)}", type="text")] + return CallToolResult( + content=[TextContent(text=f"Error: {str(e)}", type="text")], + isError=True, + ) return response @@ -1173,7 +1190,7 @@ async def call_mcp_tool( oauth2_headers: Optional[Dict[str, str]] = None, raw_headers: Optional[Dict[str, str]] = None, **kwargs: Any, - ) -> List[Union[TextContent, ImageContent, EmbeddedResource]]: + ) -> CallToolResult: """ Call a specific tool with the provided arguments (handles prefixed tool names) """ @@ -1237,9 +1254,9 @@ async def call_mcp_tool( "litellm_logging_obj", None ) if litellm_logging_obj: - litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = ( - standard_logging_mcp_tool_call - ) + litellm_logging_obj.model_call_details[ + "mcp_tool_call_metadata" + ] = standard_logging_mcp_tool_call litellm_logging_obj.model = f"MCP: {name}" # Check if tool exists in local registry first (for OpenAPI-based tools) # These tools are registered with their prefixed names @@ -1247,7 +1264,8 @@ async def call_mcp_tool( local_tool = global_mcp_tool_registry.get_tool(name) if local_tool: verbose_logger.debug(f"Executing local registry tool: {name}") - response = await _handle_local_mcp_tool(name, arguments) + local_content = await _handle_local_mcp_tool(name, arguments) + response = CallToolResult(content=cast(Any, local_content), isError=False) # Try managed MCP server tool (pass the full prefixed name) # Primary and recommended way to use external MCP servers @@ -1279,7 +1297,12 @@ async def call_mcp_tool( # Deprecated: Local MCP Server Tool ######################################################### else: - response = await _handle_local_mcp_tool(original_tool_name, arguments) + local_content = await _handle_local_mcp_tool( + original_tool_name, arguments + ) + response = CallToolResult( + content=cast(Any, local_content), isError=False + ) ######################################################### # Post MCP Tool Call Hook @@ -1432,7 +1455,7 @@ async def _handle_managed_mcp_tool( oauth2_headers: Optional[Dict[str, str]] = None, raw_headers: Optional[Dict[str, str]] = None, litellm_logging_obj: Optional[Any] = None, - ) -> List[Union[TextContent, ImageContent, EmbeddedResource]]: + ) -> CallToolResult: """Handle tool execution for managed server tools""" # Import here to avoid circular import from litellm.proxy.proxy_server import proxy_logging_obj @@ -1449,7 +1472,7 @@ async def _handle_managed_mcp_tool( proxy_logging_obj=proxy_logging_obj, ) verbose_logger.debug("CALL TOOL RESULT: %s", call_tool_result) - return call_tool_result.content # type: ignore[return-value] + return call_tool_result async def _handle_local_mcp_tool( name: str, arguments: Dict[str, Any] @@ -1741,14 +1764,16 @@ def set_auth_context( ) auth_context_var.set(auth_user) - def get_auth_context() -> Tuple[ - Optional[UserAPIKeyAuth], - Optional[str], - Optional[List[str]], - Optional[Dict[str, Dict[str, str]]], - Optional[Dict[str, str]], - Optional[Dict[str, str]], - ]: + def get_auth_context() -> ( + Tuple[ + Optional[UserAPIKeyAuth], + Optional[str], + Optional[List[str]], + Optional[Dict[str, Dict[str, str]]], + Optional[Dict[str, str]], + Optional[Dict[str, str]], + ] + ): """ Get the UserAPIKeyAuth from the auth context variable. diff --git a/tests/mcp_tests/test_mcp_logging.py b/tests/mcp_tests/test_mcp_logging.py index 44acd60ad495..47de30779796 100644 --- a/tests/mcp_tests/test_mcp_logging.py +++ b/tests/mcp_tests/test_mcp_logging.py @@ -109,7 +109,11 @@ def mock_client_constructor(*args, **kwargs): # Add assertions assert response is not None - response_list = list(response) # Convert iterable to list + # Handle CallToolResult - access .content for the list of content items + if isinstance(response, CallToolResult): + response_list = response.content + else: + response_list = list(response) # Convert iterable to list for backward compatibility assert len(response_list) == 1 assert isinstance(response_list[0], TextContent) assert response_list[0].text == "Test response" diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_metadata_preservation.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_metadata_preservation.py new file mode 100644 index 000000000000..6311b6d74b26 --- /dev/null +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_metadata_preservation.py @@ -0,0 +1,87 @@ +""" +Tests for MCP metadata preservation. + +This module tests that tool metadata is preserved when creating prefixed tools, +which is critical for ChatGPT UI widget rendering. +""" + +import sys + +import pytest + +# Add the parent directory to the path so we can import litellm +sys.path.insert(0, "../../../../../") + +from mcp.types import Tool as MCPTool + +from litellm.proxy._experimental.mcp_server.mcp_server_manager import MCPServerManager +from litellm.proxy._types import MCPTransport +from litellm.types.mcp_server.mcp_server_manager import MCPServer + + +class TestMCPMetadataPreservation: + """Test that metadata is preserved when creating prefixed tools""" + + def test_create_prefixed_tools_preserves_metadata(self): + """Test that _create_prefixed_tools preserves metadata and _meta fields""" + manager = MCPServerManager() + + # Create a mock server + mock_server = MCPServer( + server_id="test-server-1", + name="test_server", + alias="test", + server_name="Test Server", + url="https://test-server.com/mcp", + transport=MCPTransport.http, + ) + + # Create a tool with metadata + tool_with_metadata = MCPTool( + name="hello_widget", + description="Display a greeting widget", + inputSchema={"type": "object", "properties": {}}, + ) + # Add metadata using setattr since MCPTool might not have it in the constructor + tool_with_metadata.metadata = { + "openai/outputTemplate": "ui://widget/hello.html", + "openai/widgetDescription": "A greeting widget", + } + tool_with_metadata._meta = { + "openai/toolInvocation/invoking": "Preparing greeting...", + } + + # Create prefixed tools + prefixed_tools = manager._create_prefixed_tools( + [tool_with_metadata], mock_server, add_prefix=True + ) + + # Verify + assert len(prefixed_tools) == 1 + prefixed_tool = prefixed_tools[0] + + # Check that name is prefixed + assert prefixed_tool.name == "test-hello_widget" + + # Check that metadata is preserved + assert hasattr(prefixed_tool, "metadata") + assert prefixed_tool.metadata == { + "openai/outputTemplate": "ui://widget/hello.html", + "openai/widgetDescription": "A greeting widget", + } + + # Check that _meta is preserved + assert hasattr(prefixed_tool, "_meta") + assert prefixed_tool._meta == { + "openai/toolInvocation/invoking": "Preparing greeting...", + } + + # Check that other fields are preserved + assert prefixed_tool.description == "Display a greeting widget" + assert prefixed_tool.inputSchema == {"type": "object", "properties": {}} + + + +if __name__ == "__main__": + pytest.main([__file__]) +