Skip to content

Commit 9aba918

Browse files
committed
chore(repo): Typing
1 parent 4aefd85 commit 9aba918

File tree

21 files changed

+228
-185
lines changed

21 files changed

+228
-185
lines changed

src/quartz_api/cmd/main.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,9 @@
4545
from starlette.responses import FileResponse
4646
from starlette.staticfiles import StaticFiles
4747

48-
from quartz_api.internal import service
48+
from quartz_api.internal import models, service
4949
from quartz_api.internal.backends import DataPlatformClient, DummyClient, QuartzClient
5050
from quartz_api.internal.middleware import audit, auth
51-
from quartz_api.internal.models import DatabaseInterface, get_db_client
5251

5352
log = logging.getLogger(__name__)
5453
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
@@ -150,7 +149,7 @@ def redoc_html() -> FileResponse:
150149
case _:
151150
raise ValueError("Invalid Auth0 configuration")
152151

153-
db_instance: DatabaseInterface
152+
db_instance: models.DatabaseInterface
154153
match conf.get_string("backend.source"):
155154
case "quartzdb":
156155
db_instance = QuartzClient(
@@ -173,7 +172,11 @@ def redoc_html() -> FileResponse:
173172
f"Expected one of {list(conf.get('backend').keys())}",
174173
)
175174

176-
server.dependency_overrides[get_db_client] = lambda: db_instance
175+
server.dependency_overrides[models.get_db_client] = lambda: db_instance
176+
177+
# Add IANA timezone dependency
178+
timezone: str = conf.get_string("api.timezone")
179+
server.dependency_overrides[models.get_timezone] = lambda: timezone
177180

178181
# Add middlewares
179182
server.add_middleware(

src/quartz_api/cmd/server.conf

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ api {
1313
routers = ${?ROUTERS}
1414
origins = "*"
1515
origins = ${?ORIGINS}
16+
// The IANA timezone string to use for date/time operations
17+
timezone = "UTC"
18+
timezone = ${?TZ}
1619
}
1720

1821
// The backend to use for the service

src/quartz_api/internal/backends/dataplatform/client.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async def get_predicted_solar_power_production_for_location(
3636
smooth_flag: bool = True,
3737
) -> list[models.PredictedPower]:
3838
values = await self._get_predicted_power_production_for_location(
39-
location=location,
39+
location_uuid=UUID(location),
4040
energy_source=dp.EnergySource.SOLAR,
4141
forecast_horizon=forecast_horizon,
4242
forecast_horizon_minutes=forecast_horizon_minutes,
@@ -54,7 +54,7 @@ async def get_predicted_wind_power_production_for_location(
5454
smooth_flag: bool = True,
5555
) -> list[models.PredictedPower]:
5656
values = await self._get_predicted_power_production_for_location(
57-
location=location,
57+
location_uuid=UUID(location),
5858
energy_source=dp.EnergySource.WIND,
5959
forecast_horizon=forecast_horizon,
6060
forecast_horizon_minutes=forecast_horizon_minutes,
@@ -69,7 +69,7 @@ async def get_actual_solar_power_production_for_location(
6969
location: str,
7070
) -> list[models.ActualPower]:
7171
values = await self._get_actual_power_production_for_location(
72-
location,
72+
UUID(location),
7373
dp.EnergySource.SOLAR,
7474
oauth_id=None,
7575
)
@@ -81,7 +81,7 @@ async def get_actual_wind_power_production_for_location(
8181
location: str,
8282
) -> list[models.ActualPower]:
8383
values = await self._get_actual_power_production_for_location(
84-
location,
84+
UUID(location),
8585
dp.EnergySource.WIND,
8686
oauth_id=None,
8787
)
@@ -133,7 +133,7 @@ async def get_sites(self, authdata: dict[str, str]) -> list[models.Site]:
133133
@override
134134
async def put_site(
135135
self,
136-
site_uuid: str,
136+
site_uuid: UUID,
137137
site_properties: models.SiteProperties,
138138
authdata: dict[str, str],
139139
) -> models.Site:
@@ -142,7 +142,7 @@ async def put_site(
142142
@override
143143
async def get_site_forecast(
144144
self,
145-
site_uuid: str,
145+
site_uuid: UUID,
146146
authdata: dict[str, str],
147147
) -> list[models.PredictedPower]:
148148
forecast = await self._get_predicted_power_production_for_location(
@@ -155,7 +155,7 @@ async def get_site_forecast(
155155
@override
156156
async def get_site_generation(
157157
self,
158-
site_uuid: str,
158+
site_uuid: UUID,
159159
authdata: dict[str, str],
160160
) -> list[models.ActualPower]:
161161
generation = await self._get_actual_power_production_for_location(
@@ -168,7 +168,7 @@ async def get_site_generation(
168168
@override
169169
async def post_site_generation(
170170
self,
171-
site_uuid: str,
171+
site_uuid: UUID,
172172
generation: list[models.ActualPower],
173173
authdata: dict[str, str],
174174
) -> None:
@@ -196,7 +196,7 @@ async def get_substations(
196196
substation_name=loc.location_name,
197197
substation_type="primary"
198198
if loc.location_type == dp.LocationType.PRIMARY_SUBSTATION
199-
else "unknown",
199+
else "secondary",
200200
capacity_kw=loc.effective_capacity_watts // 1000.0,
201201
latitude=loc.latlng.latitude,
202202
longitude=loc.latlng.longitude,
@@ -288,22 +288,22 @@ async def get_substation(
288288

289289
async def _get_actual_power_production_for_location(
290290
self,
291-
location: str,
291+
location_uuid: UUID,
292292
energy_source: dp.EnergySource,
293293
oauth_id: str | None,
294294
) -> list[models.ActualPower]:
295295
"""Local function to retrieve actual values regardless of energy type."""
296296
if oauth_id is not None:
297297
await self._check_user_access(
298-
location,
298+
location_uuid,
299299
energy_source,
300300
dp.LocationType.SITE,
301301
oauth_id,
302302
)
303303

304304
start, end = get_window()
305305
req = dp.GetObservationsAsTimeseriesRequest(
306-
location_uuid=location,
306+
location_uuid=location_uuid,
307307
observer_name="ruvnl",
308308
energy_source=energy_source,
309309
time_window=dp.TimeWindow(
@@ -324,7 +324,7 @@ async def _get_actual_power_production_for_location(
324324

325325
async def _get_predicted_power_production_for_location(
326326
self,
327-
location: str,
327+
location_uuid: UUID,
328328
energy_source: dp.EnergySource,
329329
oauth_id: str | None,
330330
forecast_horizon: models.ForecastHorizon = models.ForecastHorizon.latest,
@@ -334,7 +334,7 @@ async def _get_predicted_power_production_for_location(
334334
"""Local function to retrieve predicted values regardless of energy type."""
335335
if oauth_id is not None:
336336
_ = await self._check_user_access(
337-
location,
337+
location_uuid,
338338
energy_source,
339339
dp.LocationType.SITE,
340340
oauth_id,
@@ -354,7 +354,7 @@ async def _get_predicted_power_production_for_location(
354354
# taking into account the desired horizon.
355355
# * At some point, we may want to allow the user to specify a particular forecaster.
356356
req = dp.GetLatestForecastsRequest(
357-
location_uuid=location,
357+
location_uuid=location_uuid,
358358
energy_source=energy_source,
359359
pivot_timestamp_utc=start - dt.timedelta(minutes=forecast_horizon_minutes),
360360
)
@@ -368,7 +368,7 @@ async def _get_predicted_power_production_for_location(
368368
forecaster = resp.forecasts[0].forecaster
369369

370370
req = dp.GetForecastAsTimeseriesRequest(
371-
location_uuid=location,
371+
location_uuid=location_uuid,
372372
energy_source=energy_source,
373373
horizon_mins=forecast_horizon_minutes,
374374
time_window=dp.TimeWindow(
@@ -391,14 +391,14 @@ async def _get_predicted_power_production_for_location(
391391

392392
async def _check_user_access(
393393
self,
394-
location: str,
394+
location_uuid: UUID,
395395
energy_source: dp.EnergySource,
396396
location_type: dp.LocationType,
397397
oauth_id: str,
398398
) -> bool:
399399
"""Check if a user has access to a given location."""
400400
req = dp.ListLocationsRequest(
401-
location_uuids_filter=[location],
401+
location_uuids_filter=[location_uuid],
402402
energy_source_filter=energy_source,
403403
location_type_filter=location_type,
404404
user_oauth_id_filter=oauth_id,
@@ -407,6 +407,6 @@ async def _check_user_access(
407407
if len(resp.locations) == 0:
408408
raise HTTPException(
409409
status_code=404,
410-
detail=f"No location found for UUID {location} and OAuth ID {oauth_id}",
410+
detail=f"No location found for UUID {location_uuid} and OAuth ID {oauth_id}",
411411
)
412412
return True

src/quartz_api/internal/backends/dataplatform/test_client.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,20 +133,20 @@ async def test_get_site_forecast(self, client_mock: dp.DataPlatformDataServiceSt
133133
@dataclasses.dataclass
134134
class TestCase:
135135
name: str
136-
site_uuid: str
136+
site_uuid: uuid.UUID
137137
authdata: dict[str, str]
138138
should_error: bool
139139

140140
testcases: list[TestCase] = [
141141
TestCase(
142142
name="Should return forecast when user has access",
143-
site_uuid=str(uuid.uuid4()),
143+
site_uuid=uuid.uuid4(),
144144
authdata={"sub": "access_user"},
145145
should_error=False,
146146
),
147147
TestCase(
148148
name="Should raise HTTPException when user has no access",
149-
site_uuid=str(uuid.uuid4()),
149+
site_uuid=uuid.uuid4(),
150150
authdata={"sub": "no_access_user"},
151151
should_error=True,
152152
),
@@ -180,20 +180,20 @@ async def test_get_site_generation(
180180
@dataclasses.dataclass
181181
class TestCase:
182182
name: str
183-
site_uuid: str
183+
site_uuid: uuid.UUID
184184
authdata: dict[str, str]
185185
should_error: bool
186186

187187
testcases: list[TestCase] = [
188188
TestCase(
189189
name="Should return generation when user has access",
190-
site_uuid=str(uuid.uuid4()),
190+
site_uuid=uuid.uuid4(),
191191
authdata={"sub": "access_user"},
192192
should_error=False,
193193
),
194194
TestCase(
195195
name="Should raise HTTPException when user has no access",
196-
site_uuid=str(uuid.uuid4()),
196+
site_uuid=uuid.uuid4(),
197197
authdata={"sub": "no_access_user"},
198198
should_error=True,
199199
),
@@ -260,20 +260,20 @@ async def test_get_substation(
260260
@dataclasses.dataclass
261261
class TestCase:
262262
name: str
263-
location_uuid: str
263+
location_uuid: uuid.UUID
264264
authdata: dict[str, str]
265265
should_error: bool
266266

267267
testcases: list[TestCase] = [
268268
TestCase(
269269
name="Should return substation when user has access",
270-
location_uuid=str(uuid.uuid4()),
270+
location_uuid=uuid.uuid4(),
271271
authdata={"sub": "access_user"},
272272
should_error=False,
273273
),
274274
TestCase(
275275
name="Should raise HTTPException when user has no access",
276-
location_uuid=str(uuid.uuid4()),
276+
location_uuid=uuid.uuid4(),
277277
authdata={"sub": "no_access_user"},
278278
should_error=True,
279279
),
@@ -305,15 +305,15 @@ async def test_get_substation_forecast(
305305
@dataclasses.dataclass
306306
class TestCase:
307307
name: str
308-
substation_uuid: str
308+
substation_uuid: uuid.UUID
309309
authdata: dict[str, str]
310310
expected_values: list[float]
311311
should_error: bool
312312

313313
testcases: list[TestCase] = [
314314
TestCase(
315315
name="Should return GSP-scaled forecast when user has access",
316-
substation_uuid=str(uuid.uuid4()),
316+
substation_uuid=uuid.uuid4(),
317317
authdata={"sub": "access_user"},
318318
# The forecast returns 5e5 watts for every value, and the substation's
319319
# effective capacity is 1e5 watts (10% of the GSP's 1e6 watts), so
@@ -323,7 +323,7 @@ class TestCase:
323323
),
324324
TestCase(
325325
name="Should raise HTTPException when user has no access",
326-
substation_uuid=str(uuid.uuid4()),
326+
substation_uuid=uuid.uuid4(),
327327
authdata={"sub": "no_access_user"},
328328
expected_values=[],
329329
should_error=True,

0 commit comments

Comments
 (0)