Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,15 +1277,14 @@ def _create_prefixed_tools(

name_to_use = prefixed_name if add_prefix else tool.name

tool_obj = MCPTool(
name=name_to_use,
description=tool.description,
inputSchema=tool.inputSchema,
)
prefixed_tools.append(tool_obj)
# Preserve all tool fields including metadata/_meta by mutating the original tool
# Similar to how _create_prefixed_prompts works
original_name = tool.name
tool.name = name_to_use
prefixed_tools.append(tool)

# Update tool to server mapping for resolution (support both forms)
self.tool_name_to_mcp_server_name_mapping[tool.name] = prefix
self.tool_name_to_mcp_server_name_mapping[original_name] = prefix
self.tool_name_to_mcp_server_name_mapping[prefixed_name] = prefix

verbose_logger.info(
Expand Down
91 changes: 58 additions & 33 deletions litellm/proxy/_experimental/mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import asyncio
import contextlib
from datetime import datetime
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union, cast

from fastapi import FastAPI, HTTPException
from pydantic import AnyUrl, ConfigDict
Expand Down Expand Up @@ -72,7 +72,13 @@
auth_context_var,
)
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.types import EmbeddedResource, ImageContent, Prompt, TextContent
from mcp.types import (
CallToolResult,
EmbeddedResource,
ImageContent,
Prompt,
TextContent,
)
from mcp.types import Tool as MCPTool

from litellm.proxy._experimental.mcp_server.auth.litellm_auth_handler import (
Expand Down Expand Up @@ -234,7 +240,7 @@ async def list_tools() -> List[MCPTool]:
@server.call_tool()
async def mcp_server_tool_call(
name: str, arguments: Dict[str, Any] | None
) -> List[Union[TextContent, ImageContent, EmbeddedResource]]:
) -> CallToolResult:
"""
Call a specific tool with the provided arguments

Expand Down Expand Up @@ -300,26 +306,37 @@ async def mcp_server_tool_call(
)
except BlockedPiiEntityError as e:
verbose_logger.error(f"BlockedPiiEntityError in MCP tool call: {str(e)}")
# Return error as text content for MCP protocol
return [
TextContent(
text=f"Error: Blocked PII entity detected - {str(e)}", type="text"
)
]
return CallToolResult(
content=[
TextContent(
text=f"Error: Blocked PII entity detected - {str(e)}",
type="text",
)
],
isError=True,
)
except GuardrailRaisedException as e:
verbose_logger.error(f"GuardrailRaisedException in MCP tool call: {str(e)}")
# Return error as text content for MCP protocol
return [
TextContent(text=f"Error: Guardrail violation - {str(e)}", type="text")
]
return CallToolResult(
content=[
TextContent(
text=f"Error: Guardrail violation - {str(e)}", type="text"
)
],
isError=True,
)
except HTTPException as e:
verbose_logger.error(f"HTTPException in MCP tool call: {str(e)}")
# Return error as text content for MCP protocol
return [TextContent(text=f"Error: {str(e.detail)}", type="text")]
return CallToolResult(
content=[TextContent(text=f"Error: {str(e.detail)}", type="text")],
isError=True,
)
except Exception as e:
verbose_logger.exception(f"MCP mcp_server_tool_call - error: {e}")
# Return error as text content for MCP protocol
return [TextContent(text=f"Error: {str(e)}", type="text")]
return CallToolResult(
content=[TextContent(text=f"Error: {str(e)}", type="text")],
isError=True,
)

return response

Expand Down Expand Up @@ -1173,7 +1190,7 @@ async def call_mcp_tool(
oauth2_headers: Optional[Dict[str, str]] = None,
raw_headers: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Union[TextContent, ImageContent, EmbeddedResource]]:
) -> CallToolResult:
"""
Call a specific tool with the provided arguments (handles prefixed tool names)
"""
Expand Down Expand Up @@ -1237,17 +1254,18 @@ async def call_mcp_tool(
"litellm_logging_obj", None
)
if litellm_logging_obj:
litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = (
standard_logging_mcp_tool_call
)
litellm_logging_obj.model_call_details[
"mcp_tool_call_metadata"
] = standard_logging_mcp_tool_call
litellm_logging_obj.model = f"MCP: {name}"
# Check if tool exists in local registry first (for OpenAPI-based tools)
# These tools are registered with their prefixed names
#########################################################
local_tool = global_mcp_tool_registry.get_tool(name)
if local_tool:
verbose_logger.debug(f"Executing local registry tool: {name}")
response = await _handle_local_mcp_tool(name, arguments)
local_content = await _handle_local_mcp_tool(name, arguments)
response = CallToolResult(content=cast(Any, local_content), isError=False)

# Try managed MCP server tool (pass the full prefixed name)
# Primary and recommended way to use external MCP servers
Expand Down Expand Up @@ -1279,7 +1297,12 @@ async def call_mcp_tool(
# Deprecated: Local MCP Server Tool
#########################################################
else:
response = await _handle_local_mcp_tool(original_tool_name, arguments)
local_content = await _handle_local_mcp_tool(
original_tool_name, arguments
)
response = CallToolResult(
content=cast(Any, local_content), isError=False
)

#########################################################
# Post MCP Tool Call Hook
Expand Down Expand Up @@ -1432,7 +1455,7 @@ async def _handle_managed_mcp_tool(
oauth2_headers: Optional[Dict[str, str]] = None,
raw_headers: Optional[Dict[str, str]] = None,
litellm_logging_obj: Optional[Any] = None,
) -> List[Union[TextContent, ImageContent, EmbeddedResource]]:
) -> CallToolResult:
"""Handle tool execution for managed server tools"""
# Import here to avoid circular import
from litellm.proxy.proxy_server import proxy_logging_obj
Expand All @@ -1449,7 +1472,7 @@ async def _handle_managed_mcp_tool(
proxy_logging_obj=proxy_logging_obj,
)
verbose_logger.debug("CALL TOOL RESULT: %s", call_tool_result)
return call_tool_result.content # type: ignore[return-value]
return call_tool_result

async def _handle_local_mcp_tool(
name: str, arguments: Dict[str, Any]
Expand Down Expand Up @@ -1741,14 +1764,16 @@ def set_auth_context(
)
auth_context_var.set(auth_user)

def get_auth_context() -> Tuple[
Optional[UserAPIKeyAuth],
Optional[str],
Optional[List[str]],
Optional[Dict[str, Dict[str, str]]],
Optional[Dict[str, str]],
Optional[Dict[str, str]],
]:
def get_auth_context() -> (
Tuple[
Optional[UserAPIKeyAuth],
Optional[str],
Optional[List[str]],
Optional[Dict[str, Dict[str, str]]],
Optional[Dict[str, str]],
Optional[Dict[str, str]],
]
):
"""
Get the UserAPIKeyAuth from the auth context variable.

Expand Down
6 changes: 5 additions & 1 deletion tests/mcp_tests/test_mcp_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ def mock_client_constructor(*args, **kwargs):

# Add assertions
assert response is not None
response_list = list(response) # Convert iterable to list
# Handle CallToolResult - access .content for the list of content items
if isinstance(response, CallToolResult):
response_list = response.content
else:
response_list = list(response) # Convert iterable to list for backward compatibility
assert len(response_list) == 1
assert isinstance(response_list[0], TextContent)
assert response_list[0].text == "Test response"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
Tests for MCP metadata preservation.

This module tests that tool metadata is preserved when creating prefixed tools,
which is critical for ChatGPT UI widget rendering.
"""

import sys

import pytest

# Add the parent directory to the path so we can import litellm
sys.path.insert(0, "../../../../../")

from mcp.types import Tool as MCPTool

from litellm.proxy._experimental.mcp_server.mcp_server_manager import MCPServerManager
from litellm.proxy._types import MCPTransport
from litellm.types.mcp_server.mcp_server_manager import MCPServer


class TestMCPMetadataPreservation:
"""Test that metadata is preserved when creating prefixed tools"""

def test_create_prefixed_tools_preserves_metadata(self):
"""Test that _create_prefixed_tools preserves metadata and _meta fields"""
manager = MCPServerManager()

# Create a mock server
mock_server = MCPServer(
server_id="test-server-1",
name="test_server",
alias="test",
server_name="Test Server",
url="https://test-server.com/mcp",
transport=MCPTransport.http,
)

# Create a tool with metadata
tool_with_metadata = MCPTool(
name="hello_widget",
description="Display a greeting widget",
inputSchema={"type": "object", "properties": {}},
)
# Add metadata using setattr since MCPTool might not have it in the constructor
tool_with_metadata.metadata = {
"openai/outputTemplate": "ui://widget/hello.html",
"openai/widgetDescription": "A greeting widget",
}
tool_with_metadata._meta = {
"openai/toolInvocation/invoking": "Preparing greeting...",
}

# Create prefixed tools
prefixed_tools = manager._create_prefixed_tools(
[tool_with_metadata], mock_server, add_prefix=True
)

# Verify
assert len(prefixed_tools) == 1
prefixed_tool = prefixed_tools[0]

# Check that name is prefixed
assert prefixed_tool.name == "test-hello_widget"

# Check that metadata is preserved
assert hasattr(prefixed_tool, "metadata")
assert prefixed_tool.metadata == {
"openai/outputTemplate": "ui://widget/hello.html",
"openai/widgetDescription": "A greeting widget",
}

# Check that _meta is preserved
assert hasattr(prefixed_tool, "_meta")
assert prefixed_tool._meta == {
"openai/toolInvocation/invoking": "Preparing greeting...",
}

# Check that other fields are preserved
assert prefixed_tool.description == "Display a greeting widget"
assert prefixed_tool.inputSchema == {"type": "object", "properties": {}}



if __name__ == "__main__":
pytest.main([__file__])

Loading