|
9 | 9 | sys.path.insert(0, "../../../../../") |
10 | 10 |
|
11 | 11 | import httpx |
| 12 | +from mcp import ReadResourceResult, Resource |
| 13 | +from mcp.types import GetPromptResult, Prompt, ResourceTemplate, TextResourceContents |
12 | 14 |
|
13 | 15 | from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( |
14 | 16 | MCPServerManager, |
|
17 | 19 | from litellm.proxy._types import LiteLLM_MCPServerTable, MCPTransport |
18 | 20 | from litellm.types.mcp import MCPAuth |
19 | 21 | from litellm.types.mcp_server.mcp_server_manager import MCPOAuthMetadata, MCPServer |
20 | | -from mcp import ReadResourceResult, Resource |
21 | | -from mcp.types import GetPromptResult, Prompt, ResourceTemplate, TextResourceContents |
22 | 22 |
|
23 | 23 |
|
24 | 24 | class TestMCPServerManager: |
@@ -1606,6 +1606,104 @@ async def mock_call_tool(params): |
1606 | 1606 | # Verify the MCP client call was awaited exactly once |
1607 | 1607 | assert mock_client.call_tool.await_count == 1 |
1608 | 1608 |
|
| 1609 | + @pytest.mark.asyncio |
| 1610 | + async def test_get_allowed_mcp_servers_with_user_api_key_auth(self): |
| 1611 | + """ |
| 1612 | + Test that get_allowed_mcp_servers properly receives and uses user_api_key_auth |
| 1613 | + when called. This verifies the fix where user_api_key_auth is passed through |
| 1614 | + litellm_metadata from responses API. |
| 1615 | + """ |
| 1616 | + from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import ( |
| 1617 | + MCPRequestHandler, |
| 1618 | + ) |
| 1619 | + from litellm.proxy._types import LiteLLM_ObjectPermissionTable, UserAPIKeyAuth |
| 1620 | + |
| 1621 | + manager = MCPServerManager() |
| 1622 | + |
| 1623 | + # Create a mock user_api_key_auth with object_permission |
| 1624 | + object_permission = LiteLLM_ObjectPermissionTable( |
| 1625 | + object_permission_id="perm_123", |
| 1626 | + mcp_servers=["test_server_1", "test_server_2"], |
| 1627 | + mcp_access_groups=[], |
| 1628 | + ) |
| 1629 | + |
| 1630 | + user_api_key_auth = UserAPIKeyAuth( |
| 1631 | + api_key="sk-test", |
| 1632 | + user_id="user-123", |
| 1633 | + object_permission=object_permission, |
| 1634 | + object_permission_id="perm_123", |
| 1635 | + ) |
| 1636 | + |
| 1637 | + # Mock MCPRequestHandler.get_allowed_mcp_servers to verify it receives user_api_key_auth |
| 1638 | + with patch.object( |
| 1639 | + MCPRequestHandler, |
| 1640 | + "get_allowed_mcp_servers", |
| 1641 | + new_callable=AsyncMock, |
| 1642 | + ) as mock_get_allowed: |
| 1643 | + # Configure mock to return servers from object_permission |
| 1644 | + mock_get_allowed.return_value = ["test_server_1", "test_server_2"] |
| 1645 | + |
| 1646 | + # Call get_allowed_mcp_servers with user_api_key_auth |
| 1647 | + result = await manager.get_allowed_mcp_servers(user_api_key_auth) |
| 1648 | + |
| 1649 | + # Verify MCPRequestHandler.get_allowed_mcp_servers was called with user_api_key_auth |
| 1650 | + mock_get_allowed.assert_called_once() |
| 1651 | + call_args = mock_get_allowed.call_args |
| 1652 | + assert call_args[0][0] is user_api_key_auth # First positional arg should be user_api_key_auth |
| 1653 | + assert call_args[0][0].user_id == "user-123" |
| 1654 | + assert call_args[0][0].object_permission_id == "perm_123" |
| 1655 | + assert call_args[0][0].object_permission is not None |
| 1656 | + assert call_args[0][0].object_permission.mcp_servers == ["test_server_1", "test_server_2"] |
| 1657 | + |
| 1658 | + # Verify result contains the expected servers |
| 1659 | + assert "test_server_1" in result |
| 1660 | + assert "test_server_2" in result |
| 1661 | + |
| 1662 | + def test_get_mcp_server_from_tool_name_uses_server_name_not_name(self): |
| 1663 | + """ |
| 1664 | + Test that _get_mcp_server_from_tool_name uses server.server_name instead of server.name |
| 1665 | + when extracting server name from prefixed tool name (second case). |
| 1666 | + This ensures the fix for using server_name instead of name works correctly. |
| 1667 | + """ |
| 1668 | + from litellm.proxy._experimental.mcp_server.utils import ( |
| 1669 | + add_server_prefix_to_name, |
| 1670 | + ) |
| 1671 | + |
| 1672 | + manager = MCPServerManager() |
| 1673 | + |
| 1674 | + # Create a server where server_name differs from name |
| 1675 | + # This tests the scenario where server.name != server.server_name |
| 1676 | + server = MCPServer( |
| 1677 | + server_id="test-server-id", |
| 1678 | + name="Test Server Name", # Different from server_name |
| 1679 | + server_name="test_server", # This is what should be used |
| 1680 | + alias="test_server", |
| 1681 | + transport=MCPTransport.http, |
| 1682 | + ) |
| 1683 | + |
| 1684 | + # Register the server |
| 1685 | + manager.registry = {server.server_id: server} |
| 1686 | + |
| 1687 | + # Create a tool with prefixed name |
| 1688 | + tool_name = "test_tool" |
| 1689 | + prefixed_tool_name = add_server_prefix_to_name(tool_name, "test_server") |
| 1690 | + |
| 1691 | + # Populate the mapping with the original tool name |
| 1692 | + manager.tool_name_to_mcp_server_name_mapping[tool_name] = "test_server" |
| 1693 | + manager.tool_name_to_mcp_server_name_mapping[prefixed_tool_name] = "test_server" |
| 1694 | + |
| 1695 | + # Test: _get_mcp_server_from_tool_name should find the server using server.server_name |
| 1696 | + # even when server.name is different |
| 1697 | + resolved_server = manager._get_mcp_server_from_tool_name(prefixed_tool_name) |
| 1698 | + |
| 1699 | + # Verify the server was found correctly |
| 1700 | + assert resolved_server is not None |
| 1701 | + assert resolved_server.server_id == server.server_id |
| 1702 | + assert resolved_server.server_name == "test_server" |
| 1703 | + # Verify it matched using server_name, not name |
| 1704 | + assert resolved_server.name == "Test Server Name" # name is different |
| 1705 | + assert resolved_server.server_name == "test_server" # server_name matches |
| 1706 | + |
1609 | 1707 |
|
1610 | 1708 | if __name__ == "__main__": |
1611 | 1709 | pytest.main([__file__]) |
0 commit comments