11from collections import namedtuple
22from typing import Any , Dict , List , Optional , Tuple , Type
33
4+ from httpx import Headers , Request
5+
46from firebolt .client .auth .base import Auth
57from firebolt .common ._types import ColType
8+ from firebolt .common .constants import (
9+ REMOVE_PARAMETERS_HEADER ,
10+ RESET_SESSION_HEADER ,
11+ TRANSACTION_ID_SETTING ,
12+ TRANSACTION_SEQUENCE_ID_SETTING ,
13+ UPDATE_PARAMETERS_HEADER ,
14+ )
615from firebolt .utils .cache import (
716 ConnectionInfo ,
817 EngineInfo ,
918 SecureCacheKey ,
1019 _firebolt_cache ,
1120)
12- from firebolt .utils .exception import ConnectionClosedError , FireboltError
21+ from firebolt .utils .exception import FireboltError
1322from firebolt .utils .usage_tracker import (
1423 get_cache_tracking_params ,
1524 get_user_agent_header ,
1625)
26+ from firebolt .utils .util import (
27+ _parse_remove_parameters ,
28+ _parse_update_parameters ,
29+ )
1730
1831ASYNC_QUERY_STATUS_RUNNING = "RUNNING"
1932ASYNC_QUERY_STATUS_SUCCESSFUL = "ENDED_SUCCESSFULLY"
@@ -68,6 +81,8 @@ def __init__(self, cursor_type: Type) -> None:
6881 self .cursor_type = cursor_type
6982 self ._cursors : List [Any ] = []
7083 self ._is_closed = False
84+ self ._transaction_id : Optional [str ] = None
85+ self ._transaction_sequence_id : Optional [str ] = None
7186
7287 def _remove_cursor (self , cursor : Any ) -> None :
7388 # This way it's atomic
@@ -76,17 +91,62 @@ def _remove_cursor(self, cursor: Any) -> None:
7691 except ValueError :
7792 pass
7893
94+ def in_transaction (self ) -> bool :
95+ """`True` if connection is in a transaction; `False` otherwise."""
96+ return self ._transaction_id is not None
97+
98+ def _parse_response_headers_transaction (self , headers : Headers ) -> None :
99+ parameters_header = headers .get (UPDATE_PARAMETERS_HEADER )
100+ if not parameters_header :
101+ return
102+ parameters = _parse_update_parameters (parameters_header )
103+ transaction_id = parameters .get (TRANSACTION_ID_SETTING )
104+ if transaction_id :
105+ self ._transaction_id = transaction_id
106+ sequence_id = parameters .get (TRANSACTION_SEQUENCE_ID_SETTING )
107+ if sequence_id :
108+ self ._transaction_sequence_id = sequence_id
109+
110+ def _parse_remove_headers_transaction (self , headers : Headers ) -> None :
111+ parameters_header = headers .get (REMOVE_PARAMETERS_HEADER )
112+ if not parameters_header :
113+ return
114+ parameters = _parse_remove_parameters (parameters_header )
115+ for param in parameters :
116+ if param == TRANSACTION_ID_SETTING :
117+ self ._transaction_id = None
118+ elif param == TRANSACTION_SEQUENCE_ID_SETTING :
119+ self ._transaction_sequence_id = None
120+
121+ def _reset_transaction_state (self ) -> None :
122+ self ._transaction_id = None
123+ self ._transaction_sequence_id = None
124+
125+ def create_transaction_params (self ) -> Dict [str , str ]:
126+ params : Dict [str , str ] = {}
127+ if self ._transaction_id :
128+ params [TRANSACTION_ID_SETTING ] = self ._transaction_id
129+ if self ._transaction_sequence_id is not None :
130+ params [TRANSACTION_SEQUENCE_ID_SETTING ] = str (self ._transaction_sequence_id )
131+ return params
132+
133+ def _add_transaction_headers (self , request : Request ) -> None :
134+ transaction_params = self .create_transaction_params ()
135+ for key , value in transaction_params .items ():
136+ request .headers [key ] = value
137+
138+ def _handle_transaction_updates (self , headers : Headers ) -> None :
139+ self ._parse_response_headers_transaction (headers )
140+ if headers .get (RESET_SESSION_HEADER ):
141+ self ._reset_transaction_state ()
142+ if headers .get (REMOVE_PARAMETERS_HEADER ):
143+ self ._parse_remove_headers_transaction (headers )
144+
79145 @property
80146 def closed (self ) -> bool :
81147 """`True` if connection is closed; `False` otherwise."""
82148 return self ._is_closed
83149
84- def commit (self ) -> None :
85- """Does nothing since Firebolt doesn't have transactions."""
86-
87- if self .closed :
88- raise ConnectionClosedError ("Unable to commit: Connection closed." )
89-
90150
91151def get_cached_system_engine_info (
92152 auth : Auth ,
0 commit comments