Skip to content

Commit c22a53f

Browse files
remove pytest parametrize, move back to unittest subTest
to allow keeping the test inside ClientTestSuite Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 5521f21 commit c22a53f

File tree

1 file changed

+74
-67
lines changed

1 file changed

+74
-67
lines changed

tests/unit/test_client.py

Lines changed: 74 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -31,73 +31,6 @@
3131
from tests.unit.test_arrow_queue import ArrowQueueSuite
3232

3333

34-
@pytest.mark.parametrize("closed", [True, False])
35-
@patch("databricks.sql.client.ThriftBackend")
36-
def test_closing_connection_closes_commands(mock_thrift_client_class, closed):
37-
"""Test that closing a connection properly closes commands.
38-
39-
This test verifies that when a connection is closed:
40-
1. All result sets are marked as closed server-side
41-
2. The operation state is set to CLOSED
42-
3. Backend.close_command is called only for commands that weren't already closed
43-
44-
Args:
45-
mock_thrift_client_class: Mock for ThriftBackend class
46-
closed: Parameter indicating if the command is already closed
47-
"""
48-
# Set initial state based on whether the command is already closed
49-
initial_state = (
50-
TOperationState.FINISHED_STATE if not closed else TOperationState.CLOSED_STATE
51-
)
52-
53-
# Mock the execute response with controlled state
54-
mock_execute_response = Mock(spec=ExecuteResponse)
55-
mock_execute_response.status = initial_state
56-
mock_execute_response.has_been_closed_server_side = closed
57-
mock_execute_response.is_staging_operation = False
58-
59-
# Mock the backend that will be used
60-
mock_backend = Mock(spec=ThriftBackend)
61-
mock_thrift_client_class.return_value = mock_backend
62-
63-
# Create connection and cursor
64-
connection = databricks.sql.connect(
65-
server_hostname="foo",
66-
http_path="dummy_path",
67-
access_token="tok",
68-
)
69-
cursor = connection.cursor()
70-
71-
# Mock execute_command to return our execute response
72-
cursor.thrift_backend.execute_command = Mock(return_value=mock_execute_response)
73-
74-
# Execute a command
75-
cursor.execute("SELECT 1")
76-
77-
# Get the active result set for later assertions
78-
active_result_set = cursor.active_result_set
79-
80-
# Close the connection
81-
connection.close()
82-
83-
# Verify the close logic worked:
84-
# 1. has_been_closed_server_side should always be True after close()
85-
assert active_result_set.has_been_closed_server_side is True
86-
87-
# 2. op_state should always be CLOSED after close()
88-
assert active_result_set.op_state == connection.thrift_backend.CLOSED_OP_STATE
89-
90-
# 3. Backend close_command should be called appropriately
91-
if not closed:
92-
# Should have called backend.close_command during the close chain
93-
mock_backend.close_command.assert_called_once_with(
94-
mock_execute_response.command_handle
95-
)
96-
else:
97-
# Should NOT have called backend.close_command (already closed)
98-
mock_backend.close_command.assert_not_called()
99-
100-
10134
class ThriftBackendMockFactory:
10235
@classmethod
10336
def new(cls):
@@ -238,6 +171,80 @@ def test_useragent_header(self, mock_client_class):
238171
http_headers = mock_client_class.call_args[0][3]
239172
self.assertIn(user_agent_header_with_entry, http_headers)
240173

174+
@patch("databricks.sql.client.ThriftBackend")
175+
def test_closing_connection_closes_commands(self, mock_thrift_client_class):
176+
"""Test that closing a connection properly closes commands.
177+
178+
This test verifies that when a connection is closed:
179+
1. All result sets are marked as closed server-side
180+
2. The operation state is set to CLOSED
181+
3. Backend.close_command is called only for commands that weren't already closed
182+
183+
Args:
184+
mock_thrift_client_class: Mock for ThriftBackend class
185+
closed: Parameter indicating if the command is already closed
186+
"""
187+
for closed in (True, False):
188+
with self.subTest(closed=closed):
189+
# Set initial state based on whether the command is already closed
190+
initial_state = (
191+
TOperationState.FINISHED_STATE
192+
if not closed
193+
else TOperationState.CLOSED_STATE
194+
)
195+
196+
# Mock the execute response with controlled state
197+
mock_execute_response = Mock(spec=ExecuteResponse)
198+
mock_execute_response.status = initial_state
199+
mock_execute_response.has_been_closed_server_side = closed
200+
mock_execute_response.is_staging_operation = False
201+
202+
# Mock the backend that will be used
203+
mock_backend = Mock(spec=ThriftBackend)
204+
mock_thrift_client_class.return_value = mock_backend
205+
206+
# Create connection and cursor
207+
connection = databricks.sql.connect(
208+
server_hostname="foo",
209+
http_path="dummy_path",
210+
access_token="tok",
211+
)
212+
cursor = connection.cursor()
213+
214+
# Mock execute_command to return our execute response
215+
cursor.thrift_backend.execute_command = Mock(
216+
return_value=mock_execute_response
217+
)
218+
219+
# Execute a command
220+
cursor.execute("SELECT 1")
221+
222+
# Get the active result set for later assertions
223+
active_result_set = cursor.active_result_set
224+
225+
# Close the connection
226+
connection.close()
227+
228+
# Verify the close logic worked:
229+
# 1. has_been_closed_server_side should always be True after close()
230+
assert active_result_set.has_been_closed_server_side is True
231+
232+
# 2. op_state should always be CLOSED after close()
233+
assert (
234+
active_result_set.op_state
235+
== connection.thrift_backend.CLOSED_OP_STATE
236+
)
237+
238+
# 3. Backend close_command should be called appropriately
239+
if not closed:
240+
# Should have called backend.close_command during the close chain
241+
mock_backend.close_command.assert_called_once_with(
242+
mock_execute_response.command_handle
243+
)
244+
else:
245+
# Should NOT have called backend.close_command (already closed)
246+
mock_backend.close_command.assert_not_called()
247+
241248
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
242249
def test_cant_open_cursor_on_closed_connection(self, mock_client_class):
243250
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)

0 commit comments

Comments
 (0)