Skip to content

Commit 4ce646a

Browse files
committed
add unit tests
1 parent b14caaf commit 4ce646a

File tree

9 files changed

+1210
-23
lines changed

9 files changed

+1210
-23
lines changed

src/firebolt/async_db/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ async def aclose(self) -> None:
249249
# Only rollback if we have a transaction and autocommit is off
250250
if self.in_transaction and not self.autocommit:
251251
try:
252-
await self.cursor().execute("ROLLBACK")
252+
await self.rollback()
253253
except Exception:
254254
# If rollback fails during close, continue closing
255255
logger.warning("Rollback failed during close")

src/firebolt/db/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def close(self) -> None:
272272
# Only rollback if we have a transaction and autocommit is off
273273
if self.in_transaction and not self.autocommit:
274274
try:
275-
self.cursor().execute("ROLLBACK")
275+
self.rollback()
276276
except Exception:
277277
# If rollback fails during close, continue closing
278278
logger.warning("Rollback failed during close")

tests/unit/async_db/conftest.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import AsyncGenerator
2+
13
from pytest import fixture
24

35
import firebolt.async_db
@@ -14,7 +16,7 @@ async def connection(
1416
engine_name: str,
1517
account_name: str,
1618
mock_connection_flow: Callable,
17-
) -> Connection:
19+
) -> AsyncGenerator[Connection, None]:
1820
mock_connection_flow()
1921
async with (
2022
await connect(
@@ -35,6 +37,32 @@ async def cursor(connection: Connection) -> Cursor:
3537
return connection.cursor()
3638

3739

40+
@fixture
41+
async def connection_autocommit_off(
42+
api_endpoint: str,
43+
db_name: str,
44+
auth: Auth,
45+
engine_name: str,
46+
account_name: str,
47+
mock_connection_flow: Callable,
48+
) -> AsyncGenerator[Connection, None]:
49+
"""Connection fixture with autocommit=False for transaction testing."""
50+
mock_connection_flow()
51+
async with (
52+
await connect(
53+
engine_name=engine_name,
54+
database=db_name,
55+
auth=auth,
56+
account_name=account_name,
57+
api_endpoint=api_endpoint,
58+
autocommit=False,
59+
)
60+
) as connection:
61+
# cache account_id for tests
62+
await connection._client.account_id
63+
yield connection
64+
65+
3866
@fixture
3967
def fb_numeric_paramstyle():
4068
"""Fixture that sets paramstyle to fb_numeric and resets it after the test."""

tests/unit/async_db/test_connection.py

Lines changed: 287 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -286,15 +286,6 @@ async def test_connect_system_engine_404(
286286
await connection.cursor().execute("select*")
287287

288288

289-
async def test_connection_commit(connection: Connection):
290-
# nothing happens
291-
connection.commit()
292-
293-
await connection.aclose()
294-
with raises(ConnectionClosedError):
295-
connection.commit()
296-
297-
298289
@mark.nofakefs
299290
async def test_connection_token_caching(
300291
db_name: str,
@@ -817,3 +808,290 @@ async def test_use_engine_update_parameters_propagation(
817808
for param_name, expected_value in test_update_parameters.items():
818809
assert param_name in cursor._set_parameters
819810
assert cursor._set_parameters[param_name] == expected_value
811+
812+
813+
# Transaction tests
814+
815+
816+
async def test_connection_commit(
817+
connection: Connection, httpx_mock: HTTPXMock, simple_commit_callback: Callable
818+
):
819+
# Mock the COMMIT query
820+
httpx_mock.add_callback(
821+
simple_commit_callback,
822+
method="POST",
823+
)
824+
# nothing happens
825+
await connection.commit()
826+
827+
await connection.aclose()
828+
with raises(ConnectionClosedError):
829+
await connection.commit()
830+
831+
832+
async def test_connection_autocommit_property(connection: Connection):
833+
"""Test autocommit property getter and that setter fails."""
834+
# Should default to True
835+
assert connection.autocommit is True, "Connection should default to autocommit=True"
836+
837+
# Autocommit should be read-only - setting it should fail
838+
with raises(AttributeError):
839+
connection.autocommit = False
840+
841+
# Close connection to satisfy async requirement
842+
await connection.aclose()
843+
844+
845+
async def test_in_transaction_property_reflects_transaction_state(
846+
connection: Connection,
847+
):
848+
"""Test in_transaction property."""
849+
# Should not be in transaction initially
850+
assert (
851+
connection.in_transaction is False
852+
), "Connection should not be in transaction initially"
853+
854+
# Mock being in transaction
855+
connection._transaction_id = "test_id"
856+
assert (
857+
connection.in_transaction is True
858+
), "Connection should be in transaction when _transaction_id is set"
859+
860+
# Clear transaction
861+
connection._transaction_id = None
862+
assert (
863+
connection.in_transaction is False
864+
), "Connection should not be in transaction when _transaction_id is None"
865+
866+
# Close connection to satisfy async requirement
867+
await connection.aclose()
868+
869+
870+
async def test_transaction_id_parsed_from_server_response_headers(
871+
httpx_mock: HTTPXMock,
872+
connection_autocommit_off: Connection,
873+
begin_transaction_callback: Callable,
874+
transaction_query_callback: Callable,
875+
commit_transaction_callback: Callable,
876+
transaction_id: str,
877+
):
878+
"""Test that transaction headers are parsed correctly."""
879+
httpx_mock.add_callback(
880+
begin_transaction_callback,
881+
method="POST",
882+
)
883+
httpx_mock.add_callback(
884+
transaction_query_callback,
885+
method="POST",
886+
)
887+
httpx_mock.add_callback(
888+
commit_transaction_callback,
889+
method="POST",
890+
)
891+
892+
cursor = connection_autocommit_off.cursor()
893+
await cursor.execute("SELECT 1") # This should implicitly start transaction
894+
895+
# Check that transaction_id was parsed from header
896+
assert connection_autocommit_off._transaction_id == transaction_id
897+
assert connection_autocommit_off.in_transaction is True
898+
899+
900+
async def test_transaction_sequence_id_parsed_from_server_response_headers(
901+
httpx_mock: HTTPXMock,
902+
connection_autocommit_off: Connection,
903+
begin_transaction_callback: Callable,
904+
select_one_query_callback: Callable,
905+
transaction_query_callback: Callable,
906+
commit_transaction_callback: Callable,
907+
transaction_id: str,
908+
transaction_sequence_id: int,
909+
):
910+
"""Test that transaction sequence id is parsed from headers."""
911+
# Start transaction implicitly - connection will send BEGIN then SELECT 1
912+
httpx_mock.add_callback(begin_transaction_callback, method="POST")
913+
httpx_mock.add_callback(select_one_query_callback, method="POST")
914+
httpx_mock.add_callback(commit_transaction_callback, method="POST")
915+
916+
cursor = connection_autocommit_off.cursor()
917+
await cursor.execute("SELECT 1") # This should implicitly start transaction
918+
919+
assert connection_autocommit_off._transaction_id == transaction_id
920+
assert connection_autocommit_off._transaction_sequence_id is None
921+
922+
# Execute query in transaction - should get sequence id
923+
httpx_mock.reset()
924+
httpx_mock.add_callback(transaction_query_callback, method="POST")
925+
httpx_mock.add_callback(commit_transaction_callback, method="POST")
926+
927+
await cursor.execute("SELECT * FROM table")
928+
929+
# Sequence id should be incremented
930+
assert connection_autocommit_off._transaction_sequence_id == str(
931+
transaction_sequence_id + 1
932+
)
933+
934+
935+
async def test_transaction_params_included_in_subsequent_requests(
936+
httpx_mock: HTTPXMock,
937+
connection_autocommit_off: Connection,
938+
begin_transaction_callback: Callable,
939+
transaction_query_callback: Callable,
940+
commit_transaction_callback: Callable,
941+
):
942+
"""Test that transaction parameters are added to requests."""
943+
# Start transaction implicitly
944+
httpx_mock.add_callback(
945+
begin_transaction_callback,
946+
method="POST",
947+
)
948+
httpx_mock.add_callback(
949+
transaction_query_callback,
950+
method="POST",
951+
)
952+
httpx_mock.add_callback(
953+
commit_transaction_callback,
954+
method="POST",
955+
)
956+
957+
cursor = connection_autocommit_off.cursor()
958+
await cursor.execute("SELECT 1") # This should implicitly start transaction
959+
960+
# Execute second query in transaction - callback will verify parameters are present
961+
httpx_mock.reset()
962+
httpx_mock.add_callback(
963+
transaction_query_callback,
964+
method="POST",
965+
)
966+
httpx_mock.add_callback(
967+
commit_transaction_callback,
968+
method="POST",
969+
)
970+
971+
await cursor.execute(
972+
"SELECT 1"
973+
) # This will fail if transaction params aren't passed
974+
975+
976+
async def test_reset_session_header_clears_transaction_state(
977+
httpx_mock: HTTPXMock,
978+
connection_autocommit_off: Connection,
979+
begin_transaction_callback: Callable,
980+
transaction_query_callback: Callable,
981+
commit_transaction_callback: Callable,
982+
transaction_id: str,
983+
):
984+
"""Test that reset session header clears transaction state."""
985+
# Start transaction implicitly
986+
httpx_mock.add_callback(
987+
begin_transaction_callback,
988+
method="POST",
989+
)
990+
httpx_mock.add_callback(
991+
transaction_query_callback,
992+
method="POST",
993+
)
994+
995+
cursor = connection_autocommit_off.cursor()
996+
await cursor.execute("SELECT 1") # This should implicitly start transaction
997+
998+
assert connection_autocommit_off._transaction_id == transaction_id
999+
assert connection_autocommit_off.in_transaction is True
1000+
1001+
# Execute COMMIT using connection method which should reset session
1002+
httpx_mock.reset()
1003+
httpx_mock.add_callback(
1004+
commit_transaction_callback,
1005+
method="POST",
1006+
)
1007+
1008+
await connection_autocommit_off.commit()
1009+
1010+
# Transaction should be cleared
1011+
assert connection_autocommit_off._transaction_id is None
1012+
assert connection_autocommit_off._transaction_sequence_id is None
1013+
assert connection_autocommit_off.in_transaction is False
1014+
1015+
1016+
async def test_remove_parameters_header_clears_transaction_state(
1017+
httpx_mock: HTTPXMock,
1018+
connection_autocommit_off: Connection,
1019+
transaction_with_remove_params_callback: Callable,
1020+
):
1021+
"""Test that remove parameters header clears transaction parameters."""
1022+
# Set up transaction state manually
1023+
connection_autocommit_off._transaction_id = "test_id"
1024+
connection_autocommit_off._transaction_sequence_id = "5"
1025+
1026+
assert connection_autocommit_off.in_transaction is True
1027+
1028+
httpx_mock.add_callback(
1029+
transaction_with_remove_params_callback,
1030+
method="POST",
1031+
)
1032+
1033+
cursor = connection_autocommit_off.cursor()
1034+
await cursor.execute("SELECT 1")
1035+
1036+
# Transaction parameters should be cleared
1037+
assert connection_autocommit_off._transaction_id is None
1038+
assert connection_autocommit_off._transaction_sequence_id is None
1039+
assert connection_autocommit_off.in_transaction is False
1040+
1041+
1042+
async def test_connection_context_manager_handles_transaction_cleanup(
1043+
httpx_mock: HTTPXMock,
1044+
auth: Auth,
1045+
account_name: str,
1046+
api_endpoint: str,
1047+
db_name: str,
1048+
engine_name: str,
1049+
engine_url: str,
1050+
mock_connection_flow: Callable,
1051+
begin_transaction_callback: Callable,
1052+
transaction_query_callback: Callable,
1053+
rollback_transaction_callback: Callable,
1054+
transaction_id: str,
1055+
):
1056+
"""Test that connection context manager handles transactions properly."""
1057+
mock_connection_flow()
1058+
1059+
# Mock queries with transaction_id parameter
1060+
httpx_mock.add_callback(
1061+
begin_transaction_callback,
1062+
method="POST",
1063+
)
1064+
httpx_mock.add_callback(
1065+
transaction_query_callback,
1066+
method="POST",
1067+
)
1068+
httpx_mock.add_callback(
1069+
rollback_transaction_callback,
1070+
method="POST",
1071+
)
1072+
1073+
try:
1074+
async with await connect(
1075+
auth=auth,
1076+
account_name=account_name,
1077+
api_endpoint=api_endpoint,
1078+
database=db_name,
1079+
engine_name=engine_name,
1080+
autocommit=False,
1081+
) as connection:
1082+
cursor = connection.cursor()
1083+
await cursor.execute("BEGIN")
1084+
1085+
# Verify transaction state
1086+
assert connection._transaction_id == transaction_id
1087+
assert connection._transaction_sequence_id == 1
1088+
1089+
# Execute another query to test transaction parameters
1090+
await cursor.execute("SELECT 1")
1091+
1092+
except Exception:
1093+
pass # Context manager should handle rollback
1094+
1095+
# Verify transaction was cleared
1096+
assert connection._transaction_id is None
1097+
assert connection._transaction_sequence_id is None

0 commit comments

Comments
 (0)