Skip to content

Commit d882c6e

Browse files
Merge branch 'sea-migration' into ext-links-sea
2 parents 1920375 + 922c448 commit d882c6e

File tree

4 files changed

+24
-0
lines changed

4 files changed

+24
-0
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def remaining_rows(self) -> List[List[str]]:
9797
self.cur_row_index += len(slice)
9898
return slice
9999

100+
def close(self):
101+
return
102+
100103

101104
class SeaCloudFetchQueue(CloudFetchQueue):
102105
"""Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend."""

src/databricks/sql/result_set.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def close(self) -> None:
169169
been closed on the server for some other reason, issue a request to the server to close it.
170170
"""
171171
try:
172+
self.results.close()
172173
if (
173174
self.status != CommandState.CLOSED
174175
and not self.has_been_closed_server_side

src/databricks/sql/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def next_n_rows(self, num_rows: int):
4747
def remaining_rows(self):
4848
pass
4949

50+
@abstractmethod
51+
def close(self):
52+
pass
53+
5054

5155
class ThriftResultSetQueueFactory(ABC):
5256
@staticmethod
@@ -159,6 +163,9 @@ def remaining_rows(self):
159163
self.cur_row_index += slice.num_rows
160164
return slice
161165

166+
def close(self):
167+
return
168+
162169

163170
class ArrowQueue(ResultSetQueue):
164171
def __init__(
@@ -196,6 +203,9 @@ def remaining_rows(self) -> "pyarrow.Table":
196203
self.cur_row_index += slice.num_rows
197204
return slice
198205

206+
def close(self):
207+
return
208+
199209

200210
class CloudFetchQueue(ResultSetQueue, ABC):
201211
"""Base class for cloud fetch queues that handle EXTERNAL_LINKS disposition with ARROW format."""
@@ -326,6 +336,9 @@ def _create_empty_table(self) -> "pyarrow.Table":
326336
return pyarrow.Table.from_pydict({})
327337
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)
328338

339+
def close(self):
340+
self.download_manager._shutdown_manager()
341+
329342

330343
class ThriftCloudFetchQueue(CloudFetchQueue):
331344
"""Queue implementation for EXTERNAL_LINKS disposition with ARROW format for Thrift backend."""

tests/unit/test_client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,16 @@ def test_arraysize_buffer_size_passthrough(
188188
def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
189189
mock_connection = Mock()
190190
mock_backend = Mock()
191+
mock_results = Mock()
191192
mock_backend.fetch_results.return_value = (Mock(), False)
192193

193194
result_set = ThriftResultSet(
194195
connection=mock_connection,
195196
execute_response=Mock(),
196197
thrift_client=mock_backend,
197198
)
199+
result_set.results = mock_results
200+
198201
# Setup session mock on the mock_connection
199202
mock_session = Mock()
200203
mock_session.open = False
@@ -204,12 +207,14 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
204207

205208
self.assertFalse(mock_backend.close_command.called)
206209
self.assertTrue(result_set.has_been_closed_server_side)
210+
mock_results.close.assert_called_once()
207211

208212
def test_closing_result_set_hard_closes_commands(self):
209213
mock_results_response = Mock()
210214
mock_results_response.has_been_closed_server_side = False
211215
mock_connection = Mock()
212216
mock_thrift_backend = Mock()
217+
mock_results = Mock()
213218
# Setup session mock on the mock_connection
214219
mock_session = Mock()
215220
mock_session.open = True
@@ -219,12 +224,14 @@ def test_closing_result_set_hard_closes_commands(self):
219224
result_set = ThriftResultSet(
220225
mock_connection, mock_results_response, mock_thrift_backend
221226
)
227+
result_set.results = mock_results
222228

223229
result_set.close()
224230

225231
mock_thrift_backend.close_command.assert_called_once_with(
226232
mock_results_response.command_id
227233
)
234+
mock_results.close.assert_called_once()
228235

229236
def test_executing_multiple_commands_uses_the_most_recent_command(self):
230237
mock_result_sets = [Mock(), Mock()]

0 commit comments

Comments
 (0)