Skip to content

Commit 7d7a85e

Browse files
committed
✨ Refactor database connection mapping and introduce atomic transaction helpers
1 parent 0e4606c commit 7d7a85e

File tree

2 files changed

+69
-28
lines changed

2 files changed

+69
-28
lines changed

src/database/config/__init__.py

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tortoise import Tortoise
99

1010
from src.config import config
11-
from src.config.models import DbExtraApp, DbParams
11+
from src.config.models import DbParams
1212

1313
logger = getLogger("bot").getChild("database")
1414

@@ -25,44 +25,63 @@ def apply_params(uri: str, params: DbParams | None) -> str:
2525
return uri
2626

2727

28-
app_url_mapping: dict[str, str] = {
29-
app_name: apply_params(app_config.url or config.db.url, app_config.params)
30-
for app_name, app_config in config.db.extra_apps.items()
31-
}
28+
def get_app_url_mapping() -> dict[str, str]:
29+
return {
30+
app_name: apply_params(app_config.url or config.db.url, app_config.params)
31+
for app_name, app_config in config.db.extra_apps.items()
32+
}
3233

33-
url_apps_mapping: dict[str, list[str]] = defaultdict(list)
3434

35-
for app_name, app_url in app_url_mapping.items():
36-
url_apps_mapping[app_url].append(app_name)
35+
def get_url_apps_mapping() -> dict[str, list[str]]:
36+
app_url_mapping: dict[str, str] = get_app_url_mapping()
37+
mapping: dict[str, list[str]] = defaultdict(list)
3738

38-
app_connections_mapping: dict[str, str] = {}
39-
connection_url_mapping: dict[str, str] = {}
39+
for app_name, app_url in app_url_mapping.items():
40+
mapping[app_url].append(app_name)
4041

41-
i: int = 0
42+
return mapping
4243

43-
for url, apps in url_apps_mapping.items():
44-
connection_name = f"connection_{i}"
45-
connection_url_mapping[connection_name] = url
46-
for app in apps:
47-
app_connections_mapping[app] = connection_name
48-
i += 1 # noqa: SIM113 # Incompatible with .items()
4944

50-
app_connections_mapping["models"] = "default"
51-
connection_url_mapping["default"] = apply_params(config.db.url, config.db.params)
45+
def parse_url_apps_mapping(url_apps_mapping: dict[str, list[str]]) -> tuple[dict[str, str], dict[str, str]]:
46+
app_connection: dict[str, str] = {}
47+
connection_url: dict[str, str] = {}
5248

53-
config.db.extra_apps["models"] = DbExtraApp(
54-
models=["src.database.models", "aerich.models"],
55-
)
49+
for i, (url, apps) in enumerate(url_apps_mapping.items()):
50+
connection_name = f"connection_{i}"
51+
connection_url[connection_name] = url
52+
for app in apps:
53+
app_connection[app] = connection_name
5654

57-
TORTOISE_ORM = {
58-
"connections": connection_url_mapping,
59-
"apps": {
55+
app_connection["models"] = "default"
56+
connection_url["default"] = apply_params(config.db.url, config.db.params)
57+
58+
return app_connection, connection_url
59+
60+
61+
APP_CONNECTION_MAPPING: dict[str, str]
62+
CONNECTION_URL_MAPPING: dict[str, str]
63+
64+
APP_CONNECTION_MAPPING, CONNECTION_URL_MAPPING = parse_url_apps_mapping(get_url_apps_mapping()) # pyright: ignore[reportConstantRedefinition]
65+
66+
67+
def get_apps() -> dict[str, dict[str, list[str] | str]]:
68+
apps = {
6069
app_name: {
6170
"models": app.models,
62-
"default_connection": app_connections_mapping[app_name],
71+
"default_connection": APP_CONNECTION_MAPPING[app_name],
6372
}
6473
for app_name, app in config.db.extra_apps.items()
65-
},
74+
}
75+
apps["models"] = {
76+
"models": ["src.database.models", "aerich.models"],
77+
"default_connection": "default",
78+
}
79+
return apps
80+
81+
82+
TORTOISE_ORM = {
83+
"connections": CONNECTION_URL_MAPPING,
84+
"apps": get_apps(),
6685
}
6786

6887

@@ -82,4 +101,4 @@ async def shutdown() -> None:
82101
await Tortoise.close_connections()
83102

84103

85-
__all__ = ["init", "shutdown"]
104+
__all__ = ["APP_CONNECTION_MAPPING", "init", "shutdown"]

src/database/utils/atomics.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) NiceBots
2+
# SPDX-License-Identifier: MIT
3+
4+
from collections.abc import Callable
5+
from typing import Any
6+
7+
from tortoise.backends.base.client import TransactionContext
8+
from tortoise.transactions import atomic as t_atomic
9+
from tortoise.transactions import in_transaction as t_in_transaction
10+
11+
from src.database.config import APP_CONNECTION_MAPPING
12+
13+
14+
def atomic[F: Callable[..., Any]](connection_name: str | None = None) -> Callable[[F], F]:
15+
return t_atomic(APP_CONNECTION_MAPPING[connection_name or "models"])
16+
17+
18+
def in_transaction(app_name: str | None = None) -> TransactionContext: # pyright: ignore[reportMissingTypeArgument, reportUnknownParameterType]
19+
return t_in_transaction(APP_CONNECTION_MAPPING[app_name or "models"]) # pyright: ignore[reportUnknownVariableType]
20+
21+
22+
__all__ = ("atomic", "in_transaction")

0 commit comments

Comments
 (0)