diff --git a/examples/foobar.py b/examples/foobar.py index 491d51c..81f85f2 100644 --- a/examples/foobar.py +++ b/examples/foobar.py @@ -69,7 +69,7 @@ async def app(scope, receive, send): if __name__ == "__main__": - s = Server(10, CustomInterceptor()) + s = Server(max_workers=10, interceptors=[CustomInterceptor()]) s.run( FooService(), BarService(), diff --git a/src/pydantic_rpc/core.py b/src/pydantic_rpc/core.py index 97677b0..f9c6247 100644 --- a/src/pydantic_rpc/core.py +++ b/src/pydantic_rpc/core.py @@ -11,7 +11,7 @@ import sys import time import types -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence from concurrent import futures from connectrpc.code import Code as Errors # Protobuf Python modules for Timestamp, Duration (requires protobuf / grpcio) @@ -2611,11 +2611,20 @@ def __init__( port: int = 50051, package_name: str = "", max_workers: int = 8, - *interceptors: Any, tls: Optional["GrpcTLSConfig"] = None, + interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, + handlers: Optional[Sequence[grpc.GenericRpcHandler]] = None, + options: Optional[Sequence[Tuple[str, Any]]] = None, + maximum_concurrent_rpcs: Optional[int] = None, + compression: Optional[grpc.Compression] = None, ) -> None: self._server: grpc.Server = grpc.server( - futures.ThreadPoolExecutor(max_workers), interceptors=interceptors + futures.ThreadPoolExecutor(max_workers), + handlers=handlers, + interceptors=interceptors, + options=options, + maximum_concurrent_rpcs=maximum_concurrent_rpcs, + compression=compression, ) self._service_names: list[str] = [] self._package_name: str = package_name diff --git a/tests/test_enhanced_api.py b/tests/test_enhanced_api.py index ddb0459..76f2267 100644 --- a/tests/test_enhanced_api.py +++ b/tests/test_enhanced_api.py @@ -195,3 +195,29 @@ async def calculate(self, request: SampleRequest) -> SampleResponse: # Test that the exception is still raised (decorator doesn't catch) with pytest.raises(ValueError): await service.calculate(SampleRequest(value="zero")) + + +def test_server_production_parameters(): + """Test Server constructor with production-ready parameters.""" + # Test with production parameters + options = [ + ('grpc.keepalive_time_ms', 10000), + ('grpc.keepalive_timeout_ms', 5000), + ('grpc.keepalive_permit_without_calls', True), + ] + + server = Server( + max_workers=10, + options=options, + maximum_concurrent_rpcs=100, + compression=grpc.Compression.Gzip, + ) + + assert server._server is not None + assert server._port == 50051 # default port + + # Test with minimal parameters + server2 = Server( + max_workers=5, + ) + assert server2._server is not None