Skip to content

Commit ddcf11d

Browse files
committed
autocommit
1 parent 38b13cf commit ddcf11d

File tree

6 files changed

+68
-10
lines changed

6 files changed

+68
-10
lines changed

src/firebolt/async_db/connection.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class Connection(BaseConnection):
8585
"client_class",
8686
"cursor_type",
8787
"id",
88+
"_autocommit",
8889
)
8990

9091
def __init__(
@@ -96,6 +97,7 @@ def __init__(
9697
api_endpoint: str,
9798
init_parameters: Optional[Dict[str, Any]] = None,
9899
id: str = uuid4().hex,
100+
autocommit: bool = True,
99101
):
100102
super().__init__(cursor_type)
101103
self.api_endpoint = api_endpoint
@@ -105,6 +107,7 @@ def __init__(
105107
self.id = id
106108
self._transaction_lock: trio.Lock = trio.Lock()
107109
self.init_parameters = init_parameters or {}
110+
self._autocommit = autocommit
108111
if database:
109112
self.init_parameters["database"] = database
110113

@@ -198,14 +201,26 @@ async def cancel_async_query(self, token: str) -> None:
198201
await cursor.execute(ASYNC_QUERY_CANCEL, [async_query_info[0].query_id])
199202

200203
async def _execute_query_impl(self, request: Request) -> Response:
201-
self._add_transaction_headers(request)
204+
self._add_transaction_params(request)
202205
response = await self._client.send(request, stream=True)
203206
self._handle_transaction_updates(response.headers)
204207
return response
205208

209+
async def _begin_nolock(self, request: Request) -> None:
210+
"""Begin a transaction without a lock. Used internally."""
211+
# Create a copy of the request with "BEGIN" as the body content
212+
begin_request = self._client.build_request(
213+
request.method, request.url, content="BEGIN"
214+
)
215+
response = await self._client.send(begin_request, stream=True)
216+
self._handle_transaction_updates(response.headers)
217+
206218
async def _execute_query(self, request: Request) -> Response:
207-
if self.in_transaction():
219+
if self.in_transaction or not self.autocommit:
208220
async with self._transaction_lock:
221+
# If autocommit is off we need to explicitly begin a transaction
222+
if not self.in_transaction:
223+
await self._begin_nolock(request)
209224
return await self._execute_query_impl(request)
210225
else:
211226
return await self._execute_query_impl(request)
@@ -227,6 +242,9 @@ async def aclose(self) -> None:
227242
if self.closed:
228243
return
229244

245+
if self.in_transaction:
246+
await self.cursor().execute("ROLLBACK")
247+
230248
# self._cursors is going to be changed during closing cursors
231249
# after this point no cursors would be added to _cursors, only removed since
232250
# closing lock is held, and later connection will be marked as closed
@@ -253,6 +271,7 @@ async def connect(
253271
api_endpoint: str = DEFAULT_API_URL,
254272
disable_cache: bool = False,
255273
url: Optional[str] = None,
274+
autocommit: bool = True,
256275
additional_parameters: Dict[str, Any] = {},
257276
) -> Connection:
258277
# auth parameter is optional in function signature
@@ -280,6 +299,7 @@ async def connect(
280299
user_agent_header=user_agent_header,
281300
database=database,
282301
connection_url=url,
302+
autocommit=autocommit,
283303
)
284304
elif auth_version == FireboltAuthVersion.V2:
285305
assert account_name is not None
@@ -292,6 +312,7 @@ async def connect(
292312
api_endpoint=api_endpoint,
293313
connection_id=connection_id,
294314
disable_cache=disable_cache,
315+
autocommit=autocommit,
295316
)
296317
elif auth_version == FireboltAuthVersion.V1:
297318
return await connect_v1(
@@ -317,6 +338,7 @@ async def connect_v2(
317338
engine_name: Optional[str] = None,
318339
api_endpoint: str = DEFAULT_API_URL,
319340
disable_cache: bool = False,
341+
autocommit: bool = True,
320342
) -> Connection:
321343
"""Connect to Firebolt.
322344
@@ -380,6 +402,7 @@ async def connect_v2(
380402
api_endpoint,
381403
cursor.parameters | cursor._set_parameters,
382404
connection_id,
405+
autocommit,
383406
)
384407

385408

@@ -447,6 +470,7 @@ def connect_core(
447470
user_agent_header: str,
448471
database: Optional[str] = None,
449472
connection_url: Optional[str] = None,
473+
autocommit: bool = True,
450474
) -> Connection:
451475
"""Connect to Firebolt Core.
452476
@@ -484,6 +508,7 @@ def connect_core(
484508
client=client,
485509
cursor_type=CursorV2,
486510
api_endpoint=verified_url,
511+
autocommit=autocommit,
487512
)
488513

489514

src/firebolt/common/base_connection.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(self, cursor_type: Type) -> None:
8383
self._is_closed = False
8484
self._transaction_id: Optional[str] = None
8585
self._transaction_sequence_id: Optional[str] = None
86+
self._autocommit: bool = True
8687

8788
def _remove_cursor(self, cursor: Any) -> None:
8889
# This way it's atomic
@@ -91,6 +92,7 @@ def _remove_cursor(self, cursor: Any) -> None:
9192
except ValueError:
9293
pass
9394

95+
@property
9496
def in_transaction(self) -> bool:
9597
"""`True` if connection is in a transaction; `False` otherwise."""
9698
return self._transaction_id is not None
@@ -130,10 +132,10 @@ def create_transaction_params(self) -> Dict[str, str]:
130132
params[TRANSACTION_SEQUENCE_ID_SETTING] = str(self._transaction_sequence_id)
131133
return params
132134

133-
def _add_transaction_headers(self, request: Request) -> None:
135+
def _add_transaction_params(self, request: Request) -> None:
134136
transaction_params = self.create_transaction_params()
135137
for key, value in transaction_params.items():
136-
request.headers[key] = value
138+
request.url = request.url.copy_add_param(key, value)
137139

138140
def _handle_transaction_updates(self, headers: Headers) -> None:
139141
self._parse_response_headers_transaction(headers)
@@ -142,6 +144,13 @@ def _handle_transaction_updates(self, headers: Headers) -> None:
142144
if headers.get(REMOVE_PARAMETERS_HEADER):
143145
self._parse_remove_headers_transaction(headers)
144146

147+
@property
148+
def autocommit(self) -> bool:
149+
"""
150+
`True` if connection is in autocommit mode; `False` otherwise.
151+
"""
152+
return self._autocommit
153+
145154
@property
146155
def closed(self) -> bool:
147156
"""`True` if connection is closed; `False` otherwise."""

src/firebolt/common/constants.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ class ParameterStyle(Enum):
3535
# Connection level transaction management
3636
TRANSACTION_PARAMETER_LIST = [TRANSACTION_ID_SETTING, TRANSACTION_SEQUENCE_ID_SETTING]
3737
# parameters that are set by the backend and should not be set by the user
38-
IMMUTABLE_PARAMETER_LIST = (
39-
USE_PARAMETER_LIST + DISALLOWED_PARAMETER_LIST + TRANSACTION_PARAMETER_LIST
40-
)
38+
IMMUTABLE_PARAMETER_LIST = USE_PARAMETER_LIST + DISALLOWED_PARAMETER_LIST
4139
UPDATE_ENDPOINT_HEADER = "Firebolt-Update-Endpoint"
4240
UPDATE_PARAMETERS_HEADER = "Firebolt-Update-Parameters"
4341
RESET_SESSION_HEADER = "Firebolt-Reset-Session"

src/firebolt/common/cursor/base_cursor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
IMMUTABLE_PARAMETER_LIST,
1616
JSON_LINES_OUTPUT_FORMAT,
1717
JSON_OUTPUT_FORMAT,
18+
TRANSACTION_PARAMETER_LIST,
1819
USE_PARAMETER_LIST,
1920
CursorState,
2021
)
@@ -207,7 +208,7 @@ def _update_set_parameters(self, parameters: Dict[str, Any]) -> None:
207208
user_parameters = {
208209
key: value
209210
for key, value in parameters.items()
210-
if key not in IMMUTABLE_PARAMETER_LIST
211+
if key not in IMMUTABLE_PARAMETER_LIST + TRANSACTION_PARAMETER_LIST
211212
}
212213

213214
self.parameters.update(immutable_parameters)

src/firebolt/db/connection.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def connect(
5959
api_endpoint: str = DEFAULT_API_URL,
6060
disable_cache: bool = False,
6161
url: Optional[str] = None,
62+
autocommit: bool = True,
6263
additional_parameters: Dict[str, Any] = {},
6364
) -> Connection:
6465
# auth parameter is optional in function signature
@@ -87,6 +88,7 @@ def connect(
8788
user_agent_header=user_agent_header,
8889
database=database,
8990
connection_url=url,
91+
autocommit=autocommit,
9092
)
9193
elif auth_version == FireboltAuthVersion.V2:
9294
assert account_name is not None
@@ -99,6 +101,7 @@ def connect(
99101
api_endpoint=api_endpoint,
100102
connection_id=connection_id,
101103
disable_cache=disable_cache,
104+
autocommit=autocommit,
102105
)
103106
elif auth_version == FireboltAuthVersion.V1:
104107
return connect_v1(
@@ -124,6 +127,7 @@ def connect_v2(
124127
engine_name: Optional[str] = None,
125128
api_endpoint: str = DEFAULT_API_URL,
126129
disable_cache: bool = False,
130+
autocommit: bool = True,
127131
) -> Connection:
128132
"""Connect to Firebolt.
129133
@@ -186,6 +190,7 @@ def connect_v2(
186190
api_endpoint,
187191
cursor.parameters | cursor._set_parameters,
188192
connection_id,
193+
autocommit,
189194
)
190195

191196

@@ -220,6 +225,7 @@ class Connection(BaseConnection):
220225
"client_class",
221226
"cursor_type",
222227
"id",
228+
"_autocommit",
223229
)
224230

225231
def __init__(
@@ -231,6 +237,7 @@ def __init__(
231237
api_endpoint: str = DEFAULT_API_URL,
232238
init_parameters: Optional[Dict[str, Any]] = None,
233239
id: str = uuid4().hex,
240+
autocommit: bool = True,
234241
):
235242
super().__init__(cursor_type)
236243
self.api_endpoint = api_endpoint
@@ -240,6 +247,7 @@ def __init__(
240247
self.id = id
241248
self._transaction_lock: threading.Lock = threading.Lock()
242249
self.init_parameters = init_parameters or {}
250+
self._autocommit = autocommit
243251
if database:
244252
self.init_parameters["database"] = database
245253

@@ -261,6 +269,9 @@ def close(self) -> None:
261269
if self.closed:
262270
return
263271

272+
if self.in_transaction:
273+
self.cursor().execute("ROLLBACK")
274+
264275
cursors = self._cursors[:]
265276
for c in cursors:
266277
# Here c can already be closed by another thread,
@@ -270,14 +281,25 @@ def close(self) -> None:
270281
self._is_closed = True
271282

272283
def _execute_query_impl(self, request: Request) -> Response:
273-
self._add_transaction_headers(request)
284+
self._add_transaction_params(request)
274285
response = self._client.send(request, stream=True)
275286
self._handle_transaction_updates(response.headers)
276287
return response
277288

289+
def _begin_nolock(self, request: Request) -> None:
290+
"""Begin a transaction without a lock. Used internally."""
291+
begin_request = self._client.build_request(
292+
request.method, request.url, content="BEGIN"
293+
)
294+
response = self._client.send(begin_request, stream=True)
295+
self._handle_transaction_updates(response.headers)
296+
278297
def _execute_query(self, request: Request) -> Response:
279-
if self.in_transaction():
298+
if self.in_transaction or not self.autocommit:
280299
with self._transaction_lock:
300+
# If autocommit is off we need to explicitly begin a transaction
301+
if not self.in_transaction:
302+
self._begin_nolock(request)
281303
return self._execute_query_impl(request)
282304
else:
283305
return self._execute_query_impl(request)
@@ -451,6 +473,7 @@ def connect_core(
451473
user_agent_header: str,
452474
database: Optional[str] = None,
453475
connection_url: Optional[str] = None,
476+
autocommit: bool = True,
454477
) -> Connection:
455478
"""Connect to Firebolt Core.
456479
@@ -489,6 +512,7 @@ def connect_core(
489512
client=client,
490513
cursor_type=CursorV2,
491514
api_endpoint=verified_url,
515+
autocommit=autocommit,
492516
)
493517

494518

src/firebolt/db/cursor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def _validate_set_parameter(
191191
self._append_row_set_from_response(None)
192192

193193
def _parse_response_headers(self, headers: Headers) -> None:
194+
# TODO: merge with async in base_cursor
194195
if headers.get(UPDATE_ENDPOINT_HEADER):
195196
endpoint, params = _parse_update_endpoint(
196197
headers.get(UPDATE_ENDPOINT_HEADER)

0 commit comments

Comments
 (0)