Skip to content

Commit d3200c4

Browse files
move Queue construction to ResultSert
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 27158b1 commit d3200c4

File tree

6 files changed

+148
-67
lines changed

6 files changed

+148
-67
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
import math
55
import time
6-
import uuid
76
import threading
87
from typing import List, Union, Any, TYPE_CHECKING
98

@@ -728,7 +727,7 @@ def _col_to_description(col):
728727
else:
729728
precision, scale = None, None
730729

731-
return col.columnName, cleaned_type, None, None, precision, scale, None
730+
return [col.columnName, cleaned_type, None, None, precision, scale, None]
732731

733732
@staticmethod
734733
def _hive_schema_to_description(t_table_schema):
@@ -778,23 +777,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
778777
schema_bytes = None
779778

780779
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
781-
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
782-
if direct_results and direct_results.resultSet:
783-
assert direct_results.resultSet.results.startRowOffset == 0
784-
assert direct_results.resultSetMetadata
785-
786-
arrow_queue_opt = ResultSetQueueFactory.build_queue(
787-
row_set_type=t_result_set_metadata_resp.resultFormat,
788-
t_row_set=direct_results.resultSet.results,
789-
arrow_schema_bytes=schema_bytes,
790-
max_download_threads=self.max_download_threads,
791-
lz4_compressed=lz4_compressed,
792-
description=description,
793-
ssl_options=self._ssl_options,
794-
)
795-
else:
796-
arrow_queue_opt = None
797-
798780
command_id = CommandId.from_thrift_handle(resp.operationHandle)
799781

800782
status = CommandState.from_thrift_state(operation_state)
@@ -806,11 +788,11 @@ def _results_message_to_execute_response(self, resp, operation_state):
806788
status=status,
807789
description=description,
808790
has_more_rows=has_more_rows,
809-
results_queue=arrow_queue_opt,
810791
has_been_closed_server_side=has_been_closed_server_side,
811792
lz4_compressed=lz4_compressed,
812-
is_staging_operation=is_staging_operation,
793+
is_staging_operation=t_result_set_metadata_resp.isStagingOperation,
813794
arrow_schema_bytes=schema_bytes,
795+
result_format=t_result_set_metadata_resp.resultFormat,
814796
)
815797

816798
def get_execution_result(
@@ -837,9 +819,6 @@ def get_execution_result(
837819

838820
t_result_set_metadata_resp = resp.resultSetMetadata
839821

840-
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
841-
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
842-
has_more_rows = resp.hasMoreRows
843822
description = self._hive_schema_to_description(
844823
t_result_set_metadata_resp.schema
845824
)
@@ -854,15 +833,9 @@ def get_execution_result(
854833
else:
855834
schema_bytes = None
856835

857-
queue = ResultSetQueueFactory.build_queue(
858-
row_set_type=resp.resultSetMetadata.resultFormat,
859-
t_row_set=resp.results,
860-
arrow_schema_bytes=schema_bytes,
861-
max_download_threads=self.max_download_threads,
862-
lz4_compressed=lz4_compressed,
863-
description=description,
864-
ssl_options=self._ssl_options,
865-
)
836+
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
837+
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
838+
has_more_rows = resp.hasMoreRows
866839

867840
status = self.get_query_state(command_id)
868841

@@ -871,11 +844,11 @@ def get_execution_result(
871844
status=status,
872845
description=description,
873846
has_more_rows=has_more_rows,
874-
results_queue=queue,
875847
has_been_closed_server_side=False,
876848
lz4_compressed=lz4_compressed,
877849
is_staging_operation=is_staging_operation,
878850
arrow_schema_bytes=schema_bytes,
851+
result_format=t_result_set_metadata_resp.resultFormat,
879852
)
880853

881854
return ThriftResultSet(
@@ -885,6 +858,9 @@ def get_execution_result(
885858
buffer_size_bytes=cursor.buffer_size_bytes,
886859
arraysize=cursor.arraysize,
887860
use_cloud_fetch=cursor.connection.use_cloud_fetch,
861+
t_row_set=resp.results,
862+
max_download_threads=self.max_download_threads,
863+
ssl_options=self._ssl_options,
888864
)
889865

890866
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
@@ -999,13 +975,20 @@ def execute_command(
999975
else:
1000976
execute_response = self._handle_execute_response(resp, cursor)
1001977

978+
t_row_set = None
979+
if resp.directResults and resp.directResults.resultSet:
980+
t_row_set = resp.directResults.resultSet.results
981+
1002982
return ThriftResultSet(
1003983
connection=cursor.connection,
1004984
execute_response=execute_response,
1005985
thrift_client=self,
1006986
buffer_size_bytes=max_bytes,
1007987
arraysize=max_rows,
1008988
use_cloud_fetch=use_cloud_fetch,
989+
t_row_set=t_row_set,
990+
max_download_threads=self.max_download_threads,
991+
ssl_options=self._ssl_options,
1009992
)
1010993

1011994
def get_catalogs(
@@ -1029,13 +1012,20 @@ def get_catalogs(
10291012

10301013
execute_response = self._handle_execute_response(resp, cursor)
10311014

1015+
t_row_set = None
1016+
if resp.directResults and resp.directResults.resultSet:
1017+
t_row_set = resp.directResults.resultSet.results
1018+
10321019
return ThriftResultSet(
10331020
connection=cursor.connection,
10341021
execute_response=execute_response,
10351022
thrift_client=self,
10361023
buffer_size_bytes=max_bytes,
10371024
arraysize=max_rows,
10381025
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1026+
t_row_set=t_row_set,
1027+
max_download_threads=self.max_download_threads,
1028+
ssl_options=self._ssl_options,
10391029
)
10401030

10411031
def get_schemas(
@@ -1063,13 +1053,20 @@ def get_schemas(
10631053

10641054
execute_response = self._handle_execute_response(resp, cursor)
10651055

1056+
t_row_set = None
1057+
if resp.directResults and resp.directResults.resultSet:
1058+
t_row_set = resp.directResults.resultSet.results
1059+
10661060
return ThriftResultSet(
10671061
connection=cursor.connection,
10681062
execute_response=execute_response,
10691063
thrift_client=self,
10701064
buffer_size_bytes=max_bytes,
10711065
arraysize=max_rows,
10721066
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1067+
t_row_set=t_row_set,
1068+
max_download_threads=self.max_download_threads,
1069+
ssl_options=self._ssl_options,
10731070
)
10741071

10751072
def get_tables(
@@ -1101,13 +1098,20 @@ def get_tables(
11011098

11021099
execute_response = self._handle_execute_response(resp, cursor)
11031100

1101+
t_row_set = None
1102+
if resp.directResults and resp.directResults.resultSet:
1103+
t_row_set = resp.directResults.resultSet.results
1104+
11041105
return ThriftResultSet(
11051106
connection=cursor.connection,
11061107
execute_response=execute_response,
11071108
thrift_client=self,
11081109
buffer_size_bytes=max_bytes,
11091110
arraysize=max_rows,
11101111
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1112+
t_row_set=t_row_set,
1113+
max_download_threads=self.max_download_threads,
1114+
ssl_options=self._ssl_options,
11111115
)
11121116

11131117
def get_columns(
@@ -1139,13 +1143,20 @@ def get_columns(
11391143

11401144
execute_response = self._handle_execute_response(resp, cursor)
11411145

1146+
t_row_set = None
1147+
if resp.directResults and resp.directResults.resultSet:
1148+
t_row_set = resp.directResults.resultSet.results
1149+
11421150
return ThriftResultSet(
11431151
connection=cursor.connection,
11441152
execute_response=execute_response,
11451153
thrift_client=self,
11461154
buffer_size_bytes=max_bytes,
11471155
arraysize=max_rows,
11481156
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1157+
t_row_set=t_row_set,
1158+
max_download_threads=self.max_download_threads,
1159+
ssl_options=self._ssl_options,
11491160
)
11501161

11511162
def _handle_execute_response(self, resp, cursor):
@@ -1203,6 +1214,8 @@ def fetch_results(
12031214
)
12041215
)
12051216

1217+
from databricks.sql.utils import ResultSetQueueFactory
1218+
12061219
queue = ResultSetQueueFactory.build_queue(
12071220
row_set_type=resp.resultSetMetadata.resultFormat,
12081221
t_row_set=resp.results,

src/databricks/sql/backend/types.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,12 +423,10 @@ class ExecuteResponse:
423423

424424
command_id: CommandId
425425
status: CommandState
426-
description: Optional[
427-
List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]]
428-
] = None
426+
description: Optional[List[List[Any]]] = None
429427
has_more_rows: bool = False
430-
results_queue: Optional[Any] = None
431428
has_been_closed_server_side: bool = False
432429
lz4_compressed: bool = True
433430
is_staging_operation: bool = False
434431
arrow_schema_bytes: Optional[bytes] = None
432+
result_format: Optional[Any] = None

src/databricks/sql/result_set.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ def __init__(
157157
buffer_size_bytes: int = 104857600,
158158
arraysize: int = 10000,
159159
use_cloud_fetch: bool = True,
160+
t_row_set=None,
161+
max_download_threads: int = 10,
162+
ssl_options=None,
160163
):
161164
"""
162165
Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient.
@@ -168,12 +171,31 @@ def __init__(
168171
buffer_size_bytes: Buffer size for fetching results
169172
arraysize: Default number of rows to fetch
170173
use_cloud_fetch: Whether to use cloud fetch for retrieving results
174+
t_row_set: The TRowSet containing result data (if available)
175+
max_download_threads: Maximum number of download threads for cloud fetch
176+
ssl_options: SSL options for cloud fetch
171177
"""
172178
# Initialize ThriftResultSet-specific attributes
173179
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
174180
self._use_cloud_fetch = use_cloud_fetch
175181
self.lz4_compressed = execute_response.lz4_compressed
176182

183+
# Build the results queue if t_row_set is provided
184+
results_queue = None
185+
if t_row_set and execute_response.result_format is not None:
186+
from databricks.sql.utils import ResultSetQueueFactory
187+
188+
# Create the results queue using the provided format
189+
results_queue = ResultSetQueueFactory.build_queue(
190+
row_set_type=execute_response.result_format,
191+
t_row_set=t_row_set,
192+
arrow_schema_bytes=execute_response.arrow_schema_bytes or b"",
193+
max_download_threads=max_download_threads,
194+
lz4_compressed=execute_response.lz4_compressed,
195+
description=execute_response.description,
196+
ssl_options=ssl_options,
197+
)
198+
177199
# Call parent constructor with common attributes
178200
super().__init__(
179201
connection=connection,
@@ -184,7 +206,7 @@ def __init__(
184206
status=execute_response.status,
185207
has_been_closed_server_side=execute_response.has_been_closed_server_side,
186208
has_more_rows=execute_response.has_more_rows,
187-
results_queue=execute_response.results_queue,
209+
results_queue=results_queue,
188210
description=execute_response.description,
189211
is_staging_operation=execute_response.is_staging_operation,
190212
)

tests/unit/test_client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class):
104104
# Mock the backend that will be used by the real ThriftResultSet
105105
mock_backend = Mock(spec=ThriftDatabricksClient)
106106
mock_backend.staging_allowed_local_path = None
107+
mock_backend.fetch_results.return_value = (Mock(), False)
107108

108109
# Configure the decorator's mock to return our specific mock_backend
109110
mock_thrift_client_class.return_value = mock_backend
@@ -184,6 +185,7 @@ def test_arraysize_buffer_size_passthrough(
184185
def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
185186
mock_connection = Mock()
186187
mock_backend = Mock()
188+
mock_backend.fetch_results.return_value = (Mock(), False)
187189

188190
result_set = ThriftResultSet(
189191
connection=mock_connection,
@@ -210,6 +212,8 @@ def test_closing_result_set_hard_closes_commands(self):
210212
mock_session.open = True
211213
type(mock_connection).session = PropertyMock(return_value=mock_session)
212214

215+
mock_thrift_backend.fetch_results.return_value = (Mock(), False)
216+
213217
result_set = ThriftResultSet(
214218
mock_connection, mock_results_response, mock_thrift_backend
215219
)
@@ -254,7 +258,10 @@ def test_closed_cursor_doesnt_allow_operations(self):
254258
self.assertIn("closed", e.msg)
255259

256260
def test_negative_fetch_throws_exception(self):
257-
result_set = ThriftResultSet(Mock(), Mock(), Mock())
261+
mock_backend = Mock()
262+
mock_backend.fetch_results.return_value = (Mock(), False)
263+
264+
result_set = ThriftResultSet(Mock(), Mock(), mock_backend)
258265

259266
with self.assertRaises(ValueError) as e:
260267
result_set.fetchmany(-1)

tests/unit/test_fetches.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,25 +40,34 @@ def make_dummy_result_set_from_initial_results(initial_results):
4040
# If the initial results have been set, then we should never try and fetch more
4141
schema, arrow_table = FetchTests.make_arrow_table(initial_results)
4242
arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0)
43+
44+
# Create a mock backend that will return the queue when _fill_results_buffer is called
45+
mock_thrift_backend = Mock(spec=ThriftDatabricksClient)
46+
mock_thrift_backend.fetch_results.return_value = (arrow_queue, False)
47+
48+
num_cols = len(initial_results[0]) if initial_results else 0
49+
description = [
50+
(f"col{col_id}", "integer", None, None, None, None, None)
51+
for col_id in range(num_cols)
52+
]
53+
4354
rs = ThriftResultSet(
4455
connection=Mock(),
4556
execute_response=ExecuteResponse(
4657
command_id=None,
4758
status=None,
4859
has_been_closed_server_side=True,
4960
has_more_rows=False,
50-
description=Mock(),
51-
lz4_compressed=Mock(),
52-
results_queue=arrow_queue,
61+
description=description,
62+
lz4_compressed=True,
5363
is_staging_operation=False,
5464
),
55-
thrift_client=None,
65+
thrift_client=mock_thrift_backend,
66+
t_row_set=None,
5667
)
57-
num_cols = len(initial_results[0]) if initial_results else 0
58-
rs.description = [
59-
(f"col{col_id}", "integer", None, None, None, None, None)
60-
for col_id in range(num_cols)
61-
]
68+
69+
# Replace the results queue with our arrow_queue
70+
rs.results = arrow_queue
6271
return rs
6372

6473
@staticmethod
@@ -85,19 +94,20 @@ def fetch_results(
8594
mock_thrift_backend.fetch_results = fetch_results
8695
num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0
8796

97+
description = [
98+
(f"col{col_id}", "integer", None, None, None, None, None)
99+
for col_id in range(num_cols)
100+
]
101+
88102
rs = ThriftResultSet(
89103
connection=Mock(),
90104
execute_response=ExecuteResponse(
91105
command_id=None,
92106
status=None,
93107
has_been_closed_server_side=False,
94108
has_more_rows=True,
95-
description=[
96-
(f"col{col_id}", "integer", None, None, None, None, None)
97-
for col_id in range(num_cols)
98-
],
99-
lz4_compressed=Mock(),
100-
results_queue=None,
109+
description=description,
110+
lz4_compressed=True,
101111
is_staging_operation=False,
102112
),
103113
thrift_client=mock_thrift_backend,

0 commit comments

Comments
 (0)