Skip to content

Commit 71c1de6

Browse files
authored
starlette over fastapi (#5069)
* reintroduce python multipart for formdata * starlette over fastapi * fix the tests * simplify json response * use json response for all of these guys * add transformer * vendor types for future compatibility * pre-commit * we can actually just deprecate this guy * update uv lock * use api_transformer stuff * fix the tests * it's ruff out there
1 parent f3df526 commit 71c1de6

File tree

8 files changed

+196
-59
lines changed

8 files changed

+196
-59
lines changed

reflex/app.py

Lines changed: 125 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,26 @@
1313
import json
1414
import sys
1515
import traceback
16-
from collections.abc import AsyncIterator, Callable, Coroutine, MutableMapping
16+
from collections.abc import AsyncIterator, Callable, Coroutine, Sequence
1717
from datetime import datetime
1818
from pathlib import Path
1919
from timeit import default_timer as timer
2020
from types import SimpleNamespace
2121
from typing import TYPE_CHECKING, Any, BinaryIO, get_args, get_type_hints
2222

23-
from fastapi import FastAPI, HTTPException, Request
24-
from fastapi import UploadFile as FastAPIUploadFile
25-
from fastapi.middleware import cors
26-
from fastapi.responses import JSONResponse, StreamingResponse
27-
from fastapi.staticfiles import StaticFiles
23+
from fastapi import FastAPI
2824
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
29-
from socketio import ASGIApp, AsyncNamespace, AsyncServer
25+
from socketio import ASGIApp as EngineIOApp
26+
from socketio import AsyncNamespace, AsyncServer
27+
from starlette.applications import Starlette
3028
from starlette.datastructures import Headers
3129
from starlette.datastructures import UploadFile as StarletteUploadFile
30+
from starlette.exceptions import HTTPException
31+
from starlette.middleware import cors
32+
from starlette.requests import Request
33+
from starlette.responses import JSONResponse, Response, StreamingResponse
34+
from starlette.staticfiles import StaticFiles
35+
from typing_extensions import deprecated
3236

3337
from reflex import constants
3438
from reflex.admin import AdminDash
@@ -102,6 +106,7 @@
102106
)
103107
from reflex.utils.exec import get_compile_context, is_prod_mode, is_testing_env
104108
from reflex.utils.imports import ImportVar
109+
from reflex.utils.types import ASGIApp, Message, Receive, Scope, Send
105110

106111
if TYPE_CHECKING:
107112
from reflex.vars import Var
@@ -389,7 +394,7 @@ class App(MiddlewareMixin, LifespanMixin):
389394
_stateful_pages: dict[str, None] = dataclasses.field(default_factory=dict)
390395

391396
# The backend API object.
392-
_api: FastAPI | None = None
397+
_api: Starlette | None = None
393398

394399
# The state class to use for the app.
395400
_state: type[BaseState] | None = None
@@ -424,14 +429,34 @@ class App(MiddlewareMixin, LifespanMixin):
424429
# Put the toast provider in the app wrap.
425430
toaster: Component | None = dataclasses.field(default_factory=toast.provider)
426431

432+
# Transform the ASGI app before running it.
433+
api_transformer: (
434+
Sequence[Callable[[ASGIApp], ASGIApp] | Starlette]
435+
| Callable[[ASGIApp], ASGIApp]
436+
| Starlette
437+
| None
438+
) = None
439+
440+
# FastAPI app for compatibility with FastAPI.
441+
_cached_fastapi_app: FastAPI | None = None
442+
427443
@property
428-
def api(self) -> FastAPI | None:
444+
@deprecated("Use `api_transformer=your_fastapi_app` instead.")
445+
def api(self) -> FastAPI:
429446
"""Get the backend api.
430447
431448
Returns:
432449
The backend api.
433450
"""
434-
return self._api
451+
if self._cached_fastapi_app is None:
452+
self._cached_fastapi_app = FastAPI()
453+
console.deprecate(
454+
feature_name="App.api",
455+
reason="Set `api_transformer=your_fastapi_app` instead.",
456+
deprecation_version="0.7.9",
457+
removal_version="0.8.0",
458+
)
459+
return self._cached_fastapi_app
435460

436461
@property
437462
def event_namespace(self) -> EventNamespace | None:
@@ -463,7 +488,7 @@ def __post_init__(self):
463488
set_breakpoints(self.style.pop("breakpoints"))
464489

465490
# Set up the API.
466-
self._api = FastAPI(lifespan=self._run_lifespan_tasks)
491+
self._api = Starlette(lifespan=self._run_lifespan_tasks)
467492
self._add_cors()
468493
self._add_default_endpoints()
469494

@@ -529,7 +554,7 @@ def _setup_state(self) -> None:
529554
)
530555

531556
# Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
532-
socket_app = ASGIApp(self.sio, socketio_path="")
557+
socket_app = EngineIOApp(self.sio, socketio_path="")
533558
namespace = config.get_event_namespace()
534559

535560
# Create the event namespace and attach the main app. Not related to any paths.
@@ -538,18 +563,16 @@ def _setup_state(self) -> None:
538563
# Register the event namespace with the socket.
539564
self.sio.register_namespace(self.event_namespace)
540565
# Mount the socket app with the API.
541-
if self.api:
566+
if self._api:
542567

543568
class HeaderMiddleware:
544569
def __init__(self, app: ASGIApp):
545570
self.app = app
546571

547-
async def __call__(
548-
self, scope: MutableMapping[str, Any], receive: Any, send: Callable
549-
):
572+
async def __call__(self, scope: Scope, receive: Receive, send: Send):
550573
original_send = send
551574

552-
async def modified_send(message: dict):
575+
async def modified_send(message: Message):
553576
if message["type"] == "websocket.accept":
554577
if scope.get("subprotocols"):
555578
# The following *does* say "subprotocol" instead of "subprotocols", intentionally.
@@ -568,7 +591,7 @@ async def modified_send(message: dict):
568591
return await self.app(scope, receive, modified_send)
569592

570593
socket_app_with_headers = HeaderMiddleware(socket_app)
571-
self.api.mount(str(constants.Endpoint.EVENT), socket_app_with_headers)
594+
self._api.mount(str(constants.Endpoint.EVENT), socket_app_with_headers)
572595

573596
# Check the exception handlers
574597
self._validate_exception_handlers()
@@ -581,7 +604,7 @@ def __repr__(self) -> str:
581604
"""
582605
return f"<App state={self._state.__name__ if self._state else None}>"
583606

584-
def __call__(self) -> FastAPI:
607+
def __call__(self) -> ASGIApp:
585608
"""Run the backend api instance.
586609
587610
Raises:
@@ -590,8 +613,18 @@ def __call__(self) -> FastAPI:
590613
Returns:
591614
The backend api.
592615
"""
593-
if not self.api:
594-
raise ValueError("The app has not been initialized.")
616+
if self._cached_fastapi_app is not None:
617+
asgi_app = self._cached_fastapi_app
618+
619+
if not asgi_app or not self._api:
620+
raise ValueError("The app has not been initialized.")
621+
622+
asgi_app.mount("", self._api)
623+
else:
624+
asgi_app = self._api
625+
626+
if not asgi_app:
627+
raise ValueError("The app has not been initialized.")
595628

596629
# For py3.9 compatibility when redis is used, we MUST add any decorator pages
597630
# before compiling the app in a thread to avoid event loop error (REF-2172).
@@ -608,30 +641,58 @@ def __call__(self) -> FastAPI:
608641
if is_prod_mode():
609642
compile_future.result()
610643

611-
return self.api
644+
if self.api_transformer is not None:
645+
api_transformers: Sequence[Starlette | Callable[[ASGIApp], ASGIApp]] = (
646+
[self.api_transformer]
647+
if not isinstance(self.api_transformer, Sequence)
648+
else self.api_transformer
649+
)
650+
651+
for api_transformer in api_transformers:
652+
if isinstance(api_transformer, Starlette):
653+
# Mount the api to the fastapi app.
654+
api_transformer.mount("", asgi_app)
655+
asgi_app = api_transformer
656+
else:
657+
# Transform the asgi app.
658+
asgi_app = api_transformer(asgi_app)
659+
660+
return asgi_app
612661

613662
def _add_default_endpoints(self):
614663
"""Add default api endpoints (ping)."""
615664
# To test the server.
616-
if not self.api:
665+
if not self._api:
617666
return
618667

619-
self.api.get(str(constants.Endpoint.PING))(ping)
620-
self.api.get(str(constants.Endpoint.HEALTH))(health)
668+
self._api.add_route(
669+
str(constants.Endpoint.PING),
670+
ping,
671+
methods=["GET"],
672+
)
673+
self._api.add_route(
674+
str(constants.Endpoint.HEALTH),
675+
health,
676+
methods=["GET"],
677+
)
621678

622679
def _add_optional_endpoints(self):
623680
"""Add optional api endpoints (_upload)."""
624-
if not self.api:
681+
if not self._api:
625682
return
626683
upload_is_used_marker = (
627684
prerequisites.get_backend_dir() / constants.Dirs.UPLOAD_IS_USED
628685
)
629686
if Upload.is_used or upload_is_used_marker.exists():
630687
# To upload files.
631-
self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
688+
self._api.add_route(
689+
str(constants.Endpoint.UPLOAD),
690+
upload(self),
691+
methods=["POST"],
692+
)
632693

633694
# To access uploaded files.
634-
self.api.mount(
695+
self._api.mount(
635696
str(constants.Endpoint.UPLOAD),
636697
StaticFiles(directory=get_upload_dir()),
637698
name="uploaded_files",
@@ -640,17 +701,19 @@ def _add_optional_endpoints(self):
640701
upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True)
641702
upload_is_used_marker.touch()
642703
if codespaces.is_running_in_codespaces():
643-
self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
644-
codespaces.auth_codespace
704+
self._api.add_route(
705+
str(constants.Endpoint.AUTH_CODESPACE),
706+
codespaces.auth_codespace,
707+
methods=["GET"],
645708
)
646709
if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get():
647710
self.add_all_routes_endpoint()
648711

649712
def _add_cors(self):
650713
"""Add CORS middleware to the app."""
651-
if not self.api:
714+
if not self._api:
652715
return
653-
self.api.add_middleware(
716+
self._api.add_middleware(
654717
cors.CORSMiddleware,
655718
allow_credentials=True,
656719
allow_methods=["*"],
@@ -915,7 +978,7 @@ def _setup_admin_dash(self):
915978
return
916979

917980
# Get the admin dash.
918-
if not self.api:
981+
if not self._api:
919982
return
920983

921984
admin_dash = self.admin_dash
@@ -936,7 +999,7 @@ def _setup_admin_dash(self):
936999
view = admin_dash.view_overrides.get(model, ModelView)
9371000
admin.add_view(view(model))
9381001

939-
admin.mount_to(self.api)
1002+
admin.mount_to(self._api)
9401003

9411004
def _get_frontend_packages(self, imports: dict[str, set[ImportVar]]):
9421005
"""Gets the frontend packages to be installed and filters out the unnecessary ones.
@@ -1427,12 +1490,15 @@ def _write_stateful_pages_marker(self):
14271490

14281491
def add_all_routes_endpoint(self):
14291492
"""Add an endpoint to the app that returns all the routes."""
1430-
if not self.api:
1493+
if not self._api:
14311494
return
14321495

1433-
@self.api.get(str(constants.Endpoint.ALL_ROUTES))
1434-
async def all_routes():
1435-
return list(self._unevaluated_pages.keys())
1496+
async def all_routes(_request: Request) -> Response:
1497+
return JSONResponse(list(self._unevaluated_pages.keys()))
1498+
1499+
self._api.add_route(
1500+
str(constants.Endpoint.ALL_ROUTES), all_routes, methods=["GET"]
1501+
)
14361502

14371503
@contextlib.asynccontextmanager
14381504
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
@@ -1687,18 +1753,24 @@ async def process(
16871753
raise
16881754

16891755

1690-
async def ping() -> str:
1756+
async def ping(_request: Request) -> Response:
16911757
"""Test API endpoint.
16921758
1759+
Args:
1760+
_request: The Starlette request object.
1761+
16931762
Returns:
16941763
The response.
16951764
"""
1696-
return "pong"
1765+
return JSONResponse("pong")
16971766

16981767

1699-
async def health() -> JSONResponse:
1768+
async def health(_request: Request) -> JSONResponse:
17001769
"""Health check endpoint to assess the status of the database and Redis services.
17011770
1771+
Args:
1772+
_request: The Starlette request object.
1773+
17021774
Returns:
17031775
JSONResponse: A JSON object with the health status:
17041776
- "status" (bool): Overall health, True if all checks pass.
@@ -1740,12 +1812,11 @@ def upload(app: App):
17401812
The upload function.
17411813
"""
17421814

1743-
async def upload_file(request: Request, files: list[FastAPIUploadFile]):
1815+
async def upload_file(request: Request):
17441816
"""Upload a file.
17451817
17461818
Args:
1747-
request: The FastAPI request object.
1748-
files: The file(s) to upload.
1819+
request: The Starlette request object.
17491820
17501821
Returns:
17511822
StreamingResponse yielding newline-delimited JSON of StateUpdate
@@ -1758,6 +1829,12 @@ async def upload_file(request: Request, files: list[FastAPIUploadFile]):
17581829
"""
17591830
from reflex.utils.exceptions import UploadTypeError, UploadValueError
17601831

1832+
# Get the files from the request.
1833+
files = await request.form()
1834+
files = files.getlist("files")
1835+
if not files:
1836+
raise UploadValueError("No files were uploaded.")
1837+
17611838
token = request.headers.get("reflex-client-token")
17621839
handler = request.headers.get("reflex-event-handler")
17631840

@@ -1810,6 +1887,10 @@ async def upload_file(request: Request, files: list[FastAPIUploadFile]):
18101887
# event is handled.
18111888
file_copies = []
18121889
for file in files:
1890+
if not isinstance(file, StarletteUploadFile):
1891+
raise UploadValueError(
1892+
"Uploaded file is not an UploadFile." + str(file)
1893+
)
18131894
content_copy = io.BytesIO()
18141895
content_copy.write(await file.read())
18151896
content_copy.seek(0)

reflex/app_mixins/lifespan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import inspect
1010
from collections.abc import Callable, Coroutine
1111

12-
from fastapi import FastAPI
12+
from starlette.applications import Starlette
1313

1414
from reflex.utils import console
1515
from reflex.utils.exceptions import InvalidLifespanTaskTypeError
@@ -27,7 +27,7 @@ class LifespanMixin(AppMixin):
2727
)
2828

2929
@contextlib.asynccontextmanager
30-
async def _run_lifespan_tasks(self, app: FastAPI):
30+
async def _run_lifespan_tasks(self, app: Starlette):
3131
running_tasks = []
3232
try:
3333
async with contextlib.AsyncExitStack() as stack:

reflex/testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,11 +322,11 @@ async def _shutdown(*args, **kwargs) -> None:
322322
return _shutdown
323323

324324
def _start_backend(self, port: int = 0):
325-
if self.app_instance is None or self.app_instance.api is None:
325+
if self.app_instance is None or self.app_instance._api is None:
326326
raise RuntimeError("App was not initialized.")
327327
self.backend = uvicorn.Server(
328328
uvicorn.Config(
329-
app=self.app_instance.api,
329+
app=self.app_instance._api,
330330
host="127.0.0.1",
331331
port=port,
332332
)

0 commit comments

Comments
 (0)