Skip to content

Commit f0d9c65

Browse files
init fetch phase JSON + INLINE
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 15a8efc commit f0d9c65

File tree

4 files changed

+425
-52
lines changed

4 files changed

+425
-52
lines changed

examples/experimental/sea_connector_test.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,96 @@
77
logger = logging.getLogger(__name__)
88

99

10+
def test_sea_result_set_json_array_inline():
11+
"""
12+
Test the SEA result set implementation with JSON_ARRAY format and INLINE disposition.
13+
14+
This function connects to a Databricks SQL endpoint using the SEA backend,
15+
executes a query that returns a small result set (which will use INLINE disposition),
16+
and tests the various fetch methods to verify the result set implementation works correctly.
17+
"""
18+
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
19+
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
20+
access_token = os.environ.get("DATABRICKS_TOKEN")
21+
catalog = os.environ.get("DATABRICKS_CATALOG")
22+
23+
if not all([server_hostname, http_path, access_token]):
24+
logger.error("Missing required environment variables.")
25+
logger.error(
26+
"Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN."
27+
)
28+
sys.exit(1)
29+
30+
try:
31+
# Create connection with SEA backend
32+
logger.info("Creating connection with SEA backend...")
33+
connection = Connection(
34+
server_hostname=server_hostname,
35+
http_path=http_path,
36+
access_token=access_token,
37+
catalog=catalog,
38+
schema="default",
39+
use_sea=True,
40+
user_agent_entry="SEA-Test-Client",
41+
)
42+
43+
logger.info(
44+
f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}"
45+
)
46+
47+
# Create cursor
48+
cursor = connection.cursor()
49+
50+
# Execute a query that returns a small result set (will use INLINE disposition)
51+
logger.info("Executing query: SELECT * FROM range(1, 10) AS id")
52+
cursor.execute("SELECT * FROM range(1, 10) AS id")
53+
54+
# Test fetchone
55+
logger.info("Testing fetchone...")
56+
row = cursor.fetchone()
57+
logger.info(f"First row: {row}")
58+
59+
# Test fetchmany
60+
logger.info("Testing fetchmany(3)...")
61+
rows = cursor.fetchmany(3)
62+
logger.info(f"Next 3 rows: {rows}")
63+
64+
# Test fetchall
65+
logger.info("Testing fetchall...")
66+
remaining_rows = cursor.fetchall()
67+
logger.info(f"Remaining rows: {remaining_rows}")
68+
69+
# Execute another query to test arrow fetch methods
70+
logger.info("Executing query for Arrow testing: SELECT * FROM range(1, 5) AS id, range(101, 105) AS value")
71+
cursor.execute("SELECT * FROM range(1, 5) AS id, range(101, 105) AS value")
72+
73+
# Test fetchmany_arrow
74+
logger.info("Testing fetchmany_arrow(2)...")
75+
arrow_batch = cursor.fetchmany_arrow(2)
76+
logger.info(f"Arrow batch num rows: {arrow_batch.num_rows}")
77+
logger.info(f"Arrow batch columns: {arrow_batch.column_names}")
78+
logger.info(f"Arrow batch data: {arrow_batch.to_pydict()}")
79+
80+
# Test fetchall_arrow
81+
logger.info("Testing fetchall_arrow...")
82+
remaining_arrow_batch = cursor.fetchall_arrow()
83+
logger.info(f"Remaining arrow batch num rows: {remaining_arrow_batch.num_rows}")
84+
logger.info(f"Remaining arrow batch data: {remaining_arrow_batch.to_pydict()}")
85+
86+
# Close cursor and connection
87+
cursor.close()
88+
connection.close()
89+
logger.info("Successfully closed SEA session")
90+
91+
except Exception as e:
92+
logger.error(f"Error during SEA result set test: {str(e)}")
93+
import traceback
94+
logger.error(traceback.format_exc())
95+
sys.exit(1)
96+
97+
logger.info("SEA result set test with JSON_ARRAY format and INLINE disposition completed successfully")
98+
99+
10100
def test_sea_query_execution_with_compression():
11101
"""
12102
Test executing a query using the SEA backend with result compression.
@@ -159,4 +249,7 @@ def test_sea_session():
159249
test_sea_session()
160250

161251
# Test query execution with compression
162-
test_sea_query_execution_with_compression()
252+
test_sea_query_execution_with_compression()
253+
254+
# Test result set implementation
255+
test_sea_result_set_json_array_inline()

src/databricks/sql/backend/sea_result_set.py

Lines changed: 119 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,18 @@
77

88
import json
99
import logging
10-
from typing import Optional, List, Any, Dict, Tuple
10+
from typing import Optional, List, Any, Dict, Tuple, cast
11+
12+
try:
13+
import pyarrow
14+
except ImportError:
15+
pyarrow = None
1116

1217
from databricks.sql.result_set import ResultSet
1318
from databricks.sql.types import Row
1419
from databricks.sql.backend.types import CommandId, CommandState
1520
from databricks.sql.exc import Error
21+
from databricks.sql.utils import ResultSetQueueFactory, JsonQueue
1622

1723
from databricks.sql.backend.models import (
1824
StatementStatus,
@@ -106,14 +112,19 @@ def __init__(
106112

107113
# Initialize other properties
108114
self._is_staging_operation = False # SEA doesn't have staging operations
109-
self._rows_buffer = []
110-
self._current_row_index = 0
111-
self._has_more_rows = True
115+
self._has_more_rows = False
112116
self._current_chunk_index = 0
113117

114-
# If we have inline data, fill the buffer
115-
if self.result and self.result.data:
116-
self._rows_buffer = self.result.data
118+
# Initialize queue for result data
119+
if self.result:
120+
self.results = ResultSetQueueFactory.build_queue(
121+
sea_result_data=self.result,
122+
description=cast(Optional[List[List[Any]]], self.description)
123+
)
124+
self._has_more_rows = True if self.result.data else False
125+
else:
126+
self.results = JsonQueue([])
127+
self._has_more_rows = False
117128

118129
@property
119130
def is_staging_operation(self) -> bool:
@@ -153,28 +164,120 @@ def _extract_description_from_manifest(
153164
return description
154165

155166
def _fill_results_buffer(self) -> None:
156-
"""Fill the results buffer from the backend."""
157-
raise NotImplementedError("Not implemented yet")
167+
"""Fill the results buffer from the backend for INLINE disposition."""
168+
if not self.result or not self.result.data:
169+
self._has_more_rows = False
170+
return
171+
172+
# For INLINE disposition, we already have all the data
173+
# No need to fetch more data from the backend
174+
self._has_more_rows = False # No more rows to fetch for INLINE
175+
176+
def _convert_rows_to_arrow_table(self, rows):
177+
"""Convert rows to Arrow table."""
178+
if not self.description:
179+
return pyarrow.Table.from_pylist([])
180+
181+
# Create dict of column data
182+
column_data = {}
183+
column_names = [col[0] for col in self.description]
184+
185+
for i, name in enumerate(column_names):
186+
column_data[name] = [row[i] for row in rows]
187+
188+
return pyarrow.Table.from_pydict(column_data)
189+
190+
def _create_empty_arrow_table(self):
191+
"""Create an empty Arrow table with the correct schema."""
192+
if not self.description:
193+
return pyarrow.Table.from_pylist([])
194+
195+
column_names = [col[0] for col in self.description]
196+
return pyarrow.Table.from_pydict({name: [] for name in column_names})
158197

159198
def fetchone(self) -> Optional[Row]:
160199
"""Fetch the next row of a query result set."""
161-
raise NotImplementedError("Not implemented yet")
200+
if isinstance(self.results, JsonQueue):
201+
rows = self.results.next_n_rows(1)
202+
if not rows:
203+
return None
204+
205+
row = rows[0]
206+
207+
# Convert to Row object
208+
if self.description:
209+
column_names = [col[0] for col in self.description]
210+
ResultRow = Row(*column_names)
211+
return ResultRow(*row)
212+
return row
213+
else:
214+
# This should not happen with current implementation
215+
# but added for future compatibility
216+
raise NotImplementedError("Unsupported queue type")
162217

163-
def fetchmany(self, size: int) -> List[Row]:
218+
def fetchmany(self, size: Optional[int] = None) -> List[Row]:
164219
"""Fetch the next set of rows of a query result."""
165-
raise NotImplementedError("Not implemented yet")
220+
if size is None:
221+
size = self.arraysize
222+
223+
if size < 0:
224+
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")
225+
226+
if isinstance(self.results, JsonQueue):
227+
rows = self.results.next_n_rows(size)
228+
229+
# Convert to Row objects
230+
if self.description:
231+
column_names = [col[0] for col in self.description]
232+
ResultRow = Row(*column_names)
233+
return [ResultRow(*row) for row in rows]
234+
return rows
235+
else:
236+
# This should not happen with current implementation
237+
# but added for future compatibility
238+
raise NotImplementedError("Unsupported queue type")
166239

167240
def fetchall(self) -> List[Row]:
168241
"""Fetch all remaining rows of a query result."""
169-
raise NotImplementedError("Not implemented yet")
242+
if isinstance(self.results, JsonQueue):
243+
rows = self.results.remaining_rows()
244+
245+
# Convert to Row objects
246+
if self.description:
247+
column_names = [col[0] for col in self.description]
248+
ResultRow = Row(*column_names)
249+
return [ResultRow(*row) for row in rows]
250+
return rows
251+
else:
252+
# This should not happen with current implementation
253+
# but added for future compatibility
254+
raise NotImplementedError("Unsupported queue type")
170255

171256
def fetchmany_arrow(self, size: int) -> Any:
172257
"""Fetch the next set of rows as an Arrow table."""
173-
raise NotImplementedError("Not implemented yet")
258+
if not pyarrow:
259+
raise ImportError("PyArrow is required for Arrow support")
260+
261+
rows = self.fetchmany(size)
262+
if not rows:
263+
# Return empty Arrow table with schema
264+
return self._create_empty_arrow_table()
265+
266+
# Convert rows to Arrow table
267+
return self._convert_rows_to_arrow_table(rows)
174268

175269
def fetchall_arrow(self) -> Any:
176270
"""Fetch all remaining rows as an Arrow table."""
177-
raise NotImplementedError("Not implemented yet")
271+
if not pyarrow:
272+
raise ImportError("PyArrow is required for Arrow support")
273+
274+
rows = self.fetchall()
275+
if not rows:
276+
# Return empty Arrow table with schema
277+
return self._create_empty_arrow_table()
278+
279+
# Convert rows to Arrow table
280+
return self._convert_rows_to_arrow_table(rows)
178281

179282
def close(self) -> None:
180283
"""Close the result set and release any resources."""
@@ -185,4 +288,4 @@ def close(self) -> None:
185288
CommandId.from_sea_statement_id(self.statement_id)
186289
)
187290
except Exception as e:
188-
logger.warning(f"Error closing SEA statement: {e}")
291+
logger.warning(f"Error closing SEA statement: {e}")

src/databricks/sql/utils.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,41 +48,90 @@ def remaining_rows(self):
4848
pass
4949

5050

51+
class JsonQueue(ResultSetQueue):
52+
"""Queue implementation for JSON_ARRAY format data."""
53+
54+
def __init__(self, data_array):
55+
"""Initialize with JSON array data."""
56+
self.data_array = data_array
57+
self.cur_row_index = 0
58+
self.n_valid_rows = len(data_array)
59+
60+
def next_n_rows(self, num_rows):
61+
"""Get the next n rows from the data array."""
62+
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
63+
slice = self.data_array[self.cur_row_index:self.cur_row_index + length]
64+
self.cur_row_index += length
65+
return slice
66+
67+
def remaining_rows(self):
68+
"""Get all remaining rows from the data array."""
69+
slice = self.data_array[self.cur_row_index:]
70+
self.cur_row_index += len(slice)
71+
return slice
72+
73+
5174
class ResultSetQueueFactory(ABC):
5275
@staticmethod
5376
def build_queue(
54-
row_set_type: TSparkRowSetType,
55-
t_row_set: TRowSet,
56-
arrow_schema_bytes: bytes,
57-
max_download_threads: int,
58-
ssl_options: SSLOptions,
77+
row_set_type: Optional[TSparkRowSetType] = None,
78+
t_row_set: Optional[TRowSet] = None,
79+
arrow_schema_bytes: Optional[bytes] = None,
80+
max_download_threads: Optional[int] = None,
81+
ssl_options: Optional[SSLOptions] = None,
5982
lz4_compressed: bool = True,
6083
description: Optional[List[List[Any]]] = None,
84+
# SEA specific parameters
85+
sea_result_data: Optional[Any] = None,
6186
) -> ResultSetQueue:
6287
"""
6388
Factory method to build a result set queue.
64-
89+
90+
This method is extended to handle both Thrift and SEA result formats.
91+
For SEA, the sea_result_data parameter is used instead of row_set_type and t_row_set.
92+
6593
Args:
94+
# Thrift parameters
6695
row_set_type (enum): Row set type (Arrow, Column, or URL).
6796
t_row_set (TRowSet): Result containing arrow batches, columns, or cloud fetch links.
97+
98+
# Common parameters
6899
arrow_schema_bytes (bytes): Bytes representing the arrow schema.
69100
lz4_compressed (bool): Whether result data has been lz4 compressed.
70101
description (List[List[Any]]): Hive table schema description.
71102
max_download_threads (int): Maximum number of downloader thread pool threads.
72103
ssl_options (SSLOptions): SSLOptions object for CloudFetchQueue
73-
104+
105+
# SEA parameters
106+
sea_result_data (ResultData): Result data from SEA response
107+
74108
Returns:
75109
ResultSetQueue
76110
"""
77-
if row_set_type == TSparkRowSetType.ARROW_BASED_SET:
111+
# Handle SEA result data
112+
if sea_result_data is not None:
113+
if sea_result_data.data:
114+
# INLINE disposition with JSON_ARRAY format
115+
return JsonQueue(sea_result_data.data)
116+
elif sea_result_data.external_links:
117+
# EXTERNAL_LINKS disposition (not implemented yet)
118+
raise NotImplementedError(
119+
"EXTERNAL_LINKS disposition is not supported yet"
120+
)
121+
else:
122+
# Empty result set
123+
return JsonQueue([])
124+
125+
# Handle Thrift result data (existing implementation)
126+
if row_set_type == TSparkRowSetType.ARROW_BASED_SET and t_row_set is not None and arrow_schema_bytes is not None:
78127
arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table(
79128
t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes
80129
)
81130
converted_arrow_table = convert_decimals_in_arrow_table(
82131
arrow_table, description
83132
)
84133
return ArrowQueue(converted_arrow_table, n_valid_rows)
85-
elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET:
134+
elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET and t_row_set is not None:
86135
column_table, column_names = convert_column_based_set_to_column_table(
87136
t_row_set.columns, description
88137
)
@@ -92,7 +141,7 @@ def build_queue(
92141
)
93142

94143
return ColumnQueue(ColumnTable(converted_column_table, column_names))
95-
elif row_set_type == TSparkRowSetType.URL_BASED_SET:
144+
elif row_set_type == TSparkRowSetType.URL_BASED_SET and t_row_set is not None and arrow_schema_bytes is not None and max_download_threads is not None and ssl_options is not None:
96145
return CloudFetchQueue(
97146
schema_bytes=arrow_schema_bytes,
98147
start_row_offset=t_row_set.startRowOffset,

0 commit comments

Comments
 (0)