@@ -85,6 +85,7 @@ class Connection(BaseConnection):
8585 "client_class" ,
8686 "cursor_type" ,
8787 "id" ,
88+ "_autocommit" ,
8889 )
8990
9091 def __init__ (
@@ -96,6 +97,7 @@ def __init__(
9697 api_endpoint : str ,
9798 init_parameters : Optional [Dict [str , Any ]] = None ,
9899 id : str = uuid4 ().hex ,
100+ autocommit : bool = True ,
99101 ):
100102 super ().__init__ (cursor_type )
101103 self .api_endpoint = api_endpoint
@@ -105,6 +107,7 @@ def __init__(
105107 self .id = id
106108 self ._transaction_lock : trio .Lock = trio .Lock ()
107109 self .init_parameters = init_parameters or {}
110+ self ._autocommit = autocommit
108111 if database :
109112 self .init_parameters ["database" ] = database
110113
@@ -198,14 +201,26 @@ async def cancel_async_query(self, token: str) -> None:
198201 await cursor .execute (ASYNC_QUERY_CANCEL , [async_query_info [0 ].query_id ])
199202
200203 async def _execute_query_impl (self , request : Request ) -> Response :
201- self ._add_transaction_headers (request )
204+ self ._add_transaction_params (request )
202205 response = await self ._client .send (request , stream = True )
203206 self ._handle_transaction_updates (response .headers )
204207 return response
205208
209+ async def _begin_nolock (self , request : Request ) -> None :
210+ """Begin a transaction without a lock. Used internally."""
211+ # Create a copy of the request with "BEGIN" as the body content
212+ begin_request = self ._client .build_request (
213+ request .method , request .url , content = "BEGIN"
214+ )
215+ response = await self ._client .send (begin_request , stream = True )
216+ self ._handle_transaction_updates (response .headers )
217+
206218 async def _execute_query (self , request : Request ) -> Response :
207- if self .in_transaction () :
219+ if self .in_transaction or not self . autocommit :
208220 async with self ._transaction_lock :
221+ # If autocommit is off we need to explicitly begin a transaction
222+ if not self .in_transaction :
223+ await self ._begin_nolock (request )
209224 return await self ._execute_query_impl (request )
210225 else :
211226 return await self ._execute_query_impl (request )
@@ -227,6 +242,9 @@ async def aclose(self) -> None:
227242 if self .closed :
228243 return
229244
245+ if self .in_transaction :
246+ await self .cursor ().execute ("ROLLBACK" )
247+
230248 # self._cursors is going to be changed during closing cursors
231249 # after this point no cursors would be added to _cursors, only removed since
232250 # closing lock is held, and later connection will be marked as closed
@@ -253,6 +271,7 @@ async def connect(
253271 api_endpoint : str = DEFAULT_API_URL ,
254272 disable_cache : bool = False ,
255273 url : Optional [str ] = None ,
274+ autocommit : bool = True ,
256275 additional_parameters : Dict [str , Any ] = {},
257276) -> Connection :
258277 # auth parameter is optional in function signature
@@ -280,6 +299,7 @@ async def connect(
280299 user_agent_header = user_agent_header ,
281300 database = database ,
282301 connection_url = url ,
302+ autocommit = autocommit ,
283303 )
284304 elif auth_version == FireboltAuthVersion .V2 :
285305 assert account_name is not None
@@ -292,6 +312,7 @@ async def connect(
292312 api_endpoint = api_endpoint ,
293313 connection_id = connection_id ,
294314 disable_cache = disable_cache ,
315+ autocommit = autocommit ,
295316 )
296317 elif auth_version == FireboltAuthVersion .V1 :
297318 return await connect_v1 (
@@ -317,6 +338,7 @@ async def connect_v2(
317338 engine_name : Optional [str ] = None ,
318339 api_endpoint : str = DEFAULT_API_URL ,
319340 disable_cache : bool = False ,
341+ autocommit : bool = True ,
320342) -> Connection :
321343 """Connect to Firebolt.
322344
@@ -380,6 +402,7 @@ async def connect_v2(
380402 api_endpoint ,
381403 cursor .parameters | cursor ._set_parameters ,
382404 connection_id ,
405+ autocommit ,
383406 )
384407
385408
@@ -447,6 +470,7 @@ def connect_core(
447470 user_agent_header : str ,
448471 database : Optional [str ] = None ,
449472 connection_url : Optional [str ] = None ,
473+ autocommit : bool = True ,
450474) -> Connection :
451475 """Connect to Firebolt Core.
452476
@@ -484,6 +508,7 @@ def connect_core(
484508 client = client ,
485509 cursor_type = CursorV2 ,
486510 api_endpoint = verified_url ,
511+ autocommit = autocommit ,
487512 )
488513
489514
0 commit comments