1313import json
1414import sys
1515import traceback
16- from collections .abc import AsyncIterator , Callable , Coroutine , MutableMapping
16+ from collections .abc import AsyncIterator , Callable , Coroutine , Sequence
1717from datetime import datetime
1818from pathlib import Path
1919from timeit import default_timer as timer
2020from types import SimpleNamespace
2121from typing import TYPE_CHECKING , Any , BinaryIO , get_args , get_type_hints
2222
23- from fastapi import FastAPI , HTTPException , Request
24- from fastapi import UploadFile as FastAPIUploadFile
25- from fastapi .middleware import cors
26- from fastapi .responses import JSONResponse , StreamingResponse
27- from fastapi .staticfiles import StaticFiles
23+ from fastapi import FastAPI
2824from rich .progress import MofNCompleteColumn , Progress , TimeElapsedColumn
29- from socketio import ASGIApp , AsyncNamespace , AsyncServer
25+ from socketio import ASGIApp as EngineIOApp
26+ from socketio import AsyncNamespace , AsyncServer
27+ from starlette .applications import Starlette
3028from starlette .datastructures import Headers
3129from starlette .datastructures import UploadFile as StarletteUploadFile
30+ from starlette .exceptions import HTTPException
31+ from starlette .middleware import cors
32+ from starlette .requests import Request
33+ from starlette .responses import JSONResponse , Response , StreamingResponse
34+ from starlette .staticfiles import StaticFiles
35+ from typing_extensions import deprecated
3236
3337from reflex import constants
3438from reflex .admin import AdminDash
102106)
103107from reflex .utils .exec import get_compile_context , is_prod_mode , is_testing_env
104108from reflex .utils .imports import ImportVar
109+ from reflex .utils .types import ASGIApp , Message , Receive , Scope , Send
105110
106111if TYPE_CHECKING :
107112 from reflex .vars import Var
@@ -389,7 +394,7 @@ class App(MiddlewareMixin, LifespanMixin):
389394 _stateful_pages : dict [str , None ] = dataclasses .field (default_factory = dict )
390395
391396 # The backend API object.
392- _api : FastAPI | None = None
397+ _api : Starlette | None = None
393398
394399 # The state class to use for the app.
395400 _state : type [BaseState ] | None = None
@@ -424,14 +429,34 @@ class App(MiddlewareMixin, LifespanMixin):
424429 # Put the toast provider in the app wrap.
425430 toaster : Component | None = dataclasses .field (default_factory = toast .provider )
426431
432+ # Transform the ASGI app before running it.
433+ api_transformer : (
434+ Sequence [Callable [[ASGIApp ], ASGIApp ] | Starlette ]
435+ | Callable [[ASGIApp ], ASGIApp ]
436+ | Starlette
437+ | None
438+ ) = None
439+
440+ # FastAPI app for compatibility with FastAPI.
441+ _cached_fastapi_app : FastAPI | None = None
442+
427443 @property
428- def api (self ) -> FastAPI | None :
444+ @deprecated ("Use `api_transformer=your_fastapi_app` instead." )
445+ def api (self ) -> FastAPI :
429446 """Get the backend api.
430447
431448 Returns:
432449 The backend api.
433450 """
434- return self ._api
451+ if self ._cached_fastapi_app is None :
452+ self ._cached_fastapi_app = FastAPI ()
453+ console .deprecate (
454+ feature_name = "App.api" ,
455+ reason = "Set `api_transformer=your_fastapi_app` instead." ,
456+ deprecation_version = "0.7.9" ,
457+ removal_version = "0.8.0" ,
458+ )
459+ return self ._cached_fastapi_app
435460
436461 @property
437462 def event_namespace (self ) -> EventNamespace | None :
@@ -463,7 +488,7 @@ def __post_init__(self):
463488 set_breakpoints (self .style .pop ("breakpoints" ))
464489
465490 # Set up the API.
466- self ._api = FastAPI (lifespan = self ._run_lifespan_tasks )
491+ self ._api = Starlette (lifespan = self ._run_lifespan_tasks )
467492 self ._add_cors ()
468493 self ._add_default_endpoints ()
469494
@@ -529,7 +554,7 @@ def _setup_state(self) -> None:
529554 )
530555
531556 # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
532- socket_app = ASGIApp (self .sio , socketio_path = "" )
557+ socket_app = EngineIOApp (self .sio , socketio_path = "" )
533558 namespace = config .get_event_namespace ()
534559
535560 # Create the event namespace and attach the main app. Not related to any paths.
@@ -538,18 +563,16 @@ def _setup_state(self) -> None:
538563 # Register the event namespace with the socket.
539564 self .sio .register_namespace (self .event_namespace )
540565 # Mount the socket app with the API.
541- if self .api :
566+ if self ._api :
542567
543568 class HeaderMiddleware :
544569 def __init__ (self , app : ASGIApp ):
545570 self .app = app
546571
547- async def __call__ (
548- self , scope : MutableMapping [str , Any ], receive : Any , send : Callable
549- ):
572+ async def __call__ (self , scope : Scope , receive : Receive , send : Send ):
550573 original_send = send
551574
552- async def modified_send (message : dict ):
575+ async def modified_send (message : Message ):
553576 if message ["type" ] == "websocket.accept" :
554577 if scope .get ("subprotocols" ):
555578 # The following *does* say "subprotocol" instead of "subprotocols", intentionally.
@@ -568,7 +591,7 @@ async def modified_send(message: dict):
568591 return await self .app (scope , receive , modified_send )
569592
570593 socket_app_with_headers = HeaderMiddleware (socket_app )
571- self .api .mount (str (constants .Endpoint .EVENT ), socket_app_with_headers )
594+ self ._api .mount (str (constants .Endpoint .EVENT ), socket_app_with_headers )
572595
573596 # Check the exception handlers
574597 self ._validate_exception_handlers ()
@@ -581,7 +604,7 @@ def __repr__(self) -> str:
581604 """
582605 return f"<App state={ self ._state .__name__ if self ._state else None } >"
583606
584- def __call__ (self ) -> FastAPI :
607+ def __call__ (self ) -> ASGIApp :
585608 """Run the backend api instance.
586609
587610 Raises:
@@ -590,8 +613,18 @@ def __call__(self) -> FastAPI:
590613 Returns:
591614 The backend api.
592615 """
593- if not self .api :
594- raise ValueError ("The app has not been initialized." )
616+ if self ._cached_fastapi_app is not None :
617+ asgi_app = self ._cached_fastapi_app
618+
619+ if not asgi_app or not self ._api :
620+ raise ValueError ("The app has not been initialized." )
621+
622+ asgi_app .mount ("" , self ._api )
623+ else :
624+ asgi_app = self ._api
625+
626+ if not asgi_app :
627+ raise ValueError ("The app has not been initialized." )
595628
596629 # For py3.9 compatibility when redis is used, we MUST add any decorator pages
597630 # before compiling the app in a thread to avoid event loop error (REF-2172).
@@ -608,30 +641,58 @@ def __call__(self) -> FastAPI:
608641 if is_prod_mode ():
609642 compile_future .result ()
610643
611- return self .api
644+ if self .api_transformer is not None :
645+ api_transformers : Sequence [Starlette | Callable [[ASGIApp ], ASGIApp ]] = (
646+ [self .api_transformer ]
647+ if not isinstance (self .api_transformer , Sequence )
648+ else self .api_transformer
649+ )
650+
651+ for api_transformer in api_transformers :
652+ if isinstance (api_transformer , Starlette ):
653+ # Mount the api to the fastapi app.
654+ api_transformer .mount ("" , asgi_app )
655+ asgi_app = api_transformer
656+ else :
657+ # Transform the asgi app.
658+ asgi_app = api_transformer (asgi_app )
659+
660+ return asgi_app
612661
613662 def _add_default_endpoints (self ):
614663 """Add default api endpoints (ping)."""
615664 # To test the server.
616- if not self .api :
665+ if not self ._api :
617666 return
618667
619- self .api .get (str (constants .Endpoint .PING ))(ping )
620- self .api .get (str (constants .Endpoint .HEALTH ))(health )
668+ self ._api .add_route (
669+ str (constants .Endpoint .PING ),
670+ ping ,
671+ methods = ["GET" ],
672+ )
673+ self ._api .add_route (
674+ str (constants .Endpoint .HEALTH ),
675+ health ,
676+ methods = ["GET" ],
677+ )
621678
622679 def _add_optional_endpoints (self ):
623680 """Add optional api endpoints (_upload)."""
624- if not self .api :
681+ if not self ._api :
625682 return
626683 upload_is_used_marker = (
627684 prerequisites .get_backend_dir () / constants .Dirs .UPLOAD_IS_USED
628685 )
629686 if Upload .is_used or upload_is_used_marker .exists ():
630687 # To upload files.
631- self .api .post (str (constants .Endpoint .UPLOAD ))(upload (self ))
688+ self ._api .add_route (
689+ str (constants .Endpoint .UPLOAD ),
690+ upload (self ),
691+ methods = ["POST" ],
692+ )
632693
633694 # To access uploaded files.
634- self .api .mount (
695+ self ._api .mount (
635696 str (constants .Endpoint .UPLOAD ),
636697 StaticFiles (directory = get_upload_dir ()),
637698 name = "uploaded_files" ,
@@ -640,17 +701,19 @@ def _add_optional_endpoints(self):
640701 upload_is_used_marker .parent .mkdir (parents = True , exist_ok = True )
641702 upload_is_used_marker .touch ()
642703 if codespaces .is_running_in_codespaces ():
643- self .api .get (str (constants .Endpoint .AUTH_CODESPACE ))(
644- codespaces .auth_codespace
704+ self ._api .add_route (
705+ str (constants .Endpoint .AUTH_CODESPACE ),
706+ codespaces .auth_codespace ,
707+ methods = ["GET" ],
645708 )
646709 if environment .REFLEX_ADD_ALL_ROUTES_ENDPOINT .get ():
647710 self .add_all_routes_endpoint ()
648711
649712 def _add_cors (self ):
650713 """Add CORS middleware to the app."""
651- if not self .api :
714+ if not self ._api :
652715 return
653- self .api .add_middleware (
716+ self ._api .add_middleware (
654717 cors .CORSMiddleware ,
655718 allow_credentials = True ,
656719 allow_methods = ["*" ],
@@ -915,7 +978,7 @@ def _setup_admin_dash(self):
915978 return
916979
917980 # Get the admin dash.
918- if not self .api :
981+ if not self ._api :
919982 return
920983
921984 admin_dash = self .admin_dash
@@ -936,7 +999,7 @@ def _setup_admin_dash(self):
936999 view = admin_dash .view_overrides .get (model , ModelView )
9371000 admin .add_view (view (model ))
9381001
939- admin .mount_to (self .api )
1002+ admin .mount_to (self ._api )
9401003
9411004 def _get_frontend_packages (self , imports : dict [str , set [ImportVar ]]):
9421005 """Gets the frontend packages to be installed and filters out the unnecessary ones.
@@ -1427,12 +1490,15 @@ def _write_stateful_pages_marker(self):
14271490
14281491 def add_all_routes_endpoint (self ):
14291492 """Add an endpoint to the app that returns all the routes."""
1430- if not self .api :
1493+ if not self ._api :
14311494 return
14321495
1433- @self .api .get (str (constants .Endpoint .ALL_ROUTES ))
1434- async def all_routes ():
1435- return list (self ._unevaluated_pages .keys ())
1496+ async def all_routes (_request : Request ) -> Response :
1497+ return JSONResponse (list (self ._unevaluated_pages .keys ()))
1498+
1499+ self ._api .add_route (
1500+ str (constants .Endpoint .ALL_ROUTES ), all_routes , methods = ["GET" ]
1501+ )
14361502
14371503 @contextlib .asynccontextmanager
14381504 async def modify_state (self , token : str ) -> AsyncIterator [BaseState ]:
@@ -1687,18 +1753,24 @@ async def process(
16871753 raise
16881754
16891755
1690- async def ping () -> str :
1756+ async def ping (_request : Request ) -> Response :
16911757 """Test API endpoint.
16921758
1759+ Args:
1760+ _request: The Starlette request object.
1761+
16931762 Returns:
16941763 The response.
16951764 """
1696- return "pong"
1765+ return JSONResponse ( "pong" )
16971766
16981767
1699- async def health () -> JSONResponse :
1768+ async def health (_request : Request ) -> JSONResponse :
17001769 """Health check endpoint to assess the status of the database and Redis services.
17011770
1771+ Args:
1772+ _request: The Starlette request object.
1773+
17021774 Returns:
17031775 JSONResponse: A JSON object with the health status:
17041776 - "status" (bool): Overall health, True if all checks pass.
@@ -1740,12 +1812,11 @@ def upload(app: App):
17401812 The upload function.
17411813 """
17421814
1743- async def upload_file (request : Request , files : list [ FastAPIUploadFile ] ):
1815+ async def upload_file (request : Request ):
17441816 """Upload a file.
17451817
17461818 Args:
1747- request: The FastAPI request object.
1748- files: The file(s) to upload.
1819+ request: The Starlette request object.
17491820
17501821 Returns:
17511822 StreamingResponse yielding newline-delimited JSON of StateUpdate
@@ -1758,6 +1829,12 @@ async def upload_file(request: Request, files: list[FastAPIUploadFile]):
17581829 """
17591830 from reflex .utils .exceptions import UploadTypeError , UploadValueError
17601831
1832+ # Get the files from the request.
1833+ files = await request .form ()
1834+ files = files .getlist ("files" )
1835+ if not files :
1836+ raise UploadValueError ("No files were uploaded." )
1837+
17611838 token = request .headers .get ("reflex-client-token" )
17621839 handler = request .headers .get ("reflex-event-handler" )
17631840
@@ -1810,6 +1887,10 @@ async def upload_file(request: Request, files: list[FastAPIUploadFile]):
18101887 # event is handled.
18111888 file_copies = []
18121889 for file in files :
1890+ if not isinstance (file , StarletteUploadFile ):
1891+ raise UploadValueError (
1892+ "Uploaded file is not an UploadFile." + str (file )
1893+ )
18131894 content_copy = io .BytesIO ()
18141895 content_copy .write (await file .read ())
18151896 content_copy .seek (0 )
0 commit comments