Skip to content

Commit 962e4a6

Browse files
fix some tests
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 120dfc0 commit 962e4a6

File tree

4 files changed

+18
-36
lines changed

4 files changed

+18
-36
lines changed

tests/unit/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class):
108108
mock_execute_response.status = (
109109
CommandState.SUCCEEDED if not closed else CommandState.CLOSED
110110
)
111-
mock_execute_response.has_been_closed_server_side = closed
112111
mock_execute_response.is_staging_operation = False
113112
mock_execute_response.description = []
114113

@@ -127,6 +126,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class):
127126
real_result_set = ThriftResultSet(
128127
connection=connection,
129128
execute_response=mock_execute_response,
129+
has_been_closed_server_side=closed,
130130
)
131131

132132
# Mock execute_command to return our real result set

tests/unit/test_fetches.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ def make_dummy_result_set_from_initial_results(initial_results):
6363
execute_response=ExecuteResponse(
6464
command_id=None,
6565
status=None,
66-
has_been_closed_server_side=True,
6766
description=description,
6867
lz4_compressed=True,
6968
is_staging_operation=False,
7069
),
7170
t_row_set=None,
71+
has_been_closed_server_side=True,
7272
)
7373
return rs
7474

@@ -117,11 +117,11 @@ def fetch_results(
117117
execute_response=ExecuteResponse(
118118
command_id=None,
119119
status=None,
120-
has_been_closed_server_side=False,
121120
description=description,
122121
lz4_compressed=True,
123122
is_staging_operation=False,
124123
),
124+
has_been_closed_server_side=False,
125125
)
126126
return rs
127127

tests/unit/test_sea_result_set.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -141,29 +141,6 @@ def test_close(self, mock_connection, execute_response):
141141
mock_connection.session.backend.close_command.assert_called_once_with(
142142
result_set.command_id
143143
)
144-
assert result_set.has_been_closed_server_side is True
145-
assert result_set.status == CommandState.CLOSED
146-
147-
def test_close_when_already_closed_server_side(
148-
self, mock_connection, execute_response
149-
):
150-
"""Test closing a result set that has already been closed server-side."""
151-
result_set = SeaResultSet(
152-
connection=mock_connection,
153-
execute_response=execute_response,
154-
result_data=ResultData(data=[]),
155-
manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY),
156-
buffer_size_bytes=1000,
157-
arraysize=100,
158-
)
159-
result_set.has_been_closed_server_side = True
160-
161-
# Close the result set
162-
result_set.close()
163-
164-
# Verify the backend's close_command was NOT called
165-
mock_connection.session.backend.close_command.assert_not_called()
166-
assert result_set.has_been_closed_server_side is True
167144
assert result_set.status == CommandState.CLOSED
168145

169146
def test_close_when_connection_closed(self, mock_connection, execute_response):
@@ -183,7 +160,6 @@ def test_close_when_connection_closed(self, mock_connection, execute_response):
183160

184161
# Verify the backend's close_command was NOT called
185162
mock_connection.session.backend.close_command.assert_not_called()
186-
assert result_set.has_been_closed_server_side is True
187163
assert result_set.status == CommandState.CLOSED
188164

189165
def test_init_with_result_data(self, result_set_with_data, sample_data):

tests/unit/test_thrift_backend.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ def test_handle_execute_response_sets_compression_in_direct_results(
649649
ssl_options=SSLOptions(),
650650
)
651651

652-
execute_response, _ = thrift_backend._handle_execute_response(
652+
execute_response, _, _, _ = thrift_backend._handle_execute_response(
653653
t_execute_resp, Mock()
654654
)
655655
self.assertEqual(execute_response.lz4_compressed, lz4Compressed)
@@ -892,6 +892,8 @@ def test_handle_execute_response_can_handle_without_direct_results(
892892
(
893893
execute_response,
894894
_,
895+
_,
896+
_
895897
) = thrift_backend._handle_execute_response(execute_resp, Mock())
896898
self.assertEqual(
897899
execute_response.status,
@@ -965,11 +967,11 @@ def test_use_arrow_schema_if_available(self, tcli_service_class):
965967
)
966968
)
967969
thrift_backend = self._make_fake_thrift_backend()
968-
execute_response, _ = thrift_backend._handle_execute_response(
970+
execute_response, _, _, arrow_schema_bytes = thrift_backend._handle_execute_response(
969971
t_execute_resp, Mock()
970972
)
971973

972-
self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock)
974+
self.assertEqual(arrow_schema_bytes, arrow_schema_mock)
973975

974976
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
975977
def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class):
@@ -997,7 +999,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class):
997999
)
9981000
)
9991001
thrift_backend = self._make_fake_thrift_backend()
1000-
_, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock())
1002+
_, _, _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock())
10011003

10021004
self.assertEqual(
10031005
hive_schema_mock,
@@ -1046,6 +1048,8 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results(
10461048
(
10471049
execute_response,
10481050
has_more_rows_result,
1051+
_,
1052+
_
10491053
) = thrift_backend._handle_execute_response(execute_resp, Mock())
10501054

10511055
self.assertEqual(is_direct_results, has_more_rows_result)
@@ -1179,7 +1183,7 @@ def test_execute_statement_calls_client_and_handle_execute_response(
11791183
ssl_options=SSLOptions(),
11801184
)
11811185
thrift_backend._handle_execute_response = Mock()
1182-
thrift_backend._handle_execute_response.return_value = (Mock(), Mock())
1186+
thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock())
11831187
cursor_mock = Mock()
11841188

11851189
result = thrift_backend.execute_command(
@@ -1215,7 +1219,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response(
12151219
ssl_options=SSLOptions(),
12161220
)
12171221
thrift_backend._handle_execute_response = Mock()
1218-
thrift_backend._handle_execute_response.return_value = (Mock(), Mock())
1222+
thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock())
12191223
cursor_mock = Mock()
12201224

12211225
result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock)
@@ -1248,7 +1252,7 @@ def test_get_schemas_calls_client_and_handle_execute_response(
12481252
ssl_options=SSLOptions(),
12491253
)
12501254
thrift_backend._handle_execute_response = Mock()
1251-
thrift_backend._handle_execute_response.return_value = (Mock(), Mock())
1255+
thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock())
12521256
cursor_mock = Mock()
12531257

12541258
result = thrift_backend.get_schemas(
@@ -1290,7 +1294,7 @@ def test_get_tables_calls_client_and_handle_execute_response(
12901294
ssl_options=SSLOptions(),
12911295
)
12921296
thrift_backend._handle_execute_response = Mock()
1293-
thrift_backend._handle_execute_response.return_value = (Mock(), Mock())
1297+
thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock())
12941298
cursor_mock = Mock()
12951299

12961300
result = thrift_backend.get_tables(
@@ -1336,7 +1340,7 @@ def test_get_columns_calls_client_and_handle_execute_response(
13361340
ssl_options=SSLOptions(),
13371341
)
13381342
thrift_backend._handle_execute_response = Mock()
1339-
thrift_backend._handle_execute_response.return_value = (Mock(), Mock())
1343+
thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock())
13401344
cursor_mock = Mock()
13411345

13421346
result = thrift_backend.get_columns(
@@ -2254,6 +2258,8 @@ def test_execute_command_sets_complex_type_fields_correctly(
22542258
mock_handle_execute_response.return_value = (
22552259
mock_execute_response,
22562260
mock_arrow_schema,
2261+
Mock(),
2262+
Mock()
22572263
)
22582264

22592265
# Iterate through each possible combination of native types (True, False and unset)

0 commit comments

Comments
 (0)