2727```
2828"""
2929
30+ import functools
3031import importlib
3132import importlib .metadata
3233import logging
3334import pathlib
3435import sys
36+ from collections .abc import Generator
37+ from contextlib import asynccontextmanager
3538from typing import Any
3639
3740import uvicorn
4144from fastapi .openapi .utils import get_openapi
4245from grpclib .client import Channel
4346from pydantic import BaseModel
44- from pyhocon import ConfigFactory
47+ from pyhocon import ConfigFactory , ConfigTree
4548from starlette .responses import FileResponse
4649from starlette .staticfiles import StaticFiles
4750
@@ -59,6 +62,7 @@ class GetHealthResponse(BaseModel):
5962
6063 status : int
6164
65+
6266def _custom_openapi (server : FastAPI ) -> dict [str , Any ]:
6367 """Customize the OpenAPI schema for ReDoc."""
6468 if server .openapi_schema :
@@ -87,20 +91,51 @@ def _custom_openapi(server: FastAPI) -> dict[str, Any]:
8791 return openapi_schema
8892
8993
90- def run () -> None :
91- """Run the API using a uvicorn server."""
92- # Get the application configuration from the environment
93- conf = ConfigFactory .parse_file ((pathlib .Path (__file__ ).parent / "server.conf" ).as_posix ())
94+ @asynccontextmanager
95+ async def _lifespan (server : FastAPI , conf : ConfigTree ) -> Generator [None ]:
96+ """Configure FastAPI app instance with startup and shutdown events."""
97+ db_instance : models .DatabaseInterface | None = None
98+ grpc_channel : Channel | None = None
99+
100+ match conf .get_string ("backend.source" ):
101+ case "quartzdb" :
102+ db_instance = QuartzClient (
103+ database_url = conf .get_string ("backend.quartzdb.database_url" ),
104+ )
105+ case "dummydb" :
106+ db_instance = DummyClient ()
107+ log .warning ("disabled backend. NOT recommended for production" )
108+ case "dataplatform" :
109+ grpc_channel = Channel (
110+ host = conf .get_string ("backend.dataplatform.host" ),
111+ port = conf .get_int ("backend.dataplatform.port" ),
112+ )
113+ client = dp .DataPlatformDataServiceStub (channel = grpc_channel )
114+ db_instance = DataPlatformClient .from_dp (dp_client = client )
115+ case _ as backend_type :
116+ raise ValueError (f"Unknown backend: { backend_type } " )
94117
95- # Create the FastAPI server
118+ server .dependency_overrides [models .get_db_client ] = lambda : db_instance
119+
120+ yield
121+
122+ if grpc_channel :
123+ grpc_channel .close ()
124+
125+
126+ def _create_server (conf : ConfigTree ) -> FastAPI :
127+ """Configure FastAPI app instance with routes, dependencies, and middleware."""
96128 server = FastAPI (
97129 version = importlib .metadata .version ("quartz_api" ),
130+ lifespan = functools .partial (_lifespan , conf = conf ),
98131 title = "Quartz API" ,
99132 description = __doc__ ,
100- openapi_tags = [{
101- "name" : "API Information" ,
102- "description" : "Routes providing information about the API." ,
103- }],
133+ openapi_tags = [
134+ {
135+ "name" : "API Information" ,
136+ "description" : "Routes providing information about the API." ,
137+ },
138+ ],
104139 docs_url = "/swagger" ,
105140 redoc_url = None ,
106141 )
@@ -126,15 +161,34 @@ def redoc_html() -> FileResponse:
126161 # Setup sentry, if configured
127162 if conf .get_string ("sentry.dsn" ) != "" :
128163 import sentry_sdk
164+
129165 sentry_sdk .init (
130166 dsn = conf .get_string ("sentry.dsn" ),
131167 environment = conf .get_string ("sentry.environment" ),
132168 traces_sample_rate = 1 ,
133169 )
134170
135- sentry_sdk .set_tag ("app_name " , "quartz_api" )
171+ sentry_sdk .set_tag ("server_name " , "quartz_api" )
136172 sentry_sdk .set_tag ("version" , importlib .metadata .version ("quartz_api" ))
137173
174+ # Add routers to the server according to configuration
175+ for r in conf .get_string ("api.routers" ).split ("," ):
176+ try :
177+ mod = importlib .import_module (service .__name__ + f".{ r } " )
178+ server .include_router (mod .router )
179+ except ModuleNotFoundError as e :
180+ raise OSError (f"No such router router '{ r } '" ) from e
181+ server .openapi_tags = [
182+ {
183+ "name" : mod .__name__ .split ("." )[- 1 ].capitalize (),
184+ "description" : mod .__doc__ ,
185+ },
186+ * server .openapi_tags ,
187+ ]
188+
189+ # Customize the OpenAPI schema
190+ server .openapi = lambda : _custom_openapi (server )
191+
138192 # Override dependencies according to configuration
139193 match (conf .get_string ("auth0.domain" ), conf .get_string ("auth0.audience" )):
140194 case (_, "" ) | ("" , _):
@@ -149,32 +203,6 @@ def redoc_html() -> FileResponse:
149203 case _:
150204 raise ValueError ("Invalid Auth0 configuration" )
151205
152- db_instance : models .DatabaseInterface
153- match conf .get_string ("backend.source" ):
154- case "quartzdb" :
155- db_instance = QuartzClient (
156- database_url = conf .get_string ("backend.quartzdb.database_url" ),
157- )
158- case "dummydb" :
159- db_instance = DummyClient ()
160- log .warning ("disabled backend. NOT recommended for production" )
161- case "dataplatform" :
162-
163- channel = Channel (
164- host = conf .get_string ("backend.dataplatform.host" ),
165- port = conf .get_int ("backend.dataplatform.port" ),
166- )
167- client = dp .DataPlatformDataServiceStub (channel = channel )
168- db_instance = DataPlatformClient .from_dp (dp_client = client )
169- case _:
170- raise ValueError (
171- "Unknown backend. "
172- f"Expected one of { list (conf .get ('backend' ).keys ())} " ,
173- )
174-
175- server .dependency_overrides [models .get_db_client ] = lambda : db_instance
176-
177- # Add IANA timezone dependency
178206 timezone : str = conf .get_string ("api.timezone" )
179207 server .dependency_overrides [models .get_timezone ] = lambda : timezone
180208
@@ -186,34 +214,27 @@ def redoc_html() -> FileResponse:
186214 allow_methods = ["*" ],
187215 allow_headers = ["*" ],
188216 )
189- server .add_middleware (
190- audit .RequestLoggerMiddleware ,
191- db_client = db_instance ,
192- )
217+ server .add_middleware (audit .RequestLoggerMiddleware )
193218
194- # Add routers to the server according to configuration
195- for r in conf .get_string ("api.routers" ).split ("," ):
196- try :
197- mod = importlib .import_module (service .__name__ + f".{ r } " )
198- server .include_router (mod .router )
199- except ModuleNotFoundError as e :
200- raise OSError (f"No such router router '{ r } '" ) from e
201- server .openapi_tags = [{
202- "name" : mod .__name__ .split ("." )[- 1 ].capitalize (),
203- "description" : mod .__doc__ ,
204- }, * server .openapi_tags ]
219+ return server
205220
206- # Customize the OpenAPI schema
207- server .openapi = lambda : _custom_openapi (server )
221+
222+ def run () -> None :
223+ """Run the API using a uvicorn server."""
224+ # Get the application configuration from the environment
225+ conf = ConfigFactory .parse_file ((pathlib .Path (__file__ ).parent / "server.conf" ).as_posix ())
226+
227+ server = _create_server (conf = conf )
208228
209229 # Run the server with uvicorn
210230 uvicorn .run (
211231 server ,
212- host = "0.0.0.0" , # noqa: S104
232+ host = "0.0.0.0" , # noqa: S104
213233 port = conf .get_int ("api.port" ),
214234 log_level = conf .get_string ("api.loglevel" ),
215235 )
216236
217237
218238if __name__ == "__main__" :
219239 run ()
240+
0 commit comments