diff --git a/faststream/app.py b/faststream/app.py index 67d9be7600..1dbf314672 100644 --- a/faststream/app.py +++ b/faststream/app.py @@ -9,10 +9,8 @@ TypeVar, ) -import anyio from typing_extensions import ParamSpec -from faststream._compat import ExceptionGroup from faststream._internal.application import Application from faststream.asgi.app import AsgiFastStream from faststream.cli.supervisors.utils import set_exit @@ -42,15 +40,9 @@ async def run( async with catch_startup_validation_error(), self.lifespan_context( **(run_extra_options or {}) ): - try: - async with anyio.create_task_group() as tg: - tg.start_soon(self._startup, log_level, run_extra_options) - await self._main_loop(sleep_time) - await self._shutdown(log_level) - tg.cancel_scope.cancel() - except ExceptionGroup as e: - for ex in e.exceptions: - raise ex from None + await self._startup(log_level, run_extra_options) + await self._main_loop(sleep_time) + await self._shutdown(log_level) def as_asgi( self, diff --git a/faststream/asgi/app.py b/faststream/asgi/app.py index c42f3c0291..c2818cdfe5 100644 --- a/faststream/asgi/app.py +++ b/faststream/asgi/app.py @@ -7,14 +7,13 @@ Any, AsyncIterator, Dict, + Literal, Optional, Sequence, Tuple, Union, ) -import anyio - from faststream._compat import HAS_UVICORN, uvicorn from faststream._internal.application import Application from faststream.exceptions import INSTALL_UVICORN @@ -182,17 +181,13 @@ async def run( await server.serve() @asynccontextmanager - async def start_lifespan_context(self) -> AsyncIterator[None]: - async with anyio.create_task_group() as tg, self.lifespan_context( - **self._run_extra_options - ): - tg.start_soon(self._startup, self._log_level, self._run_extra_options) - + async def start_lifespan_context(self) -> AsyncIterator[Literal[True]]: + async with self.lifespan_context(**self._run_extra_options): try: - yield + await self._startup(self._log_level, self._run_extra_options) + yield True finally: await self._shutdown() - tg.cancel_scope.cancel() async def lifespan(self, scope: "Scope", receive: "Receive", send: "Send") -> None: """Handle ASGI lifespan messages to start and shutdown the app.""" @@ -200,9 +195,8 @@ async def lifespan(self, scope: "Scope", receive: "Receive", send: "Send") -> No await receive() # handle `lifespan.startup` event try: - async with self.start_lifespan_context(): + async with self.start_lifespan_context() as started: await send({"type": "lifespan.startup.complete"}) - started = True await receive() # handle `lifespan.shutdown` event except BaseException: diff --git a/tests/cli/rabbit/test_app.py b/tests/cli/rabbit/test_app.py index 028fbc0224..a22a1cac72 100644 --- a/tests/cli/rabbit/test_app.py +++ b/tests/cli/rabbit/test_app.py @@ -9,6 +9,7 @@ from faststream import FastStream, TestApp from faststream._compat import IS_WINDOWS +from faststream.asgi.app import AsgiFastStream from faststream.log import logger from faststream.rabbit.testing import TestRabbitBroker @@ -406,5 +407,72 @@ async def test_run_asgi(async_mock: AsyncMock, app: FastStream): async_mock.broker_run.assert_called_once() +@pytest.mark.asyncio +@pytest.mark.skipif(IS_WINDOWS, reason="does not run on windows") +@pytest.mark.parametrize( + ("failure_type"), + [ + pytest.param( + "startup", + id="startup hook failure", + ), + pytest.param( + "shutdown", + id="shutdown hook failure", + ), + pytest.param( + "lifespan_start", + id="lifespan start failure", + ), + pytest.param( + "lifespan_shutdown", + id="lifespan shutdown failure", + ), + ], +) +async def test_lifespan_exceptions(failure_type: str, async_mock: AsyncMock, broker): + @asynccontextmanager + async def lifespan(): + if f"{failure_type}" == "lifespan_start": + raise ValueError(f"Failure during {failure_type}") + yield + if f"{failure_type}" == "lifespan_shutdown": + raise ValueError(f"Failure during {failure_type}") + + app = AsgiFastStream(broker, lifespan=lifespan) + + @app.on_startup + async def start(): + if f"{failure_type}" == "startup": + raise ValueError(f"Failure during {failure_type}") + + @app.on_shutdown + async def shutdown(): + if f"{failure_type}" == "shutdown": + raise ValueError(f"Failure during {failure_type}") + + # use uvicorn directly instead of app.run since access to the server instance is needed + with patch.object(app.broker, "start", async_mock.broker_run), patch.object( + app.broker, "stop", async_mock.broker_stopped + ): + import uvicorn + + server = uvicorn.Server(uvicorn.Config(app=app)) + try: + # if startup succeeds, serve blocks forever in main loop. Hence, we cancel the task + # but need to handle the shutdown manually as cancelling does not trigger the shutdown. + with anyio.fail_after(0.1): + await server.serve() + except TimeoutError: + await server.shutdown() + + assert server.lifespan.should_exit is True + assert server.lifespan.error_occured is True + if failure_type in ["startup", "lifespan_start"]: + assert server.lifespan.startup_failed is True + if failure_type in ["shutdown", "lifespan_shutdown"]: + assert server.lifespan.shutdown_failed is True + + async def _kill(sig): os.kill(os.getpid(), sig) diff --git a/tests/cli/test_run.py b/tests/cli/test_run.py index 323cd5471e..7bc9e19930 100644 --- a/tests/cli/test_run.py +++ b/tests/cli/test_run.py @@ -15,12 +15,11 @@ def test_run( generate_template: GenerateTemplateFactory, faststream_cli: FastStreamCLIFactory ) -> None: app_code = """ - from faststream import FastStream - from faststream.nats import NatsBroker + from unittest.mock import AsyncMock - broker = NatsBroker() + from faststream import FastStream - app = FastStream(broker) + app = FastStream(AsyncMock()) """ with generate_template(app_code) as app_path, faststream_cli( [ @@ -42,13 +41,15 @@ def test_run_asgi( ) -> None: app_code = """ import json + from contextlib import asynccontextmanager from faststream import FastStream - from faststream.nats import NatsBroker from faststream.asgi import AsgiResponse, get + from faststream.nats import NatsBroker, TestNatsBroker broker = NatsBroker() + @get async def liveness_ping(scope): return AsgiResponse(b"hello world", status_code=200) @@ -56,19 +57,32 @@ async def liveness_ping(scope): CONTEXT = {} + @get async def context(scope): return AsgiResponse(json.dumps(CONTEXT).encode(), status_code=200) - app = FastStream(broker).as_asgi( - asgi_routes=[ - ("/liveness", liveness_ping), - ("/context", context) - ], + # must use broker implementation to generate the docs + # but cannot connect to it, hence we patch it + test_broker = TestNatsBroker(broker) + + + @asynccontextmanager + async def lifespan(): + async with test_broker: + yield + + + app = FastStream( + broker, + lifespan=lifespan, + ).as_asgi( + asgi_routes=[("/liveness", liveness_ping), ("/context", context)], asyncapi_path="/docs", ) + @app.on_startup async def start(test: int, port: int): CONTEXT["test"] = test @@ -113,16 +127,15 @@ def test_run_as_asgi_with_single_worker( generate_template: GenerateTemplateFactory, faststream_cli: FastStreamCLIFactory ) -> None: app_code = """ - from faststream.asgi import AsgiFastStream, AsgiResponse, get - from faststream.nats import NatsBroker + from unittest.mock import AsyncMock - broker = NatsBroker() + from faststream.asgi import AsgiFastStream, AsgiResponse, get @get async def liveness_ping(scope): return AsgiResponse(b"hello world", status_code=200) - app = AsgiFastStream(broker, asgi_routes=[ + app = AsgiFastStream(AsyncMock(), asgi_routes=[ ("/liveness", liveness_ping), ]) """ @@ -148,12 +161,11 @@ def test_run_as_asgi_with_many_workers( workers: int, ) -> None: app_code = """ - from faststream.asgi import AsgiFastStream - from faststream.nats import NatsBroker + from unittest.mock import AsyncMock - broker = NatsBroker() + from faststream.asgi import AsgiFastStream - app = AsgiFastStream(broker) + app = AsgiFastStream(AsyncMock()) """ with generate_template(app_code) as app_path, faststream_cli( @@ -194,14 +206,12 @@ def test_run_as_asgi_mp_with_log_level( ) -> None: app_code = """ import logging + from unittest.mock import AsyncMock from faststream.asgi import AsgiFastStream from faststream.log.logging import logger - from faststream.nats import NatsBroker - - broker = NatsBroker() - app = AsgiFastStream(broker) + app = AsgiFastStream(AsyncMock()) @app.on_startup def print_log_level(): @@ -233,17 +243,16 @@ def test_run_as_factory( generate_template: GenerateTemplateFactory, faststream_cli: FastStreamCLIFactory ) -> None: app_code = """ - from faststream.asgi import AsgiFastStream, AsgiResponse, get - from faststream.nats import NatsBroker + from unittest.mock import AsyncMock - broker = NatsBroker() + from faststream.asgi import AsgiFastStream, AsgiResponse, get @get async def liveness_ping(scope): return AsgiResponse(b"hello world", status_code=200) def app_factory(): - return AsgiFastStream(broker, asgi_routes=[ + return AsgiFastStream(AsyncMock(), asgi_routes=[ ("/liveness", liveness_ping), ]) """ @@ -319,14 +328,12 @@ def test_run_as_asgi_with_log_config( ) -> None: app_code = """ import logging + from unittest.mock import AsyncMock from faststream.asgi import AsgiFastStream from faststream.log.logging import logger - from faststream.nats import NatsBroker - - broker = NatsBroker() - app = AsgiFastStream(broker) + app = AsgiFastStream(AsyncMock()) @app.on_startup def print_log_level():