Skip to content

Commit 4bde5a6

Browse files
committed
fix(infrastructure): replace taskiq with remote_funcs (#62)
1 parent ecd4795 commit 4bde5a6

20 files changed

+304
-470
lines changed

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ dependencies = [
2525
"openai==1.97.0",
2626
"structlog==25.4.0",
2727
"structlog-sentry==2.2.1",
28-
"taskiq==0.11.18",
2928
]
3029

3130
[dependency-groups]
@@ -117,7 +116,6 @@ ignore = [
117116

118117
[tool.ruff.lint.per-file-ignores]
119118
"src/ttt/infrastructure/sqlalchemy/tables/__init__.py" = ["F401"]
120-
"src/ttt/infrastructure/taskiq/tasks/__init__.py" = ["F401"]
121119
"src/ttt/infrastructure/adapters/*" = ["ARG002"]
122120
"src/ttt/presentation/adapters/*" = ["ARG002"]
123121
"src/ttt/presentation/*" = ["RUF001"]

src/ttt/infrastructure/adapters/game_log.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ async def already_completed_game_to_make_ai_move(
8888
/,
8989
) -> None:
9090
await self._logger.ainfo(
91-
"already_completed_game_to_make_move",
91+
"already_completed_game_to_make_ai_move",
9292
ai_id=ai_id.hex,
9393
game_id=game.id.hex,
9494
)

src/ttt/infrastructure/adapters/game_tasks.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@
22
from uuid import UUID
33

44
from ttt.application.game.game.ports.game_tasks import GameTasks
5-
from ttt.infrastructure.taskiq.tasks.make_ai_move_in_game_task import (
6-
make_ai_move_in_game_task,
5+
from ttt.infrastructure.remote_funcs.make_ai_move_in_game import (
6+
make_ai_move_in_game_remotely,
77
)
88

99

1010
@dataclass
11-
class TaskiqGameTasks(GameTasks):
11+
class NatsRemoteFuncGameTasks(GameTasks):
1212
async def make_ai_move(
1313
self,
1414
user_id: int,
1515
game_id: UUID,
1616
ai_id: UUID,
1717
/,
1818
) -> None:
19-
await make_ai_move_in_game_task.kiq(user_id, game_id, ai_id)
19+
await make_ai_move_in_game_remotely(
20+
user_id=user_id, game_id=game_id.hex, ai_id=ai_id.hex,
21+
)

src/ttt/infrastructure/adapters/stars_purchase_tasks.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,21 @@
55
StarsPurchaseTasks,
66
)
77
from ttt.entities.finance.payment.success import PaymentSuccess
8-
from ttt.infrastructure.taskiq.tasks.complete_stars_purchase_payment_task import ( # noqa: E501
9-
complete_stars_purchase_payment_task,
8+
from ttt.infrastructure.remote_funcs.complete_stars_purchase_payment import (
9+
complete_stars_purchase_payment_remotely,
1010
)
1111

1212

1313
@dataclass
14-
class TaskiqStarsPurchaseTasks(StarsPurchaseTasks):
14+
class NatsRemoteFuncStarsPurchaseTasks(StarsPurchaseTasks):
1515
async def complete_stars_purchase_payment(
1616
self,
1717
purchase_id: UUID,
1818
success: PaymentSuccess,
1919
/,
2020
) -> None:
21-
await complete_stars_purchase_payment_task.kiq(
22-
purchase_id,
23-
success.id,
24-
success.gateway_id,
21+
await complete_stars_purchase_payment_remotely(
22+
purchase_id=purchase_id.hex,
23+
payment_success_id=success.id,
24+
payment_success_gateway_id=success.gateway_id,
2525
)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from asyncio import gather
2+
from collections.abc import AsyncIterator
3+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
4+
5+
6+
@asynccontextmanager
7+
async def multi_asynccontextmanager[T](
8+
*managers: AbstractAsyncContextManager[T],
9+
) -> AsyncIterator[list[T]]:
10+
result = await gather(*(manager.__aenter__() for manager in managers)) # noqa: PLC2801
11+
12+
try:
13+
yield result
14+
except BaseException as error: # noqa: BLE001
15+
await gather(*(
16+
manager.__aexit__(type(error), error, error.__traceback__)
17+
for manager in managers
18+
))
19+
else:
20+
await gather(*(
21+
manager.__aexit__(None, None, None)
22+
for manager in managers
23+
))

src/ttt/infrastructure/processors/processors.py

Lines changed: 0 additions & 14 deletions
This file was deleted.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from uuid import UUID
2+
3+
from dishka import AsyncContainer
4+
5+
from ttt.application.stars_purchase.complete_stars_purchase_payment import (
6+
CompleteStarsPurchasePayment,
7+
)
8+
from ttt.entities.finance.payment.success import PaymentSuccess
9+
from ttt.infrastructure.remote_funcs.nats_remote_func import nats_remote
10+
from ttt.infrastructure.retrier import Retrier
11+
12+
13+
@nats_remote(
14+
subject="stars_purchase.stars_purchase.complete_stars_purchase_payment",
15+
pull_subscribe=lambda js, subject: js.pull_subscribe(
16+
subject,
17+
"ttt-stars_purchase-stars_purchase-complete_stars_purchase_payment",
18+
stream="STARS_PURCHASE",
19+
),
20+
)
21+
async def complete_stars_purchase_payment_remotely(
22+
container: AsyncContainer,
23+
*,
24+
purchase_id: str,
25+
payment_success_id: str,
26+
payment_success_gateway_id: str,
27+
) -> None:
28+
payment_success = PaymentSuccess(
29+
payment_success_id, payment_success_gateway_id,
30+
)
31+
32+
retrier = await container.get(Retrier)
33+
complete_stars_purchase_payment = await container.get(
34+
CompleteStarsPurchasePayment,
35+
)
36+
await retrier(
37+
complete_stars_purchase_payment,
38+
UUID(hex=purchase_id),
39+
payment_success,
40+
)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from uuid import UUID
2+
3+
from dishka import AsyncContainer
4+
5+
from ttt.application.game.game.make_ai_move_in_game import MakeAiMoveInGame
6+
from ttt.infrastructure.remote_funcs.nats_remote_func import nats_remote
7+
from ttt.infrastructure.retrier import Retrier
8+
9+
10+
@nats_remote(
11+
subject="game.game.make_ai_move_in_game",
12+
pull_subscribe=lambda js, subject: js.pull_subscribe(
13+
subject,
14+
durable="ttt-game-game-make_ai_move_in_game",
15+
stream="GAME",
16+
),
17+
)
18+
async def make_ai_move_in_game_remotely(
19+
container: AsyncContainer,
20+
*,
21+
user_id: int,
22+
game_id: str,
23+
ai_id: str,
24+
) -> None:
25+
retrier = await container.get(Retrier)
26+
make_ai_move_in_game = await container.get(MakeAiMoveInGame)
27+
28+
await retrier(
29+
make_ai_move_in_game, user_id, UUID(hex=game_id), UUID(hex=ai_id),
30+
)
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import json
2+
from asyncio import (
3+
AbstractEventLoop,
4+
Semaphore,
5+
Task,
6+
gather,
7+
get_event_loop,
8+
)
9+
from collections.abc import (
10+
AsyncIterator,
11+
Callable,
12+
)
13+
from contextlib import asynccontextmanager
14+
from dataclasses import dataclass, field
15+
from typing import Any, Protocol, Self
16+
17+
from dishka import AsyncContainer
18+
from nats.aio.msg import Msg as NatsMessage
19+
from nats.errors import TimeoutError as NatsTimeoutError
20+
from nats.js import JetStreamContext
21+
from structlog.types import FilteringBoundLogger
22+
23+
from ttt.entities.tools.assertion import not_none
24+
from ttt.infrastructure.dishka.next_container import NextContainer
25+
from ttt.infrastructure.structlog.logger import (
26+
unexpected_error_log,
27+
)
28+
29+
30+
class PullSubscribe(Protocol):
31+
async def __call__(
32+
self, js: JetStreamContext, subject: str,
33+
) -> JetStreamContext.PullSubscription: ...
34+
35+
36+
class NatsRemoteFuncBody[**PmT](Protocol):
37+
async def __call__(
38+
self,
39+
container: AsyncContainer,
40+
*args: PmT.args,
41+
**kwargs: PmT.kwargs,
42+
) -> Any: ... # noqa: ANN401
43+
44+
45+
@dataclass
46+
class NatsRemoteFunc[**PmT = ...]:
47+
body: NatsRemoteFuncBody[PmT]
48+
_pull_subscribe: PullSubscribe
49+
_subject: str
50+
51+
_max_workers: int | None = None
52+
_pull_consume_batch: int | None = None
53+
_pull_consume_timeout: float | None = None
54+
55+
_js: JetStreamContext = field(init=False)
56+
_pull_subscription: JetStreamContext.PullSubscription = field(init=False)
57+
_workers: set[Task[None]] = field(init=False, default_factory=set)
58+
_semaphore: Semaphore = field(init=False, default_factory=Semaphore)
59+
_container: AsyncContainer = field(init=False)
60+
_loop: AbstractEventLoop = field(init=False)
61+
62+
@asynccontextmanager
63+
async def startup(
64+
self,
65+
js: JetStreamContext,
66+
max_workers: int | None = None,
67+
pull_consume_batch: int | None = None,
68+
pull_consume_timeout: float | None = None,
69+
) -> AsyncIterator[Self]:
70+
self._loop = get_event_loop()
71+
self._max_workers = max_workers or self._max_workers or 1000
72+
self._pull_consume_batch = (
73+
pull_consume_batch or self._pull_consume_batch or 1
74+
)
75+
self._pull_consume_timeout = (
76+
pull_consume_timeout or self._pull_consume_timeout or 5
77+
)
78+
79+
self._js = js
80+
self._pull_subscription = await self._pull_subscribe(
81+
self._js,
82+
self._subject,
83+
)
84+
self._semaphore = Semaphore(self._max_workers)
85+
86+
try:
87+
yield self
88+
finally:
89+
await gather(*self._workers)
90+
91+
async def __call__(self, *args: PmT.args, **kwargs: PmT.kwargs) -> None:
92+
payload = {"args": args, "kwargs": kwargs}
93+
94+
await not_none(self._js).publish(
95+
self._subject,
96+
payload=json.dumps(payload).encode(),
97+
)
98+
99+
async def processor(self, next_container: NextContainer) -> None:
100+
while True:
101+
try:
102+
messages = await self._pull_subscription.fetch(
103+
batch=not_none(self._pull_consume_batch),
104+
timeout=self._pull_consume_timeout,
105+
)
106+
except NatsTimeoutError:
107+
continue
108+
109+
for message in messages:
110+
await self._create_worker(message, next_container)
111+
112+
async def _create_worker(
113+
self, message: NatsMessage, next_container: NextContainer,
114+
) -> None:
115+
await self._semaphore.acquire()
116+
task = self._loop.create_task(self._worker(message, next_container))
117+
self._workers.add(task)
118+
task.add_done_callback(self._workers.discard)
119+
120+
async def _worker(
121+
self, message: NatsMessage, next_container: NextContainer,
122+
) -> None:
123+
async with next_container() as container:
124+
try:
125+
json_str = message.data.decode()
126+
json_ = json.loads(json_str)
127+
await self.body(container, *json_["args"], **json_["kwargs"])
128+
except Exception as error: # noqa: BLE001
129+
await message.nak()
130+
131+
logger = await container.get(FilteringBoundLogger)
132+
await unexpected_error_log(logger, error)
133+
134+
self._semaphore.release()
135+
except BaseException as error:
136+
await message.nak()
137+
self._semaphore.release()
138+
raise error from error
139+
else:
140+
await message.ack()
141+
self._semaphore.release()
142+
143+
144+
def nats_remote[**PmT](
145+
subject: str,
146+
pull_subscribe: PullSubscribe,
147+
max_workers: int | None = None,
148+
pull_consume_batch: int | None = None,
149+
pull_consume_timeout: float | None = None,
150+
) -> Callable[[NatsRemoteFuncBody[PmT]], NatsRemoteFunc[PmT]]:
151+
def decorator(body: NatsRemoteFuncBody[PmT]) -> NatsRemoteFunc[PmT]:
152+
return NatsRemoteFunc(
153+
body,
154+
pull_subscribe,
155+
subject,
156+
max_workers,
157+
pull_consume_batch,
158+
pull_consume_timeout,
159+
)
160+
161+
return decorator

0 commit comments

Comments
 (0)