2
2
3
3
import base64
4
4
import functools
5
+ import warnings
5
6
from abc import ABC , abstractmethod
6
7
from asyncio import Lock
7
8
from collections .abc import AsyncIterator , Awaitable , Sequence
8
9
from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
9
10
from dataclasses import dataclass , field , replace
11
+ from datetime import timedelta
10
12
from pathlib import Path
11
13
from typing import Any , Callable
12
14
37
39
) from _import_error
38
40
39
41
# after mcp imports so any import error maps to this file, not _mcp.py
40
- from . import _mcp , exceptions , messages , models
42
+ from . import _mcp , _utils , exceptions , messages , models
41
43
42
44
__all__ = 'MCPServer' , 'MCPServerStdio' , 'MCPServerHTTP' , 'MCPServerSSE' , 'MCPServerStreamableHTTP'
43
45
@@ -59,6 +61,7 @@ class MCPServer(AbstractToolset[Any], ABC):
59
61
log_level : mcp_types .LoggingLevel | None = None
60
62
log_handler : LoggingFnT | None = None
61
63
timeout : float = 5
64
+ read_timeout : float = 5 * 60
62
65
process_tool_call : ProcessToolCallback | None = None
63
66
allow_sampling : bool = True
64
67
max_retries : int = 1
@@ -208,6 +211,7 @@ async def __aenter__(self) -> Self:
208
211
write_stream = self ._write_stream ,
209
212
sampling_callback = self ._sampling_callback if self .allow_sampling else None ,
210
213
logging_callback = self .log_handler ,
214
+ read_timeout_seconds = timedelta (seconds = self .read_timeout ),
211
215
)
212
216
self ._client = await self ._exit_stack .enter_async_context (client )
213
217
@@ -401,7 +405,7 @@ def __repr__(self) -> str:
401
405
return f'MCPServerStdio(command={ self .command !r} , args={ self .args !r} , tool_prefix={ self .tool_prefix !r} )'
402
406
403
407
404
- @dataclass
408
+ @dataclass ( init = False )
405
409
class _MCPServerHTTP (MCPServer ):
406
410
url : str
407
411
"""The URL of the endpoint on the MCP server."""
@@ -438,10 +442,10 @@ class _MCPServerHTTP(MCPServer):
438
442
```
439
443
"""
440
444
441
- sse_read_timeout : float = 5 * 60
442
- """Maximum time in seconds to wait for new SSE messages before timing out.
445
+ read_timeout : float = 5 * 60
446
+ """Maximum time in seconds to wait for new messages before timing out.
443
447
444
- This timeout applies to the long-lived SSE connection after it's established.
448
+ This timeout applies to the long-lived connection after it's established.
445
449
If no new messages are received within this time, the connection will be considered stale
446
450
and may be closed. Defaults to 5 minutes (300 seconds).
447
451
"""
@@ -485,6 +489,51 @@ class _MCPServerHTTP(MCPServer):
485
489
sampling_model : models .Model | None = None
486
490
"""The model to use for sampling."""
487
491
492
+ def __init__ (
493
+ self ,
494
+ * ,
495
+ url : str ,
496
+ headers : dict [str , str ] | None = None ,
497
+ http_client : httpx .AsyncClient | None = None ,
498
+ read_timeout : float | None = None ,
499
+ tool_prefix : str | None = None ,
500
+ log_level : mcp_types .LoggingLevel | None = None ,
501
+ log_handler : LoggingFnT | None = None ,
502
+ timeout : float = 5 ,
503
+ process_tool_call : ProcessToolCallback | None = None ,
504
+ allow_sampling : bool = True ,
505
+ max_retries : int = 1 ,
506
+ sampling_model : models .Model | None = None ,
507
+ ** kwargs : Any ,
508
+ ):
509
+ # Handle deprecated sse_read_timeout parameter
510
+ if 'sse_read_timeout' in kwargs :
511
+ if read_timeout is not None :
512
+ raise TypeError ("'read_timeout' and 'sse_read_timeout' cannot be set at the same time." )
513
+
514
+ warnings .warn (
515
+ "'sse_read_timeout' is deprecated, use 'read_timeout' instead." , DeprecationWarning , stacklevel = 2
516
+ )
517
+ read_timeout = kwargs .pop ('sse_read_timeout' )
518
+
519
+ _utils .validate_empty_kwargs (kwargs )
520
+
521
+ if read_timeout is None :
522
+ read_timeout = 5 * 60
523
+
524
+ self .url = url
525
+ self .headers = headers
526
+ self .http_client = http_client
527
+ self .tool_prefix = tool_prefix
528
+ self .log_level = log_level
529
+ self .log_handler = log_handler
530
+ self .timeout = timeout
531
+ self .process_tool_call = process_tool_call
532
+ self .allow_sampling = allow_sampling
533
+ self .max_retries = max_retries
534
+ self .sampling_model = sampling_model
535
+ self .read_timeout = read_timeout
536
+
488
537
@property
489
538
@abstractmethod
490
539
def _transport_client (
@@ -522,7 +571,7 @@ async def client_streams(
522
571
self ._transport_client ,
523
572
url = self .url ,
524
573
timeout = self .timeout ,
525
- sse_read_timeout = self .sse_read_timeout ,
574
+ sse_read_timeout = self .read_timeout ,
526
575
)
527
576
528
577
if self .http_client is not None :
@@ -549,7 +598,7 @@ def __repr__(self) -> str: # pragma: no cover
549
598
return f'{ self .__class__ .__name__ } (url={ self .url !r} , tool_prefix={ self .tool_prefix !r} )'
550
599
551
600
552
- @dataclass
601
+ @dataclass ( init = False )
553
602
class MCPServerSSE (_MCPServerHTTP ):
554
603
"""An MCP server that connects over streamable HTTP connections.
555
604
0 commit comments