Skip to content

Commit 6862929

Browse files
introduce SeaResultSetQueueFactory
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent c540987 commit 6862929

File tree

6 files changed

+194
-86
lines changed

6 files changed

+194
-86
lines changed

src/databricks/sql/backend/sea_result_set.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from databricks.sql.types import Row
1919
from databricks.sql.backend.types import CommandId, CommandState
2020
from databricks.sql.exc import Error
21-
from databricks.sql.utils import ResultSetQueueFactory, JsonQueue
21+
from databricks.sql.utils import SeaResultSetQueueFactory, JsonQueue
2222

2323
from databricks.sql.backend.models import (
2424
StatementStatus,
@@ -117,9 +117,9 @@ def __init__(
117117

118118
# Initialize queue for result data
119119
if self.result:
120-
self.results = ResultSetQueueFactory.build_queue(
120+
self.results = SeaResultSetQueueFactory.build_queue(
121121
sea_result_data=self.result,
122-
description=cast(Optional[List[List[Any]]], self.description)
122+
description=cast(Optional[List[List[Any]]], self.description),
123123
)
124124
self._has_more_rows = True if self.result.data else False
125125
else:
@@ -168,7 +168,7 @@ def _fill_results_buffer(self) -> None:
168168
if not self.result or not self.result.data:
169169
self._has_more_rows = False
170170
return
171-
171+
172172
# For INLINE disposition, we already have all the data
173173
# No need to fetch more data from the backend
174174
self._has_more_rows = False # No more rows to fetch for INLINE
@@ -177,21 +177,21 @@ def _convert_rows_to_arrow_table(self, rows):
177177
"""Convert rows to Arrow table."""
178178
if not self.description:
179179
return pyarrow.Table.from_pylist([])
180-
180+
181181
# Create dict of column data
182182
column_data = {}
183183
column_names = [col[0] for col in self.description]
184-
184+
185185
for i, name in enumerate(column_names):
186186
column_data[name] = [row[i] for row in rows]
187-
187+
188188
return pyarrow.Table.from_pydict(column_data)
189189

190190
def _create_empty_arrow_table(self):
191191
"""Create an empty Arrow table with the correct schema."""
192192
if not self.description:
193193
return pyarrow.Table.from_pylist([])
194-
194+
195195
column_names = [col[0] for col in self.description]
196196
return pyarrow.Table.from_pydict({name: [] for name in column_names})
197197

@@ -201,9 +201,9 @@ def fetchone(self) -> Optional[Row]:
201201
rows = self.results.next_n_rows(1)
202202
if not rows:
203203
return None
204-
204+
205205
row = rows[0]
206-
206+
207207
# Convert to Row object
208208
if self.description:
209209
column_names = [col[0] for col in self.description]
@@ -219,13 +219,13 @@ def fetchmany(self, size: Optional[int] = None) -> List[Row]:
219219
"""Fetch the next set of rows of a query result."""
220220
if size is None:
221221
size = self.arraysize
222-
222+
223223
if size < 0:
224224
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")
225-
225+
226226
if isinstance(self.results, JsonQueue):
227227
rows = self.results.next_n_rows(size)
228-
228+
229229
# Convert to Row objects
230230
if self.description:
231231
column_names = [col[0] for col in self.description]
@@ -241,7 +241,7 @@ def fetchall(self) -> List[Row]:
241241
"""Fetch all remaining rows of a query result."""
242242
if isinstance(self.results, JsonQueue):
243243
rows = self.results.remaining_rows()
244-
244+
245245
# Convert to Row objects
246246
if self.description:
247247
column_names = [col[0] for col in self.description]
@@ -257,25 +257,25 @@ def fetchmany_arrow(self, size: int) -> Any:
257257
"""Fetch the next set of rows as an Arrow table."""
258258
if not pyarrow:
259259
raise ImportError("PyArrow is required for Arrow support")
260-
260+
261261
rows = self.fetchmany(size)
262262
if not rows:
263263
# Return empty Arrow table with schema
264264
return self._create_empty_arrow_table()
265-
265+
266266
# Convert rows to Arrow table
267267
return self._convert_rows_to_arrow_table(rows)
268268

269269
def fetchall_arrow(self) -> Any:
270270
"""Fetch all remaining rows as an Arrow table."""
271271
if not pyarrow:
272272
raise ImportError("PyArrow is required for Arrow support")
273-
273+
274274
rows = self.fetchall()
275275
if not rows:
276276
# Return empty Arrow table with schema
277277
return self._create_empty_arrow_table()
278-
278+
279279
# Convert rows to Arrow table
280280
return self._convert_rows_to_arrow_table(rows)
281281

@@ -288,4 +288,4 @@ def close(self) -> None:
288288
CommandId.from_sea_statement_id(self.statement_id)
289289
)
290290
except Exception as e:
291-
logger.warning(f"Error closing SEA statement: {e}")
291+
logger.warning(f"Error closing SEA statement: {e}")

src/databricks/sql/backend/thrift_backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
_bound,
4646
RequestErrorInfo,
4747
NoRetryReason,
48-
ResultSetQueueFactory,
48+
ThriftResultSetQueueFactory,
4949
convert_arrow_based_set_to_arrow_table,
5050
convert_decimals_in_arrow_table,
5151
convert_column_based_set_to_arrow_table,
@@ -792,7 +792,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
792792
assert direct_results.resultSet.results.startRowOffset == 0
793793
assert direct_results.resultSetMetadata
794794

795-
arrow_queue_opt = ResultSetQueueFactory.build_queue(
795+
arrow_queue_opt = ThriftResultSetQueueFactory.build_queue(
796796
row_set_type=t_result_set_metadata_resp.resultFormat,
797797
t_row_set=direct_results.resultSet.results,
798798
arrow_schema_bytes=schema_bytes,
@@ -859,7 +859,7 @@ def get_execution_result(
859859
else:
860860
schema_bytes = None
861861

862-
queue = ResultSetQueueFactory.build_queue(
862+
queue = ThriftResultSetQueueFactory.build_queue(
863863
row_set_type=resp.resultSetMetadata.resultFormat,
864864
t_row_set=resp.results,
865865
arrow_schema_bytes=schema_bytes,
@@ -1207,7 +1207,7 @@ def fetch_results(
12071207
)
12081208
)
12091209

1210-
queue = ResultSetQueueFactory.build_queue(
1210+
queue = ThriftResultSetQueueFactory.build_queue(
12111211
row_set_type=resp.resultSetMetadata.resultFormat,
12121212
t_row_set=resp.results,
12131213
arrow_schema_bytes=arrow_schema_bytes,

src/databricks/sql/utils.py

Lines changed: 70 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -48,30 +48,7 @@ def remaining_rows(self):
4848
pass
4949

5050

51-
class JsonQueue(ResultSetQueue):
52-
"""Queue implementation for JSON_ARRAY format data."""
53-
54-
def __init__(self, data_array):
55-
"""Initialize with JSON array data."""
56-
self.data_array = data_array
57-
self.cur_row_index = 0
58-
self.n_valid_rows = len(data_array)
59-
60-
def next_n_rows(self, num_rows):
61-
"""Get the next n rows from the data array."""
62-
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
63-
slice = self.data_array[self.cur_row_index:self.cur_row_index + length]
64-
self.cur_row_index += length
65-
return slice
66-
67-
def remaining_rows(self):
68-
"""Get all remaining rows from the data array."""
69-
slice = self.data_array[self.cur_row_index:]
70-
self.cur_row_index += len(slice)
71-
return slice
72-
73-
74-
class ResultSetQueueFactory(ABC):
51+
class ThriftResultSetQueueFactory(ABC):
7552
@staticmethod
7653
def build_queue(
7754
row_set_type: Optional[TSparkRowSetType] = None,
@@ -81,57 +58,38 @@ def build_queue(
8158
ssl_options: Optional[SSLOptions] = None,
8259
lz4_compressed: bool = True,
8360
description: Optional[List[List[Any]]] = None,
84-
# SEA specific parameters
85-
sea_result_data: Optional[Any] = None,
8661
) -> ResultSetQueue:
8762
"""
88-
Factory method to build a result set queue.
89-
90-
This method is extended to handle both Thrift and SEA result formats.
91-
For SEA, the sea_result_data parameter is used instead of row_set_type and t_row_set.
92-
63+
Factory method to build a result set queue for Thrift backend.
64+
9365
Args:
94-
# Thrift parameters
9566
row_set_type (enum): Row set type (Arrow, Column, or URL).
9667
t_row_set (TRowSet): Result containing arrow batches, columns, or cloud fetch links.
97-
98-
# Common parameters
9968
arrow_schema_bytes (bytes): Bytes representing the arrow schema.
10069
lz4_compressed (bool): Whether result data has been lz4 compressed.
10170
description (List[List[Any]]): Hive table schema description.
10271
max_download_threads (int): Maximum number of downloader thread pool threads.
10372
ssl_options (SSLOptions): SSLOptions object for CloudFetchQueue
104-
105-
# SEA parameters
106-
sea_result_data (ResultData): Result data from SEA response
107-
73+
10874
Returns:
10975
ResultSetQueue
11076
"""
111-
# Handle SEA result data
112-
if sea_result_data is not None:
113-
if sea_result_data.data:
114-
# INLINE disposition with JSON_ARRAY format
115-
return JsonQueue(sea_result_data.data)
116-
elif sea_result_data.external_links:
117-
# EXTERNAL_LINKS disposition (not implemented yet)
118-
raise NotImplementedError(
119-
"EXTERNAL_LINKS disposition is not supported yet"
120-
)
121-
else:
122-
# Empty result set
123-
return JsonQueue([])
124-
125-
# Handle Thrift result data (existing implementation)
126-
if row_set_type == TSparkRowSetType.ARROW_BASED_SET and t_row_set is not None and arrow_schema_bytes is not None:
77+
# Handle Thrift result data
78+
if (
79+
row_set_type == TSparkRowSetType.ARROW_BASED_SET
80+
and t_row_set is not None
81+
and arrow_schema_bytes is not None
82+
):
12783
arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table(
12884
t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes
12985
)
13086
converted_arrow_table = convert_decimals_in_arrow_table(
13187
arrow_table, description
13288
)
13389
return ArrowQueue(converted_arrow_table, n_valid_rows)
134-
elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET and t_row_set is not None:
90+
elif (
91+
row_set_type == TSparkRowSetType.COLUMN_BASED_SET and t_row_set is not None
92+
):
13593
column_table, column_names = convert_column_based_set_to_column_table(
13694
t_row_set.columns, description
13795
)
@@ -141,7 +99,13 @@ def build_queue(
14199
)
142100

143101
return ColumnQueue(ColumnTable(converted_column_table, column_names))
144-
elif row_set_type == TSparkRowSetType.URL_BASED_SET and t_row_set is not None and arrow_schema_bytes is not None and max_download_threads is not None and ssl_options is not None:
102+
elif (
103+
row_set_type == TSparkRowSetType.URL_BASED_SET
104+
and t_row_set is not None
105+
and arrow_schema_bytes is not None
106+
and max_download_threads is not None
107+
and ssl_options is not None
108+
):
145109
return CloudFetchQueue(
146110
schema_bytes=arrow_schema_bytes,
147111
start_row_offset=t_row_set.startRowOffset,
@@ -155,6 +119,56 @@ def build_queue(
155119
raise AssertionError("Row set type is not valid")
156120

157121

122+
class SeaResultSetQueueFactory(ABC):
123+
@staticmethod
124+
def build_queue(
125+
sea_result_data: Any,
126+
description: Optional[List[List[Any]]] = None,
127+
) -> ResultSetQueue:
128+
"""
129+
Factory method to build a result set queue for SEA backend.
130+
131+
Args:
132+
sea_result_data (ResultData): Result data from SEA response
133+
description (List[List[Any]]): Column descriptions
134+
135+
Returns:
136+
ResultSetQueue: The appropriate queue for the result data
137+
"""
138+
if sea_result_data.data:
139+
# INLINE disposition with JSON_ARRAY format
140+
return JsonQueue(sea_result_data.data)
141+
elif sea_result_data.external_links:
142+
# EXTERNAL_LINKS disposition (not implemented yet)
143+
raise NotImplementedError("EXTERNAL_LINKS disposition is not supported yet")
144+
else:
145+
# Empty result set
146+
return JsonQueue([])
147+
148+
149+
class JsonQueue(ResultSetQueue):
150+
"""Queue implementation for JSON_ARRAY format data."""
151+
152+
def __init__(self, data_array):
153+
"""Initialize with JSON array data."""
154+
self.data_array = data_array
155+
self.cur_row_index = 0
156+
self.n_valid_rows = len(data_array)
157+
158+
def next_n_rows(self, num_rows):
159+
"""Get the next n rows from the data array."""
160+
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
161+
slice = self.data_array[self.cur_row_index : self.cur_row_index + length]
162+
self.cur_row_index += length
163+
return slice
164+
165+
def remaining_rows(self):
166+
"""Get all remaining rows from the data array."""
167+
slice = self.data_array[self.cur_row_index :]
168+
self.cur_row_index += len(slice)
169+
return slice
170+
171+
158172
class ColumnTable:
159173
def __init__(self, column_table, column_names):
160174
self.column_table = column_table

0 commit comments

Comments
 (0)