Skip to content

Commit a439239

Browse files
AntSan813KRRT7
authored andcommitted
Rename MCPServer sse_read_timeout to read_timeout and pass to ClientSession (pydantic#2240)
1 parent 9c5c529 commit a439239

File tree

2 files changed

+74
-15
lines changed

2 files changed

+74
-15
lines changed

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
import base64
44
import functools
5+
import warnings
56
from abc import ABC, abstractmethod
67
from asyncio import Lock
78
from collections.abc import AsyncIterator, Awaitable, Sequence
89
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
910
from dataclasses import dataclass, field, replace
11+
from datetime import timedelta
1012
from pathlib import Path
1113
from typing import Any, Callable
1214

@@ -37,7 +39,7 @@
3739
) from _import_error
3840

3941
# after mcp imports so any import error maps to this file, not _mcp.py
40-
from . import _mcp, exceptions, messages, models
42+
from . import _mcp, _utils, exceptions, messages, models
4143

4244
__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
4345

@@ -59,6 +61,7 @@ class MCPServer(AbstractToolset[Any], ABC):
5961
log_level: mcp_types.LoggingLevel | None = None
6062
log_handler: LoggingFnT | None = None
6163
timeout: float = 5
64+
read_timeout: float = 5 * 60
6265
process_tool_call: ProcessToolCallback | None = None
6366
allow_sampling: bool = True
6467
max_retries: int = 1
@@ -208,6 +211,7 @@ async def __aenter__(self) -> Self:
208211
write_stream=self._write_stream,
209212
sampling_callback=self._sampling_callback if self.allow_sampling else None,
210213
logging_callback=self.log_handler,
214+
read_timeout_seconds=timedelta(seconds=self.read_timeout),
211215
)
212216
self._client = await self._exit_stack.enter_async_context(client)
213217

@@ -401,7 +405,7 @@ def __repr__(self) -> str:
401405
return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'
402406

403407

404-
@dataclass
408+
@dataclass(init=False)
405409
class _MCPServerHTTP(MCPServer):
406410
url: str
407411
"""The URL of the endpoint on the MCP server."""
@@ -438,10 +442,10 @@ class _MCPServerHTTP(MCPServer):
438442
```
439443
"""
440444

441-
sse_read_timeout: float = 5 * 60
442-
"""Maximum time in seconds to wait for new SSE messages before timing out.
445+
read_timeout: float = 5 * 60
446+
"""Maximum time in seconds to wait for new messages before timing out.
443447
444-
This timeout applies to the long-lived SSE connection after it's established.
448+
This timeout applies to the long-lived connection after it's established.
445449
If no new messages are received within this time, the connection will be considered stale
446450
and may be closed. Defaults to 5 minutes (300 seconds).
447451
"""
@@ -485,6 +489,51 @@ class _MCPServerHTTP(MCPServer):
485489
sampling_model: models.Model | None = None
486490
"""The model to use for sampling."""
487491

492+
def __init__(
493+
self,
494+
*,
495+
url: str,
496+
headers: dict[str, str] | None = None,
497+
http_client: httpx.AsyncClient | None = None,
498+
read_timeout: float | None = None,
499+
tool_prefix: str | None = None,
500+
log_level: mcp_types.LoggingLevel | None = None,
501+
log_handler: LoggingFnT | None = None,
502+
timeout: float = 5,
503+
process_tool_call: ProcessToolCallback | None = None,
504+
allow_sampling: bool = True,
505+
max_retries: int = 1,
506+
sampling_model: models.Model | None = None,
507+
**kwargs: Any,
508+
):
509+
# Handle deprecated sse_read_timeout parameter
510+
if 'sse_read_timeout' in kwargs:
511+
if read_timeout is not None:
512+
raise TypeError("'read_timeout' and 'sse_read_timeout' cannot be set at the same time.")
513+
514+
warnings.warn(
515+
"'sse_read_timeout' is deprecated, use 'read_timeout' instead.", DeprecationWarning, stacklevel=2
516+
)
517+
read_timeout = kwargs.pop('sse_read_timeout')
518+
519+
_utils.validate_empty_kwargs(kwargs)
520+
521+
if read_timeout is None:
522+
read_timeout = 5 * 60
523+
524+
self.url = url
525+
self.headers = headers
526+
self.http_client = http_client
527+
self.tool_prefix = tool_prefix
528+
self.log_level = log_level
529+
self.log_handler = log_handler
530+
self.timeout = timeout
531+
self.process_tool_call = process_tool_call
532+
self.allow_sampling = allow_sampling
533+
self.max_retries = max_retries
534+
self.sampling_model = sampling_model
535+
self.read_timeout = read_timeout
536+
488537
@property
489538
@abstractmethod
490539
def _transport_client(
@@ -522,7 +571,7 @@ async def client_streams(
522571
self._transport_client,
523572
url=self.url,
524573
timeout=self.timeout,
525-
sse_read_timeout=self.sse_read_timeout,
574+
sse_read_timeout=self.read_timeout,
526575
)
527576

528577
if self.http_client is not None:
@@ -549,7 +598,7 @@ def __repr__(self) -> str: # pragma: no cover
549598
return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
550599

551600

552-
@dataclass
601+
@dataclass(init=False)
553602
class MCPServerSSE(_MCPServerHTTP):
554603
"""An MCP server that connects over streamable HTTP connections.
555604

tests/test_mcp.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,30 @@ def test_sse_server():
140140

141141

142142
def test_sse_server_with_header_and_timeout():
143-
sse_server = MCPServerSSE(
144-
url='http://localhost:8000/sse',
145-
headers={'my-custom-header': 'my-header-value'},
146-
timeout=10,
147-
sse_read_timeout=100,
148-
log_level='info',
149-
)
143+
with pytest.warns(DeprecationWarning, match="'sse_read_timeout' is deprecated, use 'read_timeout' instead."):
144+
sse_server = MCPServerSSE(
145+
url='http://localhost:8000/sse',
146+
headers={'my-custom-header': 'my-header-value'},
147+
timeout=10,
148+
sse_read_timeout=100,
149+
log_level='info',
150+
)
150151
assert sse_server.url == 'http://localhost:8000/sse'
151152
assert sse_server.headers is not None and sse_server.headers['my-custom-header'] == 'my-header-value'
152153
assert sse_server.timeout == 10
153-
assert sse_server.sse_read_timeout == 100
154+
assert sse_server.read_timeout == 100
154155
assert sse_server.log_level == 'info'
155156

156157

158+
def test_sse_server_conflicting_timeout_params():
159+
with pytest.raises(TypeError, match="'read_timeout' and 'sse_read_timeout' cannot be set at the same time."):
160+
MCPServerSSE(
161+
url='http://localhost:8000/sse',
162+
read_timeout=50,
163+
sse_read_timeout=100,
164+
)
165+
166+
157167
@pytest.mark.vcr()
158168
async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent):
159169
async with agent:

0 commit comments

Comments
 (0)