Skip to content

Commit b14caaf

Browse files
committed
Add context management and improve tests
1 parent b586345 commit b14caaf

File tree

9 files changed

+717
-322
lines changed

9 files changed

+717
-322
lines changed

src/firebolt/async_db/connection.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
from ssl import SSLContext
45
from types import TracebackType
56
from typing import Any, Dict, List, Optional, Type, Union
@@ -46,6 +47,8 @@
4647
validate_engine_name_and_url_v1,
4748
)
4849

50+
logger = logging.getLogger(__name__)
51+
4952

5053
class Connection(BaseConnection):
5154
"""
@@ -243,8 +246,13 @@ async def aclose(self) -> None:
243246
if self.closed:
244247
return
245248

246-
if self.in_transaction:
247-
await self.cursor().execute("ROLLBACK")
249+
# Only rollback if we have a transaction and autocommit is off
250+
if self.in_transaction and not self.autocommit:
251+
try:
252+
await self.cursor().execute("ROLLBACK")
253+
except Exception:
254+
# If rollback fails during close, continue closing
255+
logger.warning("Rollback failed during close")
248256

249257
# self._cursors is going to be changed during closing cursors
250258
# after this point no cursors would be added to _cursors, only removed since
@@ -260,6 +268,10 @@ async def aclose(self) -> None:
260268
async def __aexit__(
261269
self, exc_type: type, exc_val: Exception, exc_tb: TracebackType
262270
) -> None:
271+
# If exiting normally (no exception) and we have a transaction with
272+
# autocommit=False, commit the transaction before closing
273+
if exc_type is None and not self.autocommit and self.in_transaction:
274+
await self.commit()
263275
await self.aclose()
264276

265277

src/firebolt/db/connection.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,13 @@ def close(self) -> None:
269269
if self.closed:
270270
return
271271

272-
if self.in_transaction:
273-
self.cursor().execute("ROLLBACK")
272+
# Only rollback if we have a transaction and autocommit is off
273+
if self.in_transaction and not self.autocommit:
274+
try:
275+
self.cursor().execute("ROLLBACK")
276+
except Exception:
277+
# If rollback fails during close, continue closing
278+
logger.warning("Rollback failed during close")
274279

275280
cursors = self._cursors[:]
276281
for c in cursors:
@@ -402,6 +407,10 @@ def __enter__(self) -> Connection:
402407
def __exit__(
403408
self, exc_type: type, exc_val: Exception, exc_tb: TracebackType
404409
) -> None:
410+
# If exiting normally (no exception) and we have a transaction with
411+
# autocommit=False, commit the transaction before closing
412+
if exc_type is None and not self.autocommit and self.in_transaction:
413+
self.commit()
405414
self.close()
406415

407416
def __del__(self) -> None:

tests/integration/dbapi/async/V2/conftest.py

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,54 +11,41 @@
1111
from tests.integration.conftest import Secret
1212

1313

14-
@fixture(params=["remote", "core"])
14+
@fixture
1515
async def connection(
16-
engine_name: str,
17-
database_name: str,
18-
auth: Auth,
19-
core_auth: Auth,
20-
account_name: str,
21-
api_endpoint: str,
22-
core_url: str,
23-
request: Any,
16+
connection_factory: Callable[..., Connection],
2417
) -> Connection:
25-
if request.param == "core":
26-
kwargs = {
27-
"database": "firebolt",
28-
"auth": core_auth,
29-
"url": core_url,
30-
}
31-
else:
32-
kwargs = {
33-
"engine_name": engine_name,
34-
"database": database_name,
35-
"auth": auth,
36-
"account_name": account_name,
37-
"api_endpoint": api_endpoint,
38-
}
39-
async with await connect(
40-
**kwargs,
41-
) as connection:
18+
async with await connection_factory() as connection:
4219
yield connection
4320

4421

45-
@fixture
22+
@fixture(params=["remote", "core"])
4623
async def connection_factory(
4724
engine_name: str,
4825
database_name: str,
4926
auth: Auth,
27+
core_auth: Auth,
5028
account_name: str,
5129
api_endpoint: str,
30+
core_url: str,
31+
request: Any,
5232
) -> Callable[..., Connection]:
5333
async def factory(**kwargs: Any) -> Connection:
54-
return await connect(
55-
engine_name=engine_name,
56-
database=database_name,
57-
auth=auth,
58-
account_name=account_name,
59-
api_endpoint=api_endpoint,
60-
**kwargs,
61-
)
34+
if request.param == "core":
35+
base_kwargs = {
36+
"database": "firebolt",
37+
"auth": core_auth,
38+
"url": core_url,
39+
}
40+
else:
41+
base_kwargs = {
42+
"engine_name": engine_name,
43+
"database": database_name,
44+
"auth": auth,
45+
"account_name": account_name,
46+
"api_endpoint": api_endpoint,
47+
}
48+
return await connect(**base_kwargs, **kwargs)
6249

6350
return factory
6451

0 commit comments

Comments
 (0)