Skip to content

Commit dd7c410

Browse files
add back SeaResltSet (to fix)
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 54c7f6d commit dd7c410

File tree

5 files changed

+270
-14
lines changed

5 files changed

+270
-14
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
ExecuteResponse,
1818
)
1919
from databricks.sql.exc import Error, NotSupportedError, ServerOperationError
20-
from databricks.sql.backend.sea.utils.http_client import CustomHttpClient
20+
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient
2121
from databricks.sql.thrift_api.TCLIService import ttypes
2222
from databricks.sql.types import SSLOptions
2323
from databricks.sql.utils import SeaResultSetQueueFactory

src/databricks/sql/backend/thrift_backend.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
)
4141

4242
from databricks.sql.utils import (
43-
ResultSetQueueFactory,
4443
_bound,
4544
RequestErrorInfo,
4645
NoRetryReason,
@@ -1219,9 +1218,9 @@ def fetch_results(
12191218
)
12201219
)
12211220

1222-
from databricks.sql.utils import ResultSetQueueFactory
1221+
from databricks.sql.utils import ThriftResultSetQueueFactory
12231222

1224-
queue = ResultSetQueueFactory.build_queue(
1223+
queue = ThriftResultSetQueueFactory.build_queue(
12251224
row_set_type=resp.resultSetMetadata.resultFormat,
12261225
t_row_set=resp.results,
12271226
arrow_schema_bytes=arrow_schema_bytes,

src/databricks/sql/result_set.py

Lines changed: 260 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas
77

88
from databricks.sql.backend.sea.backend import SeaDatabricksClient
9+
from databricks.sql.cloud_fetch_queue import SeaCloudFetchQueue
910

1011
try:
1112
import pyarrow
@@ -19,7 +20,7 @@
1920
from databricks.sql.thrift_api.TCLIService import ttypes
2021
from databricks.sql.types import Row
2122
from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError
22-
from databricks.sql.utils import ColumnTable, ColumnQueue
23+
from databricks.sql.utils import ColumnTable, ColumnQueue, JsonQueue
2324
from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse
2425

2526
logger = logging.getLogger(__name__)
@@ -183,10 +184,10 @@ def __init__(
183184
# Build the results queue if t_row_set is provided
184185
results_queue = None
185186
if t_row_set and execute_response.result_format is not None:
186-
from databricks.sql.utils import ResultSetQueueFactory
187+
from databricks.sql.utils import ThriftResultSetQueueFactory
187188

188189
# Create the results queue using the provided format
189-
results_queue = ResultSetQueueFactory.build_queue(
190+
results_queue = ThriftResultSetQueueFactory.build_queue(
190191
row_set_type=execute_response.result_format,
191192
t_row_set=t_row_set,
192193
arrow_schema_bytes=execute_response.arrow_schema_bytes or b"",
@@ -436,3 +437,259 @@ def map_col_type(type_):
436437
(column.name, map_col_type(column.datatype), None, None, None, None, None)
437438
for column in table_schema_message.columns
438439
]
440+
441+
class SeaResultSet(ResultSet):
442+
"""ResultSet implementation for SEA backend."""
443+
444+
def __init__(
445+
self,
446+
connection: "Connection",
447+
execute_response: "ExecuteResponse",
448+
sea_client: "SeaDatabricksClient",
449+
buffer_size_bytes: int = 104857600,
450+
arraysize: int = 10000,
451+
):
452+
"""
453+
Initialize a SeaResultSet with the response from a SEA query execution.
454+
455+
Args:
456+
connection: The parent connection
457+
execute_response: Response from the execute command
458+
sea_client: The SeaDatabricksClient instance for direct access
459+
buffer_size_bytes: Buffer size for fetching results
460+
arraysize: Default number of rows to fetch
461+
"""
462+
# Extract and store SEA-specific properties
463+
self.statement_id = (
464+
execute_response.command_id.to_sea_statement_id()
465+
if execute_response.command_id
466+
else None
467+
)
468+
469+
# Call parent constructor with common attributes
470+
super().__init__(
471+
connection=connection,
472+
backend=sea_client,
473+
arraysize=arraysize,
474+
buffer_size_bytes=buffer_size_bytes,
475+
command_id=execute_response.command_id,
476+
status=execute_response.status,
477+
has_been_closed_server_side=execute_response.has_been_closed_server_side,
478+
has_more_rows=execute_response.has_more_rows,
479+
results_queue=execute_response.results_queue,
480+
description=execute_response.description,
481+
is_staging_operation=execute_response.is_staging_operation,
482+
)
483+
484+
# Initialize queue for result data if not provided
485+
if not self.results:
486+
self.results = JsonQueue([])
487+
488+
def _convert_to_row_objects(self, rows):
489+
"""
490+
Convert raw data rows to Row objects with named columns based on description.
491+
492+
Args:
493+
rows: List of raw data rows
494+
495+
Returns:
496+
List of Row objects with named columns
497+
"""
498+
if not self.description or not rows:
499+
return rows
500+
501+
column_names = [col[0] for col in self.description]
502+
ResultRow = Row(*column_names)
503+
return [ResultRow(*row) for row in rows]
504+
505+
def _fill_results_buffer(self):
506+
"""Fill the results buffer from the backend."""
507+
# For INLINE disposition, we already have all the data
508+
# No need to fetch more data from the backend
509+
self._has_more_rows = False
510+
511+
def _convert_rows_to_arrow_table(self, rows):
512+
"""Convert rows to Arrow table."""
513+
if not self.description:
514+
return pyarrow.Table.from_pylist([])
515+
516+
# Create dict of column data
517+
column_data = {}
518+
column_names = [col[0] for col in self.description]
519+
520+
for i, name in enumerate(column_names):
521+
column_data[name] = [row[i] for row in rows]
522+
523+
return pyarrow.Table.from_pydict(column_data)
524+
525+
def _create_empty_arrow_table(self):
526+
"""Create an empty Arrow table with the correct schema."""
527+
if not self.description:
528+
return pyarrow.Table.from_pylist([])
529+
530+
column_names = [col[0] for col in self.description]
531+
return pyarrow.Table.from_pydict({name: [] for name in column_names})
532+
533+
def fetchone(self) -> Optional[Row]:
534+
"""
535+
Fetch the next row of a query result set, returning a single sequence,
536+
or None when no more data is available.
537+
"""
538+
# Note: We check for the specific queue type to maintain consistency with ThriftResultSet
539+
# This pattern is maintained from the existing code
540+
if isinstance(self.results, JsonQueue):
541+
rows = self.results.next_n_rows(1)
542+
if not rows:
543+
return None
544+
545+
# Convert to Row object
546+
converted_rows = self._convert_to_row_objects(rows)
547+
return converted_rows[0] if converted_rows else None
548+
elif isinstance(self.results, SeaCloudFetchQueue):
549+
# For ARROW format with EXTERNAL_LINKS disposition
550+
arrow_table = self.results.next_n_rows(1)
551+
if arrow_table.num_rows == 0:
552+
return None
553+
554+
# Convert Arrow table to Row object
555+
column_names = [col[0] for col in self.description]
556+
ResultRow = Row(*column_names)
557+
558+
# Get the first row as a list of values
559+
row_values = [
560+
arrow_table.column(i)[0].as_py() for i in range(arrow_table.num_columns)
561+
]
562+
563+
# Increment the row index
564+
self._next_row_index += 1
565+
566+
return ResultRow(*row_values)
567+
else:
568+
# This should not happen with current implementation
569+
raise NotImplementedError("Unsupported queue type")
570+
571+
def fetchmany(self, size: Optional[int] = None) -> List[Row]:
572+
"""
573+
Fetch the next set of rows of a query result, returning a list of rows.
574+
575+
An empty sequence is returned when no more rows are available.
576+
"""
577+
if size is None:
578+
size = self.arraysize
579+
580+
if size < 0:
581+
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")
582+
583+
# Note: We check for the specific queue type to maintain consistency with ThriftResultSet
584+
if isinstance(self.results, JsonQueue):
585+
rows = self.results.next_n_rows(size)
586+
self._next_row_index += len(rows)
587+
588+
# Convert to Row objects
589+
return self._convert_to_row_objects(rows)
590+
elif isinstance(self.results, SeaCloudFetchQueue):
591+
# For ARROW format with EXTERNAL_LINKS disposition
592+
arrow_table = self.results.next_n_rows(size)
593+
if arrow_table.num_rows == 0:
594+
return []
595+
596+
# Convert Arrow table to Row objects
597+
column_names = [col[0] for col in self.description]
598+
ResultRow = Row(*column_names)
599+
600+
# Convert each row to a Row object
601+
result_rows = []
602+
for i in range(arrow_table.num_rows):
603+
row_values = [
604+
arrow_table.column(j)[i].as_py()
605+
for j in range(arrow_table.num_columns)
606+
]
607+
result_rows.append(ResultRow(*row_values))
608+
609+
# Increment the row index
610+
self._next_row_index += arrow_table.num_rows
611+
612+
return result_rows
613+
else:
614+
# This should not happen with current implementation
615+
raise NotImplementedError("Unsupported queue type")
616+
617+
def fetchall(self) -> List[Row]:
618+
"""
619+
Fetch all (remaining) rows of a query result, returning them as a list of rows.
620+
"""
621+
# Note: We check for the specific queue type to maintain consistency with ThriftResultSet
622+
if isinstance(self.results, JsonQueue):
623+
rows = self.results.remaining_rows()
624+
self._next_row_index += len(rows)
625+
626+
# Convert to Row objects
627+
return self._convert_to_row_objects(rows)
628+
elif isinstance(self.results, SeaCloudFetchQueue):
629+
# For ARROW format with EXTERNAL_LINKS disposition
630+
arrow_table = self.results.remaining_rows()
631+
if arrow_table.num_rows == 0:
632+
return []
633+
634+
# Convert Arrow table to Row objects
635+
column_names = [col[0] for col in self.description]
636+
ResultRow = Row(*column_names)
637+
638+
# Convert each row to a Row object
639+
result_rows = []
640+
for i in range(arrow_table.num_rows):
641+
row_values = [
642+
arrow_table.column(j)[i].as_py()
643+
for j in range(arrow_table.num_columns)
644+
]
645+
result_rows.append(ResultRow(*row_values))
646+
647+
# Increment the row index
648+
self._next_row_index += arrow_table.num_rows
649+
650+
return result_rows
651+
else:
652+
# This should not happen with current implementation
653+
raise NotImplementedError("Unsupported queue type")
654+
655+
def fetchmany_arrow(self, size: int) -> Any:
656+
"""Fetch the next set of rows as an Arrow table."""
657+
if not pyarrow:
658+
raise ImportError("PyArrow is required for Arrow support")
659+
660+
if isinstance(self.results, JsonQueue):
661+
rows = self.fetchmany(size)
662+
if not rows:
663+
# Return empty Arrow table with schema
664+
return self._create_empty_arrow_table()
665+
666+
# Convert rows to Arrow table
667+
return self._convert_rows_to_arrow_table(rows)
668+
elif isinstance(self.results, SeaCloudFetchQueue):
669+
# For ARROW format with EXTERNAL_LINKS disposition
670+
arrow_table = self.results.next_n_rows(size)
671+
self._next_row_index += arrow_table.num_rows
672+
return arrow_table
673+
else:
674+
raise NotImplementedError("Unsupported queue type")
675+
676+
def fetchall_arrow(self) -> Any:
677+
"""Fetch all remaining rows as an Arrow table."""
678+
if not pyarrow:
679+
raise ImportError("PyArrow is required for Arrow support")
680+
681+
if isinstance(self.results, JsonQueue):
682+
rows = self.fetchall()
683+
if not rows:
684+
# Return empty Arrow table with schema
685+
return self._create_empty_arrow_table()
686+
687+
# Convert rows to Arrow table
688+
return self._convert_rows_to_arrow_table(rows)
689+
elif isinstance(self.results, SeaCloudFetchQueue):
690+
# For ARROW format with EXTERNAL_LINKS disposition
691+
arrow_table = self.results.remaining_rows()
692+
self._next_row_index += arrow_table.num_rows
693+
return arrow_table
694+
else:
695+
raise NotImplementedError("Unsupported queue type")

tests/unit/test_result_set_queue_factories.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Tests for the ResultSetQueueFactory classes.
2+
Tests for the ThriftResultSetQueueFactory classes.
33
"""
44

55
import unittest

tests/unit/test_thrift_backend.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,7 +1168,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class):
11681168

11691169
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
11701170
@patch(
1171-
"databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()
1171+
"databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", return_value=Mock()
11721172
)
11731173
def test_execute_statement_calls_client_and_handle_execute_response(
11741174
self, mock_build_queue, tcli_service_class
@@ -1208,7 +1208,7 @@ def test_execute_statement_calls_client_and_handle_execute_response(
12081208

12091209
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
12101210
@patch(
1211-
"databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()
1211+
"databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", return_value=Mock()
12121212
)
12131213
def test_get_catalogs_calls_client_and_handle_execute_response(
12141214
self, mock_build_queue, tcli_service_class
@@ -1245,7 +1245,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response(
12451245

12461246
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
12471247
@patch(
1248-
"databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()
1248+
"databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", return_value=Mock()
12491249
)
12501250
def test_get_schemas_calls_client_and_handle_execute_response(
12511251
self, mock_build_queue, tcli_service_class
@@ -1291,7 +1291,7 @@ def test_get_schemas_calls_client_and_handle_execute_response(
12911291

12921292
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
12931293
@patch(
1294-
"databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()
1294+
"databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", return_value=Mock()
12951295
)
12961296
def test_get_tables_calls_client_and_handle_execute_response(
12971297
self, mock_build_queue, tcli_service_class
@@ -1341,7 +1341,7 @@ def test_get_tables_calls_client_and_handle_execute_response(
13411341

13421342
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
13431343
@patch(
1344-
"databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()
1344+
"databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", return_value=Mock()
13451345
)
13461346
def test_get_columns_calls_client_and_handle_execute_response(
13471347
self, mock_build_queue, tcli_service_class
@@ -2265,7 +2265,7 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class):
22652265

22662266
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
22672267
@patch(
2268-
"databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()
2268+
"databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", return_value=Mock()
22692269
)
22702270
@patch(
22712271
"databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response",

0 commit comments

Comments
 (0)