Skip to content

Commit a1d4ea8

Browse files
committed
fix(adapters): implement new transactions (#62)
1 parent 0a12c0a commit a1d4ea8

File tree

2 files changed

+84
-25
lines changed

2 files changed

+84
-25
lines changed
Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,56 @@
1-
from dataclasses import dataclass, field
1+
from dataclasses import dataclass
22
from types import TracebackType
33
from typing import Self
44

5-
from sqlalchemy.ext.asyncio import AsyncSession, AsyncSessionTransaction
5+
from sqlalchemy.ext.asyncio import AsyncSession
66

7-
from ttt.application.common.ports.transaction import Transaction
8-
from ttt.entities.tools.assertion import not_none
7+
from ttt.application.common.ports.transaction import (
8+
NotSerializableTransaction,
9+
ReadonlyTransaction,
10+
SerializableTransaction,
11+
)
12+
from ttt.entities.tools.assertion import assert_, not_none
13+
from ttt.infrastructure.sqlalchemy.serialization import (
14+
reraise_serialization_error,
15+
)
916

1017

1118
@dataclass
12-
class InPostgresTransaction(Transaction):
19+
class InPostgresSerializableTransaction(SerializableTransaction):
1320
_session: AsyncSession
14-
_transaction: AsyncSessionTransaction | None = field(
15-
init=False,
16-
default=None,
17-
)
18-
_nesting_counter: int = field(
19-
init=False,
20-
default=0,
21-
)
2221

2322
async def __aenter__(self) -> Self:
24-
self._nesting_counter += 1
23+
assert_(not self._session.in_transaction())
24+
await self._session.connection(
25+
execution_options={"isolation_level": "SERIALIZABLE"},
26+
)
27+
return self
2528

26-
if self._transaction is None:
27-
self._transaction = await self._session.begin()
28-
elif not self._transaction.is_active:
29-
await self._session.rollback()
30-
self._transaction = await self._session.begin()
29+
async def __aexit__(
30+
self,
31+
error_type: type[BaseException] | None,
32+
error: BaseException | None,
33+
traceback: TracebackType | None,
34+
) -> None:
35+
transaction = not_none(self._session.get_transaction())
36+
with reraise_serialization_error():
37+
await transaction.__aexit__(error_type, error, traceback)
3138

39+
async def commit(self) -> None:
40+
transaction = not_none(self._session.get_transaction())
41+
with reraise_serialization_error():
42+
await transaction.commit()
43+
44+
45+
@dataclass
46+
class InPostgresNotSerializableTransaction(NotSerializableTransaction):
47+
_session: AsyncSession
48+
49+
async def __aenter__(self) -> Self:
50+
assert_(not self._session.in_transaction())
51+
await self._session.connection(
52+
execution_options={"isolation_level": "READ COMMITED"},
53+
)
3254
return self
3355

3456
async def __aexit__(
@@ -37,10 +59,29 @@ async def __aexit__(
3759
error: BaseException | None,
3860
traceback: TracebackType | None,
3961
) -> None:
40-
self._nesting_counter -= 1
62+
transaction = not_none(self._session.get_transaction())
63+
await transaction.__aexit__(error_type, error, traceback)
4164

42-
if self._nesting_counter == 0:
43-
transaction = not_none(self._transaction)
44-
await transaction.__aexit__(error_type, error, traceback)
45-
self._transaction = None
46-
return
65+
async def commit(self) -> None:
66+
transaction = not_none(self._session.get_transaction())
67+
await transaction.commit()
68+
69+
70+
@dataclass
71+
class InPostgresReadonlyTransaction(ReadonlyTransaction):
72+
_session: AsyncSession
73+
74+
async def __aenter__(self) -> Self:
75+
assert_(not self._session.in_transaction())
76+
options = {"isolation_level": "SERIALIZABLE", "readonly": True}
77+
await self._session.connection(execution_options=options)
78+
return self
79+
80+
async def __aexit__(
81+
self,
82+
error_type: type[BaseException] | None,
83+
error: BaseException | None,
84+
traceback: TracebackType | None,
85+
) -> None:
86+
transaction = not_none(self._session.get_transaction())
87+
await transaction.__aexit__(error_type, error, traceback)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from collections.abc import Iterator
2+
from contextlib import contextmanager
3+
4+
from psycopg.errors import SerializationFailure
5+
from sqlalchemy.exc import OperationalError
6+
7+
from ttt.application.common.errors.serialization_error import SerializationError
8+
9+
10+
@contextmanager
11+
def reraise_serialization_error() -> Iterator[None]:
12+
try:
13+
yield
14+
except OperationalError as error:
15+
if isinstance(error.orig, SerializationFailure):
16+
raise SerializationError from error
17+
18+
raise error from error

0 commit comments

Comments
 (0)