Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/branch_ci.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# yaml-language-server: $schema=https://www.schemastore.org/github-workflow.json

name: Branch Push CI (Python)

on:
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/bump_tag.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# yaml-language-server: $schema=https://www.schemastore.org/github-workflow.json

name: Merged CI
run-name: 'Bump tag with merge #${{ github.event.number }} "${{ github.event.pull_request.title }}"'

Expand All @@ -9,4 +11,4 @@ on:
jobs:
bump-tag:
uses: openclimatefix/.github/.github/workflows/bump_tag.yml@main
secrets: inherit
secrets: inherit
4 changes: 3 additions & 1 deletion .github/workflows/tagged_ci.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# yaml-language-server: $schema=https://json.schemastore.org/github-workflow

name: Tagged CI
run-name: 'Tagged CI for ${{ github.ref_name }} by ${{ github.actor }}'

Expand All @@ -11,4 +13,4 @@ jobs:
secrets: inherit
with:
containerfile: Containerfile
enable_pypi: false
enable_pypi: false
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dev = [
]

[tool.uv.sources]
dp-sdk = { url = "https://github.com/openclimatefix/data-platform/releases/download/v0.14.0/dp_sdk-0.14.0-py3-none-any.whl" }
dp-sdk = { url = "https://github.com/openclimatefix/data-platform/releases/download/v0.16.0/dp_sdk-0.16.0-py3-none-any.whl" }

[project.urls]
repository = "https://github.com/openclimatefix/quartz-api"
Expand Down
139 changes: 83 additions & 56 deletions src/quartz_api/cmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@
```
"""

import functools
import importlib
import importlib.metadata
import logging
import pathlib
import sys
from importlib.metadata import version
from collections.abc import Generator
from contextlib import asynccontextmanager
from typing import Any

import uvicorn
Expand All @@ -40,14 +44,13 @@
from fastapi.openapi.utils import get_openapi
from grpclib.client import Channel
from pydantic import BaseModel
from pyhocon import ConfigFactory
from pyhocon import ConfigFactory, ConfigTree
from starlette.responses import FileResponse
from starlette.staticfiles import StaticFiles

from quartz_api.internal import models, service
from quartz_api.internal.backends import DataPlatformClient, DummyClient, QuartzClient
from quartz_api.internal.middleware import audit, auth
from quartz_api.internal.models import DatabaseInterface, get_db_client
from quartz_api.internal.service import regions, sites

log = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
Expand All @@ -59,6 +62,7 @@ class GetHealthResponse(BaseModel):

status: int


def _custom_openapi(server: FastAPI) -> dict[str, Any]:
"""Customize the OpenAPI schema for ReDoc."""
if server.openapi_schema:
Expand Down Expand Up @@ -87,20 +91,51 @@ def _custom_openapi(server: FastAPI) -> dict[str, Any]:
return openapi_schema


def run() -> None:
"""Run the API using a uvicorn server."""
# Get the application configuration from the environment
conf = ConfigFactory.parse_file((pathlib.Path(__file__).parent / "server.conf").as_posix())
@asynccontextmanager
async def _lifespan(server: FastAPI, conf: ConfigTree) -> Generator[None]:
"""Configure FastAPI app instance with startup and shutdown events."""
db_instance: models.DatabaseInterface | None = None
grpc_channel: Channel | None = None

# Create the FastAPI server
match conf.get_string("backend.source"):
case "quartzdb":
db_instance = QuartzClient(
database_url=conf.get_string("backend.quartzdb.database_url"),
)
case "dummydb":
db_instance = DummyClient()
log.warning("disabled backend. NOT recommended for production")
case "dataplatform":
grpc_channel = Channel(
host=conf.get_string("backend.dataplatform.host"),
port=conf.get_int("backend.dataplatform.port"),
)
client = dp.DataPlatformDataServiceStub(channel=grpc_channel)
db_instance = DataPlatformClient.from_dp(dp_client=client)
case _ as backend_type:
raise ValueError(f"Unknown backend: {backend_type}")

server.dependency_overrides[models.get_db_client] = lambda: db_instance

yield

if grpc_channel:
grpc_channel.close()


def _create_server(conf: ConfigTree) -> FastAPI:
"""Configure FastAPI app instance with routes, dependencies, and middleware."""
server = FastAPI(
version=version("quartz_api"),
version=importlib.metadata.version("quartz_api"),
lifespan=functools.partial(_lifespan, conf=conf),
title="Quartz API",
description=__doc__,
openapi_tags=[{
"name": "API Information",
"description": "Routes providing information about the API.",
}],
openapi_tags=[
{
"name": "API Information",
"description": "Routes providing information about the API.",
},
],
docs_url="/swagger",
redoc_url=None,
)
Expand All @@ -126,14 +161,33 @@ def redoc_html() -> FileResponse:
# Setup sentry, if configured
if conf.get_string("sentry.dsn") != "":
import sentry_sdk

sentry_sdk.init(
dsn=conf.get_string("sentry.dsn"),
environment=conf.get_string("sentry.environment"),
traces_sample_rate=1,
)

sentry_sdk.set_tag("app_name", "quartz_api")
sentry_sdk.set_tag("version", version("quartz_api"))
sentry_sdk.set_tag("server_name", "quartz_api")
sentry_sdk.set_tag("version", importlib.metadata.version("quartz_api"))

# Add routers to the server according to configuration
for r in conf.get_string("api.routers").split(","):
try:
mod = importlib.import_module(service.__name__ + f".{r}")
server.include_router(mod.router)
except ModuleNotFoundError as e:
raise OSError(f"No such router router '{r}'") from e
server.openapi_tags = [
{
"name": mod.__name__.split(".")[-1].capitalize(),
"description": mod.__doc__,
},
*server.openapi_tags,
]

# Customize the OpenAPI schema
server.openapi = lambda: _custom_openapi(server)

# Override dependencies according to configuration
match (conf.get_string("auth0.domain"), conf.get_string("auth0.audience")):
Expand All @@ -149,30 +203,8 @@ def redoc_html() -> FileResponse:
case _:
raise ValueError("Invalid Auth0 configuration")

db_instance: DatabaseInterface
match conf.get_string("backend.source"):
case "quartzdb":
db_instance = QuartzClient(
database_url=conf.get_string("backend.quartzdb.database_url"),
)
case "dummydb":
db_instance = DummyClient()
log.warning("disabled backend. NOT recommended for production")
case "dataplatform":

channel = Channel(
host=conf.get_string("backend.dataplatform.host"),
port=conf.get_int("backend.dataplatform.port"),
)
client = dp.DataPlatformDataServiceStub(channel=channel)
db_instance = DataPlatformClient.from_dp(dp_client=client)
case _:
raise ValueError(
"Unknown backend. "
f"Expected one of {list(conf.get('backend').keys())}",
)

server.dependency_overrides[get_db_client] = lambda: db_instance
timezone: str = conf.get_string("api.timezone")
server.dependency_overrides[models.get_timezone] = lambda: timezone

# Add middlewares
server.add_middleware(
Expand All @@ -182,32 +214,27 @@ def redoc_html() -> FileResponse:
allow_methods=["*"],
allow_headers=["*"],
)
server.add_middleware(
audit.RequestLoggerMiddleware,
db_client=db_instance,
)
server.add_middleware(audit.RequestLoggerMiddleware)

# Add routers to the server according to configuration
for router_module in [sites, regions]:
if conf.get_string("api.router") == router_module.__name__.split(".")[-1]:
server.include_router(router_module.router)
server.openapi_tags = [{
"name": router_module.__name__.split(".")[-1].capitalize(),
"description": router_module.__doc__,
}, *server.openapi_tags]
break
return server

# Customize the OpenAPI schema
server.openapi = lambda: _custom_openapi(server)

def run() -> None:
"""Run the API using a uvicorn server."""
# Get the application configuration from the environment
conf = ConfigFactory.parse_file((pathlib.Path(__file__).parent / "server.conf").as_posix())

server = _create_server(conf=conf)

# Run the server with uvicorn
uvicorn.run(
server,
host="0.0.0.0", # noqa: S104
host="0.0.0.0", # noqa: S104
port=conf.get_int("api.port"),
log_level=conf.get_string("api.loglevel"),
)


if __name__ == "__main__":
run()

11 changes: 8 additions & 3 deletions src/quartz_api/cmd/server.conf
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@ api {
port = ${?PORT}
loglevel = "debug"
loglevel = ${?LOGLEVEL}
// Which router to serve requests through
router = "none"
router = ${?ROUTER}
// Comma seperated list of routers to enable
// Supported routers: "sites", "regions", "substations"
routers = ""
routers = ${?ROUTERS}
origins = "*"
origins = ${?ORIGINS}
// The IANA timezone string to use for date/time operations
timezone = "UTC"
timezone = ${?TZ}
}

// The backend to use for the service
// Supported backends: "dummydb", "quartzdb", "dataplatform"
backend {
source = "dummydb"
source = ${?SOURCE}
Expand Down
11 changes: 1 addition & 10 deletions src/quartz_api/internal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1 @@
from .models import (
ActualPower,
PredictedPower,
Site,
SiteProperties,
DBClientDependency,
ForecastHorizon,
DatabaseInterface,
get_db_client
)

Loading