diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 5271baa70..cf8779a21 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -13,6 +13,7 @@ TExecuteStatementResp, TOperationHandle, THandleIdentifier, + TOperationState, TOperationType, ) from databricks.sql.thrift_backend import ThriftBackend @@ -23,6 +24,7 @@ from databricks.sql.exc import RequestError, CursorAlreadyClosedError from databricks.sql.types import Row +from databricks.sql.utils import ExecuteResponse from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite from tests.unit.test_arrow_queue import ArrowQueueSuite @@ -168,22 +170,78 @@ def test_useragent_header(self, mock_client_class): http_headers = mock_client_class.call_args[0][3] self.assertIn(user_agent_header_with_entry, http_headers) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) - @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_closing_connection_closes_commands(self, mock_result_set_class): - # Test once with has_been_closed_server side, once without + @patch("databricks.sql.client.ThriftBackend") + def test_closing_connection_closes_commands(self, mock_thrift_client_class): + """Test that closing a connection properly closes commands. + + This test verifies that when a connection is closed: + 1. the active result set is marked as closed server-side + 2. The operation state is set to CLOSED + 3. backend.close_command is called only for commands that weren't already closed + + Args: + mock_thrift_client_class: Mock for ThriftBackend class + """ for closed in (True, False): with self.subTest(closed=closed): - mock_result_set_class.return_value = Mock() - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + # Set initial state based on whether the command is already closed + initial_state = ( + TOperationState.FINISHED_STATE + if not closed + else TOperationState.CLOSED_STATE + ) + + # Mock the execute response with controlled state + mock_execute_response = Mock(spec=ExecuteResponse) + mock_execute_response.status = initial_state + mock_execute_response.has_been_closed_server_side = closed + mock_execute_response.is_staging_operation = False + + # Mock the backend that will be used + mock_backend = Mock(spec=ThriftBackend) + mock_thrift_client_class.return_value = mock_backend + + # Create connection and cursor + connection = databricks.sql.connect( + server_hostname="foo", + http_path="dummy_path", + access_token="tok", + ) cursor = connection.cursor() - cursor.execute("SELECT 1;") + + # Mock execute_command to return our execute response + cursor.thrift_backend.execute_command = Mock( + return_value=mock_execute_response + ) + + # Execute a command + cursor.execute("SELECT 1") + + # Get the active result set for later assertions + active_result_set = cursor.active_result_set + + # Close the connection connection.close() - self.assertTrue( - mock_result_set_class.return_value.has_been_closed_server_side + # Verify the close logic worked: + # 1. has_been_closed_server_side should always be True after close() + assert active_result_set.has_been_closed_server_side is True + + # 2. op_state should always be CLOSED after close() + assert ( + active_result_set.op_state + == connection.thrift_backend.CLOSED_OP_STATE ) - mock_result_set_class.return_value.close.assert_called_once_with() + + # 3. Backend close_command should be called appropriately + if not closed: + # Should have called backend.close_command during the close chain + mock_backend.close_command.assert_called_once_with( + mock_execute_response.command_handle + ) + else: + # Should NOT have called backend.close_command (already closed) + mock_backend.close_command.assert_not_called() @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_cant_open_cursor_on_closed_connection(self, mock_client_class):