Skip to content

Commit f96936a

Browse files
use pytest parametrize instead of unittest subtests, switch to pytest instead of unittest primitives
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 2f666fb commit f96936a

File tree

1 file changed

+71
-64
lines changed

1 file changed

+71
-64
lines changed

tests/unit/test_client.py

Lines changed: 71 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import re
33
import sys
44
import unittest
5+
import pytest
56
from unittest.mock import patch, MagicMock, Mock, PropertyMock
67
import itertools
78
from decimal import Decimal
@@ -30,6 +31,76 @@
3031
from tests.unit.test_arrow_queue import ArrowQueueSuite
3132

3233

34+
@pytest.mark.parametrize("closed", [True, False])
35+
@patch("databricks.sql.client.ThriftBackend")
36+
def test_closing_connection_closes_commands(mock_thrift_client_class, closed):
37+
"""Test that closing a connection properly closes commands.
38+
39+
This test verifies that when a connection is closed:
40+
1. All result sets are marked as closed server-side
41+
2. The operation state is set to CLOSED
42+
3. Backend.close_command is called only for commands that weren't already closed
43+
44+
Args:
45+
mock_thrift_client_class: Mock for ThriftBackend class
46+
closed: Parameter indicating if the command is already closed
47+
"""
48+
# Set initial state based on whether the command is already closed
49+
initial_state = (
50+
TOperationState.FINISHED_STATE if not closed else TOperationState.CLOSED_STATE
51+
)
52+
53+
# Mock the execute response with controlled state
54+
mock_execute_response = Mock(spec=ExecuteResponse)
55+
mock_execute_response.status = initial_state
56+
mock_execute_response.has_been_closed_server_side = closed
57+
mock_execute_response.is_staging_operation = False
58+
59+
# Mock the backend that will be used
60+
mock_backend = Mock(spec=ThriftBackend)
61+
mock_thrift_client_class.return_value = mock_backend
62+
63+
# Create connection and cursor
64+
connection = databricks.sql.connect(
65+
server_hostname="foo",
66+
http_path="dummy_path",
67+
access_token="tok",
68+
)
69+
cursor = connection.cursor()
70+
71+
# Verify initial state
72+
assert mock_execute_response.has_been_closed_server_side == closed
73+
assert mock_execute_response.status == initial_state
74+
75+
# Mock execute_command to return our execute response
76+
cursor.thrift_backend.execute_command = Mock(return_value=mock_execute_response)
77+
cursor.execute("SELECT 1")
78+
79+
# Verify that cursor.execute() set up the result set correctly
80+
active_result_set = cursor.active_result_set
81+
assert active_result_set.has_been_closed_server_side == closed
82+
83+
# Close the connection
84+
connection.close()
85+
86+
# Verify the close logic worked:
87+
# 1. has_been_closed_server_side should always be True after close()
88+
assert active_result_set.has_been_closed_server_side is True
89+
90+
# 2. op_state should always be CLOSED after close()
91+
assert active_result_set.op_state == connection.thrift_backend.CLOSED_OP_STATE
92+
93+
# 3. Backend close_command should be called appropriately
94+
if not closed:
95+
# Should have called backend.close_command during the close chain
96+
mock_backend.close_command.assert_called_once_with(
97+
mock_execute_response.command_handle
98+
)
99+
else:
100+
# Should NOT have called backend.close_command (already closed)
101+
mock_backend.close_command.assert_not_called()
102+
103+
33104
class ThriftBackendMockFactory:
34105
@classmethod
35106
def new(cls):
@@ -170,70 +241,6 @@ def test_useragent_header(self, mock_client_class):
170241
http_headers = mock_client_class.call_args[0][3]
171242
self.assertIn(user_agent_header_with_entry, http_headers)
172243

173-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
174-
def test_closing_connection_closes_commands(self, mock_thrift_client_class):
175-
# Test once with has_been_closed_server side, once without
176-
for closed in (True, False):
177-
with self.subTest(closed=closed):
178-
initial_state = (
179-
TOperationState.FINISHED_STATE
180-
if not closed
181-
else TOperationState.CLOSED_STATE
182-
)
183-
184-
# Mock the execute response with controlled state
185-
mock_execute_response = Mock(spec=ExecuteResponse)
186-
mock_execute_response.status = initial_state
187-
mock_execute_response.has_been_closed_server_side = closed
188-
mock_execute_response.is_staging_operation = False
189-
190-
# Mock the backend that will be used
191-
mock_backend = Mock(spec=ThriftBackend)
192-
mock_thrift_client_class.return_value = mock_backend
193-
194-
# Create connection and cursor
195-
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
196-
cursor = connection.cursor()
197-
198-
# Verify initial state
199-
self.assertEqual(
200-
mock_execute_response.has_been_closed_server_side, closed
201-
)
202-
self.assertEqual(mock_execute_response.status, initial_state)
203-
204-
# Mock execute_command to return our execute response
205-
cursor.thrift_backend.execute_command = Mock(
206-
return_value=mock_execute_response
207-
)
208-
cursor.execute("SELECT 1")
209-
210-
# Verify that cursor.execute() set up the result set correctly
211-
active_result_set = cursor.active_result_set
212-
self.assertEqual(active_result_set.has_been_closed_server_side, closed)
213-
214-
# Close the connection
215-
connection.close()
216-
217-
# Verify the close logic worked:
218-
# 1. has_been_closed_server_side should always be True after close()
219-
self.assertTrue(active_result_set.has_been_closed_server_side)
220-
221-
# 2. op_state should always be CLOSED after close()
222-
self.assertEqual(
223-
active_result_set.op_state,
224-
connection.thrift_backend.CLOSED_OP_STATE,
225-
)
226-
227-
# 3. Backend close_command should be called appropriately
228-
if not closed:
229-
# Should have called backend.close_command during the close chain
230-
mock_backend.close_command.assert_called_once_with(
231-
mock_execute_response.command_handle
232-
)
233-
else:
234-
# Should NOT have called backend.close_command (already closed)
235-
mock_backend.close_command.assert_not_called()
236-
237244
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
238245
def test_cant_open_cursor_on_closed_connection(self, mock_client_class):
239246
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)

0 commit comments

Comments
 (0)