Skip to content

Commit 3102413

Browse files
committed
Improve error handling in connection and cursor context managers to ensure proper closure during exceptions, including KeyboardInterrupt. Add tests for nested cursor management and verify operation closure on server-side errors.
1 parent 1232b3c commit 3102413

File tree

3 files changed

+154
-7
lines changed

3 files changed

+154
-7
lines changed

src/databricks/sql/client.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,13 @@ def __enter__(self) -> "Connection":
315315
return self
316316

317317
def __exit__(self, exc_type, exc_value, traceback):
318-
self.close()
318+
try:
319+
self.close()
320+
except BaseException as e:
321+
logger.warning(f"Exception during connection close in __exit__: {e}")
322+
if exc_type is None:
323+
raise
324+
return False
319325

320326
def __del__(self):
321327
if self.open:
@@ -459,11 +465,9 @@ def __exit__(self, exc_type, exc_value, traceback):
459465
try:
460466
logger.debug("Cursor context manager exiting, calling close()")
461467
self.close()
462-
except Exception as e:
468+
except BaseException as e:
463469
logger.warning(f"Exception during cursor close in __exit__: {e}")
464-
# Don't suppress the original exception if there was one
465470
if exc_type is None:
466-
# Only raise our new exception if there wasn't already one in progress
467471
raise
468472
return False
469473

tests/e2e/test_driver.py

Lines changed: 128 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151
from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin
5252

53-
from databricks.sql.exc import SessionAlreadyClosedError
53+
from databricks.sql.exc import SessionAlreadyClosedError, CursorAlreadyClosedError
5454

5555
log = logging.getLogger(__name__)
5656

@@ -813,7 +813,6 @@ def test_close_connection_closes_cursors(self):
813813
ars = cursor.active_result_set
814814

815815
# We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True
816-
817816
# Cursor op state should be open before connection is closed
818817
status_request = ttypes.TGetOperationStatusReq(
819818
operationHandle=ars.command_id, getProgressUpdate=False
@@ -840,9 +839,106 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog):
840839
with self.connection() as conn:
841840
# First .close() call is explicit here
842841
conn.close()
843-
844842
assert "Session appears to have been closed already" in caplog.text
845843

844+
# --- Integrated KeyboardInterrupt test ---
845+
conn = None
846+
try:
847+
with pytest.raises(KeyboardInterrupt):
848+
with self.connection() as c:
849+
conn = c
850+
raise KeyboardInterrupt("Simulated interrupt")
851+
finally:
852+
if conn is not None:
853+
assert not conn.open, "Connection should be closed after KeyboardInterrupt"
854+
855+
def test_cursor_close_properly_closes_operation(self):
856+
"""Test that Cursor.close() properly closes the active operation handle on the server."""
857+
with self.connection() as conn:
858+
cursor = conn.cursor()
859+
try:
860+
cursor.execute("SELECT 1 AS test")
861+
assert cursor.active_op_handle is not None
862+
cursor.close()
863+
assert cursor.active_op_handle is None
864+
assert not cursor.open
865+
finally:
866+
if cursor.open:
867+
cursor.close()
868+
869+
# --- Integrated KeyboardInterrupt test ---
870+
conn = None
871+
cursor = None
872+
try:
873+
with self.connection() as c:
874+
conn = c
875+
with pytest.raises(KeyboardInterrupt):
876+
with conn.cursor() as cur:
877+
cursor = cur
878+
raise KeyboardInterrupt("Simulated interrupt")
879+
finally:
880+
if cursor is not None:
881+
assert not cursor.open, "Cursor should be closed after KeyboardInterrupt"
882+
883+
def test_nested_cursor_context_managers(self):
884+
"""Test that nested cursor context managers properly close operations on the server."""
885+
with self.connection() as conn:
886+
with conn.cursor() as cursor1:
887+
cursor1.execute("SELECT 1 AS test1")
888+
assert cursor1.active_op_handle is not None
889+
890+
with conn.cursor() as cursor2:
891+
cursor2.execute("SELECT 2 AS test2")
892+
assert cursor2.active_op_handle is not None
893+
894+
# After inner context manager exit, cursor2 should be not open
895+
assert not cursor2.open
896+
assert cursor2.active_op_handle is None
897+
898+
# After outer context manager exit, cursor1 should be not open
899+
assert not cursor1.open
900+
assert cursor1.active_op_handle is None
901+
902+
def test_cursor_error_handling(self):
903+
"""Test that cursor close handles errors properly to prevent orphaned operations."""
904+
with self.connection() as conn:
905+
cursor = conn.cursor()
906+
907+
cursor.execute("SELECT 1 AS test")
908+
909+
op_handle = cursor.active_op_handle
910+
911+
assert op_handle is not None
912+
913+
# Manually close the operation to simulate server-side closure
914+
conn.thrift_backend.close_command(op_handle)
915+
916+
cursor.close()
917+
918+
assert not cursor.open
919+
920+
def test_result_set_close(self):
921+
"""Test that ResultSet.close() properly closes operations on the server and handles state correctly."""
922+
with self.connection() as conn:
923+
cursor = conn.cursor()
924+
try:
925+
cursor.execute("SELECT * FROM RANGE(10)")
926+
927+
result_set = cursor.active_result_set
928+
assert result_set is not None
929+
930+
initial_op_state = result_set.op_state
931+
932+
result_set.close()
933+
934+
assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE
935+
assert result_set.op_state != initial_op_state
936+
937+
# Closing the result set again should be a no-op and not raise exceptions
938+
result_set.close()
939+
finally:
940+
cursor.close()
941+
846942

847943
# use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep
848944
# the 429/503 subsuites separate since they execute under different circumstances.
@@ -875,3 +971,32 @@ def test_initial_namespace(self):
875971
assert cursor.fetchone()[0] == self.arguments["catalog"]
876972
cursor.execute("select current_database()")
877973
assert cursor.fetchone()[0] == table_name
974+
975+
976+
class TestContextManagerInterrupts(PySQLPytestTestCase):
977+
def test_connection_context_manager_handles_keyboard_interrupt(self):
978+
# This test ensures that a KeyboardInterrupt inside a connection context propagates and closes the connection
979+
conn = None
980+
try:
981+
with pytest.raises(KeyboardInterrupt):
982+
with self.connection() as c:
983+
conn = c
984+
raise KeyboardInterrupt("Simulated interrupt")
985+
finally:
986+
if conn is not None:
987+
assert not conn.open, "Connection should be closed after KeyboardInterrupt"
988+
989+
def test_cursor_context_manager_handles_keyboard_interrupt(self):
990+
# This test ensures that a KeyboardInterrupt inside a cursor context propagates and closes the cursor
991+
conn = None
992+
cursor = None
993+
try:
994+
with self.connection() as c:
995+
conn = c
996+
with pytest.raises(KeyboardInterrupt):
997+
with conn.cursor() as cur:
998+
cursor = cur
999+
raise KeyboardInterrupt("Simulated interrupt")
1000+
finally:
1001+
if cursor is not None:
1002+
assert not cursor.open, "Cursor should be closed after KeyboardInterrupt"

tests/unit/test_client.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,15 @@ def test_context_manager_closes_cursor(self):
284284
cursor.close = mock_close
285285
mock_close.assert_called_once_with()
286286

287+
cursor = client.Cursor(Mock(), Mock())
288+
cursor.close = Mock()
289+
try:
290+
with self.assertRaises(KeyboardInterrupt):
291+
with cursor:
292+
raise KeyboardInterrupt("Simulated interrupt")
293+
finally:
294+
cursor.close.assert_called()
295+
287296
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
288297
def test_context_manager_closes_connection(self, mock_client_class):
289298
instance = mock_client_class.return_value
@@ -299,6 +308,15 @@ def test_context_manager_closes_connection(self, mock_client_class):
299308
close_session_id = instance.close_session.call_args[0][0].sessionId
300309
self.assertEqual(close_session_id, b"\x22")
301310

311+
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
312+
connection.close = Mock()
313+
try:
314+
with self.assertRaises(KeyboardInterrupt):
315+
with connection:
316+
raise KeyboardInterrupt("Simulated interrupt")
317+
finally:
318+
connection.close.assert_called()
319+
302320
def dict_product(self, dicts):
303321
"""
304322
Generate cartesion product of values in input dictionary, outputting a dictionary

0 commit comments

Comments
 (0)