Skip to content

Commit 397acec

Browse files
authored
Merge pull request #17342 from BerriAI/litellm_fix_mcp_auth_header_forwarding
Fix: litellm user auth not passing issue
2 parents 18a9af3 + 082c8af commit 397acec

File tree

4 files changed

+115
-10
lines changed

4 files changed

+115
-10
lines changed

litellm/proxy/_experimental/mcp_server/mcp_server_manager.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,15 @@
1717
from fastapi import HTTPException
1818
from httpx import HTTPStatusError
1919
from mcp import ReadResourceResult, Resource
20+
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
2021
from mcp.types import (
21-
CallToolRequestParams as MCPCallToolRequestParams,
22+
CallToolResult,
2223
GetPromptRequestParams,
2324
GetPromptResult,
2425
Prompt,
2526
ResourceTemplate,
2627
)
27-
from mcp.types import CallToolResult
2828
from mcp.types import Tool as MCPTool
29-
3029
from pydantic import AnyUrl
3130

3231
import litellm
@@ -1949,7 +1948,12 @@ def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPServer]:
19491948
) = split_server_prefix_from_name(tool_name)
19501949
if original_tool_name in self.tool_name_to_mcp_server_name_mapping:
19511950
for server in self.get_registry().values():
1952-
if normalize_server_name(server.name) == normalize_server_name(
1951+
if server.server_name is None:
1952+
if normalize_server_name(server.name) == normalize_server_name(
1953+
server_name_from_prefix
1954+
):
1955+
return server
1956+
elif normalize_server_name(server.server_name) == normalize_server_name(
19531957
server_name_from_prefix
19541958
):
19551959
return server

litellm/proxy/litellm_pre_call_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,8 @@ def add_user_api_key_auth_to_request_metadata(
611611
data[_metadata_variable_name]["user_api_end_user_max_budget"] = getattr(
612612
user_api_key_dict, "end_user_max_budget", None
613613
)
614+
# Add the full UserAPIKeyAuth object for MCP server access control
615+
data[_metadata_variable_name]["user_api_key_auth"] = user_api_key_dict
614616
return data
615617

616618
@staticmethod

litellm/responses/main.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from litellm._logging import verbose_logger
2323
from litellm.constants import request_timeout
2424
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
25+
from litellm.litellm_core_utils.prompt_templates.common_utils import (
26+
update_responses_input_with_model_file_ids,
27+
)
2528
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
2629
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
2730
from litellm.responses.litellm_completion_transformation.handler import (
@@ -38,9 +41,6 @@
3841
ToolChoice,
3942
ToolParam,
4043
)
41-
from litellm.litellm_core_utils.prompt_templates.common_utils import (
42-
update_responses_input_with_model_file_ids,
43-
)
4444

4545
# Handle ResponseText import with fallback
4646
if TYPE_CHECKING:
@@ -168,7 +168,8 @@ async def aresponses_api_with_mcp(
168168
) = LiteLLM_Proxy_MCP_Handler._parse_mcp_tools(tools)
169169

170170
# Process MCP tools through the complete pipeline (fetch + filter + deduplicate + transform)
171-
user_api_key_auth = kwargs.get("user_api_key_auth")
171+
# Extract user_api_key_auth from litellm_metadata (where it's added by add_user_api_key_auth_to_request_metadata)
172+
user_api_key_auth = kwargs.get("user_api_key_auth") or kwargs.get("litellm_metadata", {}).get("user_api_key_auth")
172173

173174
# Get original MCP tools (for events) and OpenAI tools (for LLM) by reusing existing methods
174175
(

tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
sys.path.insert(0, "../../../../../")
1010

1111
import httpx
12+
from mcp import ReadResourceResult, Resource
13+
from mcp.types import GetPromptResult, Prompt, ResourceTemplate, TextResourceContents
1214

1315
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
1416
MCPServerManager,
@@ -17,8 +19,6 @@
1719
from litellm.proxy._types import LiteLLM_MCPServerTable, MCPTransport
1820
from litellm.types.mcp import MCPAuth
1921
from litellm.types.mcp_server.mcp_server_manager import MCPOAuthMetadata, MCPServer
20-
from mcp import ReadResourceResult, Resource
21-
from mcp.types import GetPromptResult, Prompt, ResourceTemplate, TextResourceContents
2222

2323

2424
class TestMCPServerManager:
@@ -1606,6 +1606,104 @@ async def mock_call_tool(params):
16061606
# Verify the MCP client call was awaited exactly once
16071607
assert mock_client.call_tool.await_count == 1
16081608

1609+
@pytest.mark.asyncio
1610+
async def test_get_allowed_mcp_servers_with_user_api_key_auth(self):
1611+
"""
1612+
Test that get_allowed_mcp_servers properly receives and uses user_api_key_auth
1613+
when called. This verifies the fix where user_api_key_auth is passed through
1614+
litellm_metadata from responses API.
1615+
"""
1616+
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
1617+
MCPRequestHandler,
1618+
)
1619+
from litellm.proxy._types import LiteLLM_ObjectPermissionTable, UserAPIKeyAuth
1620+
1621+
manager = MCPServerManager()
1622+
1623+
# Create a mock user_api_key_auth with object_permission
1624+
object_permission = LiteLLM_ObjectPermissionTable(
1625+
object_permission_id="perm_123",
1626+
mcp_servers=["test_server_1", "test_server_2"],
1627+
mcp_access_groups=[],
1628+
)
1629+
1630+
user_api_key_auth = UserAPIKeyAuth(
1631+
api_key="sk-test",
1632+
user_id="user-123",
1633+
object_permission=object_permission,
1634+
object_permission_id="perm_123",
1635+
)
1636+
1637+
# Mock MCPRequestHandler.get_allowed_mcp_servers to verify it receives user_api_key_auth
1638+
with patch.object(
1639+
MCPRequestHandler,
1640+
"get_allowed_mcp_servers",
1641+
new_callable=AsyncMock,
1642+
) as mock_get_allowed:
1643+
# Configure mock to return servers from object_permission
1644+
mock_get_allowed.return_value = ["test_server_1", "test_server_2"]
1645+
1646+
# Call get_allowed_mcp_servers with user_api_key_auth
1647+
result = await manager.get_allowed_mcp_servers(user_api_key_auth)
1648+
1649+
# Verify MCPRequestHandler.get_allowed_mcp_servers was called with user_api_key_auth
1650+
mock_get_allowed.assert_called_once()
1651+
call_args = mock_get_allowed.call_args
1652+
assert call_args[0][0] is user_api_key_auth # First positional arg should be user_api_key_auth
1653+
assert call_args[0][0].user_id == "user-123"
1654+
assert call_args[0][0].object_permission_id == "perm_123"
1655+
assert call_args[0][0].object_permission is not None
1656+
assert call_args[0][0].object_permission.mcp_servers == ["test_server_1", "test_server_2"]
1657+
1658+
# Verify result contains the expected servers
1659+
assert "test_server_1" in result
1660+
assert "test_server_2" in result
1661+
1662+
def test_get_mcp_server_from_tool_name_uses_server_name_not_name(self):
1663+
"""
1664+
Test that _get_mcp_server_from_tool_name uses server.server_name instead of server.name
1665+
when extracting server name from prefixed tool name (second case).
1666+
This ensures the fix for using server_name instead of name works correctly.
1667+
"""
1668+
from litellm.proxy._experimental.mcp_server.utils import (
1669+
add_server_prefix_to_name,
1670+
)
1671+
1672+
manager = MCPServerManager()
1673+
1674+
# Create a server where server_name differs from name
1675+
# This tests the scenario where server.name != server.server_name
1676+
server = MCPServer(
1677+
server_id="test-server-id",
1678+
name="Test Server Name", # Different from server_name
1679+
server_name="test_server", # This is what should be used
1680+
alias="test_server",
1681+
transport=MCPTransport.http,
1682+
)
1683+
1684+
# Register the server
1685+
manager.registry = {server.server_id: server}
1686+
1687+
# Create a tool with prefixed name
1688+
tool_name = "test_tool"
1689+
prefixed_tool_name = add_server_prefix_to_name(tool_name, "test_server")
1690+
1691+
# Populate the mapping with the original tool name
1692+
manager.tool_name_to_mcp_server_name_mapping[tool_name] = "test_server"
1693+
manager.tool_name_to_mcp_server_name_mapping[prefixed_tool_name] = "test_server"
1694+
1695+
# Test: _get_mcp_server_from_tool_name should find the server using server.server_name
1696+
# even when server.name is different
1697+
resolved_server = manager._get_mcp_server_from_tool_name(prefixed_tool_name)
1698+
1699+
# Verify the server was found correctly
1700+
assert resolved_server is not None
1701+
assert resolved_server.server_id == server.server_id
1702+
assert resolved_server.server_name == "test_server"
1703+
# Verify it matched using server_name, not name
1704+
assert resolved_server.name == "Test Server Name" # name is different
1705+
assert resolved_server.server_name == "test_server" # server_name matches
1706+
16091707

16101708
if __name__ == "__main__":
16111709
pytest.main([__file__])

0 commit comments

Comments
 (0)