Skip to content

Commit 9422735

Browse files
committed
fix(server): Enable data platform connections
1 parent 3191a77 commit 9422735

File tree

3 files changed

+80
-58
lines changed

3 files changed

+80
-58
lines changed

src/quartz_api/cmd/main.py

Lines changed: 76 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,14 @@
2727
```
2828
"""
2929

30+
import functools
3031
import importlib
3132
import importlib.metadata
3233
import logging
3334
import pathlib
3435
import sys
36+
from collections.abc import Generator
37+
from contextlib import asynccontextmanager
3538
from typing import Any
3639

3740
import uvicorn
@@ -41,7 +44,7 @@
4144
from fastapi.openapi.utils import get_openapi
4245
from grpclib.client import Channel
4346
from pydantic import BaseModel
44-
from pyhocon import ConfigFactory
47+
from pyhocon import ConfigFactory, ConfigTree
4548
from starlette.responses import FileResponse
4649
from starlette.staticfiles import StaticFiles
4750

@@ -59,6 +62,7 @@ class GetHealthResponse(BaseModel):
5962

6063
status: int
6164

65+
6266
def _custom_openapi(server: FastAPI) -> dict[str, Any]:
6367
"""Customize the OpenAPI schema for ReDoc."""
6468
if server.openapi_schema:
@@ -87,20 +91,51 @@ def _custom_openapi(server: FastAPI) -> dict[str, Any]:
8791
return openapi_schema
8892

8993

90-
def run() -> None:
91-
"""Run the API using a uvicorn server."""
92-
# Get the application configuration from the environment
93-
conf = ConfigFactory.parse_file((pathlib.Path(__file__).parent / "server.conf").as_posix())
94+
@asynccontextmanager
95+
async def _lifespan(server: FastAPI, conf: ConfigTree) -> Generator[None]:
96+
"""Configure FastAPI app instance with startup and shutdown events."""
97+
db_instance: models.DatabaseInterface | None = None
98+
grpc_channel: Channel | None = None
99+
100+
match conf.get_string("backend.source"):
101+
case "quartzdb":
102+
db_instance = QuartzClient(
103+
database_url=conf.get_string("backend.quartzdb.database_url"),
104+
)
105+
case "dummydb":
106+
db_instance = DummyClient()
107+
log.warning("disabled backend. NOT recommended for production")
108+
case "dataplatform":
109+
grpc_channel = Channel(
110+
host=conf.get_string("backend.dataplatform.host"),
111+
port=conf.get_int("backend.dataplatform.port"),
112+
)
113+
client = dp.DataPlatformDataServiceStub(channel=grpc_channel)
114+
db_instance = DataPlatformClient.from_dp(dp_client=client)
115+
case _ as backend_type:
116+
raise ValueError(f"Unknown backend: {backend_type}")
94117

95-
# Create the FastAPI server
118+
server.dependency_overrides[models.get_db_client] = lambda: db_instance
119+
120+
yield
121+
122+
if grpc_channel:
123+
grpc_channel.close()
124+
125+
126+
def _create_server(conf: ConfigTree) -> FastAPI:
127+
"""Configure FastAPI app instance with routes, dependencies, and middleware."""
96128
server = FastAPI(
97129
version=importlib.metadata.version("quartz_api"),
130+
lifespan=functools.partial(_lifespan, conf=conf),
98131
title="Quartz API",
99132
description=__doc__,
100-
openapi_tags=[{
101-
"name": "API Information",
102-
"description": "Routes providing information about the API.",
103-
}],
133+
openapi_tags=[
134+
{
135+
"name": "API Information",
136+
"description": "Routes providing information about the API.",
137+
},
138+
],
104139
docs_url="/swagger",
105140
redoc_url=None,
106141
)
@@ -126,15 +161,34 @@ def redoc_html() -> FileResponse:
126161
# Setup sentry, if configured
127162
if conf.get_string("sentry.dsn") != "":
128163
import sentry_sdk
164+
129165
sentry_sdk.init(
130166
dsn=conf.get_string("sentry.dsn"),
131167
environment=conf.get_string("sentry.environment"),
132168
traces_sample_rate=1,
133169
)
134170

135-
sentry_sdk.set_tag("app_name", "quartz_api")
171+
sentry_sdk.set_tag("server_name", "quartz_api")
136172
sentry_sdk.set_tag("version", importlib.metadata.version("quartz_api"))
137173

174+
# Add routers to the server according to configuration
175+
for r in conf.get_string("api.routers").split(","):
176+
try:
177+
mod = importlib.import_module(service.__name__ + f".{r}")
178+
server.include_router(mod.router)
179+
except ModuleNotFoundError as e:
180+
raise OSError(f"No such router router '{r}'") from e
181+
server.openapi_tags = [
182+
{
183+
"name": mod.__name__.split(".")[-1].capitalize(),
184+
"description": mod.__doc__,
185+
},
186+
*server.openapi_tags,
187+
]
188+
189+
# Customize the OpenAPI schema
190+
server.openapi = lambda: _custom_openapi(server)
191+
138192
# Override dependencies according to configuration
139193
match (conf.get_string("auth0.domain"), conf.get_string("auth0.audience")):
140194
case (_, "") | ("", _):
@@ -149,32 +203,6 @@ def redoc_html() -> FileResponse:
149203
case _:
150204
raise ValueError("Invalid Auth0 configuration")
151205

152-
db_instance: models.DatabaseInterface
153-
match conf.get_string("backend.source"):
154-
case "quartzdb":
155-
db_instance = QuartzClient(
156-
database_url=conf.get_string("backend.quartzdb.database_url"),
157-
)
158-
case "dummydb":
159-
db_instance = DummyClient()
160-
log.warning("disabled backend. NOT recommended for production")
161-
case "dataplatform":
162-
163-
channel = Channel(
164-
host=conf.get_string("backend.dataplatform.host"),
165-
port=conf.get_int("backend.dataplatform.port"),
166-
)
167-
client = dp.DataPlatformDataServiceStub(channel=channel)
168-
db_instance = DataPlatformClient.from_dp(dp_client=client)
169-
case _:
170-
raise ValueError(
171-
"Unknown backend. "
172-
f"Expected one of {list(conf.get('backend').keys())}",
173-
)
174-
175-
server.dependency_overrides[models.get_db_client] = lambda: db_instance
176-
177-
# Add IANA timezone dependency
178206
timezone: str = conf.get_string("api.timezone")
179207
server.dependency_overrides[models.get_timezone] = lambda: timezone
180208

@@ -186,34 +214,27 @@ def redoc_html() -> FileResponse:
186214
allow_methods=["*"],
187215
allow_headers=["*"],
188216
)
189-
server.add_middleware(
190-
audit.RequestLoggerMiddleware,
191-
db_client=db_instance,
192-
)
217+
server.add_middleware(audit.RequestLoggerMiddleware)
193218

194-
# Add routers to the server according to configuration
195-
for r in conf.get_string("api.routers").split(","):
196-
try:
197-
mod = importlib.import_module(service.__name__ + f".{r}")
198-
server.include_router(mod.router)
199-
except ModuleNotFoundError as e:
200-
raise OSError(f"No such router router '{r}'") from e
201-
server.openapi_tags = [{
202-
"name": mod.__name__.split(".")[-1].capitalize(),
203-
"description": mod.__doc__,
204-
}, *server.openapi_tags]
219+
return server
205220

206-
# Customize the OpenAPI schema
207-
server.openapi = lambda: _custom_openapi(server)
221+
222+
def run() -> None:
223+
"""Run the API using a uvicorn server."""
224+
# Get the application configuration from the environment
225+
conf = ConfigFactory.parse_file((pathlib.Path(__file__).parent / "server.conf").as_posix())
226+
227+
server = _create_server(conf=conf)
208228

209229
# Run the server with uvicorn
210230
uvicorn.run(
211231
server,
212-
host="0.0.0.0", # noqa: S104
232+
host="0.0.0.0", # noqa: S104
213233
port=conf.get_int("api.port"),
214234
log_level=conf.get_string("api.loglevel"),
215235
)
216236

217237

218238
if __name__ == "__main__":
219239
run()
240+

src/quartz_api/internal/middleware/audit.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
class RequestLoggerMiddleware(BaseHTTPMiddleware):
1313
"""Middleware to log API requests to the database."""
1414

15-
def __init__(self, server: FastAPI, db_client: models.DatabaseInterface) -> None:
15+
def __init__(self, server: FastAPI) -> None:
1616
"""Initialize the middleware with the FastAPI server and database client."""
1717
super().__init__(server)
18-
self.db_client = db_client
1918

2019
async def dispatch(
2120
self,
@@ -38,6 +37,9 @@ async def dispatch(
3837
url += f"?{request.url.query}"
3938

4039
try:
40+
db_client: models.DatabaseInterface = getattr(request.app.state, "db_instance", None)
41+
if db_client is None:
42+
raise RuntimeError("Database client not found in app state.")
4143
await self.db_client.save_api_call_to_db(url=url, authdata=auth)
4244
except Exception as e:
4345
logging.error(f"Failed to log request to DB: {e}")

src/quartz_api/internal/service/substations/router.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
router = APIRouter(tags=[pathlib.Path(__file__).parent.stem.capitalize()])
1313

14-
1514
@router.get(
1615
"/substations",
1716
status_code=status.HTTP_200_OK,

0 commit comments

Comments
 (0)