1- from dataclasses import dataclass , field
1+ from dataclasses import dataclass
22from types import TracebackType
33from typing import Self
44
5- from sqlalchemy .ext .asyncio import AsyncSession , AsyncSessionTransaction
5+ from sqlalchemy .ext .asyncio import AsyncSession
66
7- from ttt .application .common .ports .transaction import Transaction
8- from ttt .entities .tools .assertion import not_none
7+ from ttt .application .common .ports .transaction import (
8+ NotSerializableTransaction ,
9+ ReadonlyTransaction ,
10+ SerializableTransaction ,
11+ )
12+ from ttt .entities .tools .assertion import assert_ , not_none
13+ from ttt .infrastructure .sqlalchemy .serialization import (
14+ reraise_serialization_error ,
15+ )
916
1017
1118@dataclass
12- class InPostgresTransaction ( Transaction ):
19+ class InPostgresSerializableTransaction ( SerializableTransaction ):
1320 _session : AsyncSession
14- _transaction : AsyncSessionTransaction | None = field (
15- init = False ,
16- default = None ,
17- )
18- _nesting_counter : int = field (
19- init = False ,
20- default = 0 ,
21- )
2221
2322 async def __aenter__ (self ) -> Self :
24- self ._nesting_counter += 1
23+ assert_ (not self ._session .in_transaction ())
24+ await self ._session .connection (
25+ execution_options = {"isolation_level" : "SERIALIZABLE" },
26+ )
27+ return self
2528
26- if self ._transaction is None :
27- self ._transaction = await self ._session .begin ()
28- elif not self ._transaction .is_active :
29- await self ._session .rollback ()
30- self ._transaction = await self ._session .begin ()
29+ async def __aexit__ (
30+ self ,
31+ error_type : type [BaseException ] | None ,
32+ error : BaseException | None ,
33+ traceback : TracebackType | None ,
34+ ) -> None :
35+ transaction = not_none (self ._session .get_transaction ())
36+ with reraise_serialization_error ():
37+ await transaction .__aexit__ (error_type , error , traceback )
3138
39+ async def commit (self ) -> None :
40+ transaction = not_none (self ._session .get_transaction ())
41+ with reraise_serialization_error ():
42+ await transaction .commit ()
43+
44+
45+ @dataclass
46+ class InPostgresNotSerializableTransaction (NotSerializableTransaction ):
47+ _session : AsyncSession
48+
49+ async def __aenter__ (self ) -> Self :
50+ assert_ (not self ._session .in_transaction ())
51+ await self ._session .connection (
52+ execution_options = {"isolation_level" : "READ COMMITED" },
53+ )
3254 return self
3355
3456 async def __aexit__ (
@@ -37,10 +59,29 @@ async def __aexit__(
3759 error : BaseException | None ,
3860 traceback : TracebackType | None ,
3961 ) -> None :
40- self ._nesting_counter -= 1
62+ transaction = not_none (self ._session .get_transaction ())
63+ await transaction .__aexit__ (error_type , error , traceback )
4164
42- if self ._nesting_counter == 0 :
43- transaction = not_none (self ._transaction )
44- await transaction .__aexit__ (error_type , error , traceback )
45- self ._transaction = None
46- return
65+ async def commit (self ) -> None :
66+ transaction = not_none (self ._session .get_transaction ())
67+ await transaction .commit ()
68+
69+
70+ @dataclass
71+ class InPostgresReadonlyTransaction (ReadonlyTransaction ):
72+ _session : AsyncSession
73+
74+ async def __aenter__ (self ) -> Self :
75+ assert_ (not self ._session .in_transaction ())
76+ options = {"isolation_level" : "SERIALIZABLE" , "readonly" : True }
77+ await self ._session .connection (execution_options = options )
78+ return self
79+
80+ async def __aexit__ (
81+ self ,
82+ error_type : type [BaseException ] | None ,
83+ error : BaseException | None ,
84+ traceback : TracebackType | None ,
85+ ) -> None :
86+ transaction = not_none (self ._session .get_transaction ())
87+ await transaction .__aexit__ (error_type , error , traceback )
0 commit comments