|
1 | 1 | import json |
2 | 2 | import logging |
3 | 3 | from contextlib import AsyncExitStack |
| 4 | +from datetime import timedelta |
4 | 5 | from pathlib import Path |
5 | | -from typing import TYPE_CHECKING, AsyncIterable, Dict, List, Optional, Union |
| 6 | +from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Union, overload |
6 | 7 |
|
7 | | -from typing_extensions import TypeAlias |
| 8 | +from typing_extensions import NotRequired, TypeAlias, TypedDict, Unpack |
8 | 9 |
|
9 | 10 | from ...utils._runtime import get_hf_hub_version |
10 | 11 | from .._generated._async_client import AsyncInferenceClient |
|
26 | 27 | # Type alias for tool names |
27 | 28 | ToolName: TypeAlias = str |
28 | 29 |
|
| 30 | +ServerType: TypeAlias = Literal["stdio", "sse", "http"] |
| 31 | + |
| 32 | + |
| 33 | +class StdioServerParameters_T(TypedDict): |
| 34 | + command: str |
| 35 | + args: NotRequired[List[str]] |
| 36 | + env: NotRequired[Dict[str, str]] |
| 37 | + cwd: NotRequired[Union[str, Path, None]] |
| 38 | + |
| 39 | + |
| 40 | +class SSEServerParameters_T(TypedDict): |
| 41 | + url: str |
| 42 | + headers: NotRequired[Dict[str, Any]] |
| 43 | + timeout: NotRequired[float] |
| 44 | + sse_read_timeout: NotRequired[float] |
| 45 | + |
| 46 | + |
| 47 | +class StreamableHTTPParameters_T(TypedDict): |
| 48 | + url: str |
| 49 | + headers: NotRequired[dict[str, Any]] |
| 50 | + timeout: NotRequired[timedelta] |
| 51 | + sse_read_timeout: NotRequired[timedelta] |
| 52 | + terminate_on_close: NotRequired[bool] |
| 53 | + |
29 | 54 |
|
30 | 55 | class MCPClient: |
31 | 56 | """ |
@@ -64,39 +89,84 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): |
64 | 89 | await self.client.__aexit__(exc_type, exc_val, exc_tb) |
65 | 90 | await self.cleanup() |
66 | 91 |
|
67 | | - async def add_mcp_server( |
68 | | - self, |
69 | | - *, |
70 | | - command: str, |
71 | | - args: Optional[List[str]] = None, |
72 | | - env: Optional[Dict[str, str]] = None, |
73 | | - cwd: Union[str, Path, None] = None, |
74 | | - ): |
| 92 | + @overload |
| 93 | + async def add_mcp_server(self, type: Literal["stdio"], **params: Unpack[StdioServerParameters_T]): ... |
| 94 | + |
| 95 | + @overload |
| 96 | + async def add_mcp_server(self, type: Literal["sse"], **params: Unpack[SSEServerParameters_T]): ... |
| 97 | + |
| 98 | + @overload |
| 99 | + async def add_mcp_server(self, type: Literal["http"], **params: Unpack[StreamableHTTPParameters_T]): ... |
| 100 | + |
| 101 | + async def add_mcp_server(self, type: ServerType, **params: Any): |
75 | 102 | """Connect to an MCP server |
76 | 103 |
|
77 | 104 | Args: |
78 | | - command (str): |
79 | | - The command to run the MCP server. |
80 | | - args (List[str], optional): |
81 | | - Arguments for the command. |
82 | | - env (Dict[str, str], optional): |
83 | | - Environment variables for the command. Default is to inherit the parent environment. |
84 | | - cwd (Union[str, Path, None], optional): |
85 | | - Working directory for the command. Default to current directory. |
| 105 | + type (`str`): |
| 106 | + Type of the server to connect to. Can be one of: |
| 107 | + - "stdio": Standard input/output server (local) |
| 108 | + - "sse": Server-sent events (SSE) server |
| 109 | + - "http": StreamableHTTP server |
| 110 | + **params: Server parameters that can be either: |
| 111 | + - For stdio servers: |
| 112 | + - command (str): The command to run the MCP server |
| 113 | + - args (List[str], optional): Arguments for the command |
| 114 | + - env (Dict[str, str], optional): Environment variables for the command |
| 115 | + - cwd (Union[str, Path, None], optional): Working directory for the command |
| 116 | + - For SSE servers: |
| 117 | + - url (str): The URL of the SSE server |
| 118 | + - headers (Dict[str, Any], optional): Headers for the SSE connection |
| 119 | + - timeout (float, optional): Connection timeout |
| 120 | + - sse_read_timeout (float, optional): SSE read timeout |
| 121 | + - For StreamableHTTP servers: |
| 122 | + - url (str): The URL of the StreamableHTTP server |
| 123 | + - headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection |
| 124 | + - timeout (timedelta, optional): Connection timeout |
| 125 | + - sse_read_timeout (timedelta, optional): SSE read timeout |
| 126 | + - terminate_on_close (bool, optional): Whether to terminate on close |
86 | 127 | """ |
87 | 128 | from mcp import ClientSession, StdioServerParameters |
88 | 129 | from mcp import types as mcp_types |
89 | | - from mcp.client.stdio import stdio_client |
90 | | - |
91 | | - logger.info(f"Connecting to MCP server with command: {command} {args}") |
92 | | - server_params = StdioServerParameters( |
93 | | - command=command, |
94 | | - args=args if args is not None else [], |
95 | | - env=env, |
96 | | - cwd=cwd, |
97 | | - ) |
98 | 130 |
|
99 | | - read, write = await self.exit_stack.enter_async_context(stdio_client(server_params)) |
| 131 | + # Determine server type and create appropriate parameters |
| 132 | + if type == "stdio": |
| 133 | + # Handle stdio server |
| 134 | + from mcp.client.stdio import stdio_client |
| 135 | + |
| 136 | + logger.info(f"Connecting to stdio MCP server with command: {params['command']} {params.get('args', [])}") |
| 137 | + |
| 138 | + client_kwargs = {"command": params["command"]} |
| 139 | + for key in ["args", "env", "cwd"]: |
| 140 | + if params.get(key) is not None: |
| 141 | + client_kwargs[key] = params[key] |
| 142 | + server_params = StdioServerParameters(**client_kwargs) |
| 143 | + read, write = await self.exit_stack.enter_async_context(stdio_client(server_params)) |
| 144 | + elif type == "sse": |
| 145 | + # Handle SSE server |
| 146 | + from mcp.client.sse import sse_client |
| 147 | + |
| 148 | + logger.info(f"Connecting to SSE MCP server at: {params['url']}") |
| 149 | + |
| 150 | + client_kwargs = {"url": params["url"]} |
| 151 | + for key in ["headers", "timeout", "sse_read_timeout"]: |
| 152 | + if params.get(key) is not None: |
| 153 | + client_kwargs[key] = params[key] |
| 154 | + read, write = await self.exit_stack.enter_async_context(sse_client(**client_kwargs)) |
| 155 | + elif type == "http": |
| 156 | + # Handle StreamableHTTP server |
| 157 | + from mcp.client.streamable_http import streamablehttp_client |
| 158 | + |
| 159 | + logger.info(f"Connecting to StreamableHTTP MCP server at: {params['url']}") |
| 160 | + |
| 161 | + client_kwargs = {"url": params["url"]} |
| 162 | + for key in ["headers", "timeout", "sse_read_timeout", "terminate_on_close"]: |
| 163 | + if params.get(key) is not None: |
| 164 | + client_kwargs[key] = params[key] |
| 165 | + read, write, _ = await self.exit_stack.enter_async_context(streamablehttp_client(**client_kwargs)) |
| 166 | + # ^ TODO: should be handle `get_session_id_callback`? (function to retrieve the current session ID) |
| 167 | + else: |
| 168 | + raise ValueError(f"Unsupported server type: {type}") |
| 169 | + |
100 | 170 | session = await self.exit_stack.enter_async_context( |
101 | 171 | ClientSession( |
102 | 172 | read_stream=read, |
|
0 commit comments