Skip to content

Commit 7257168

Browse files
Merge branch 'sea-hybrid' into sea-decouple-link-fetch
2 parents 4abb3ad + 671dbca commit 7257168

File tree

3 files changed

+293
-30
lines changed

3 files changed

+293
-30
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
88

9-
import lz4.frame
9+
from databricks.sql.cloudfetch.downloader import ResultSetDownloadHandler
1010

1111
try:
1212
import pyarrow
@@ -37,25 +37,6 @@
3737
logger = logging.getLogger(__name__)
3838

3939

40-
def decompress_multi_frame_lz4(attachment: bytes) -> bytes:
41-
try:
42-
decompressor = lz4.frame.LZ4FrameDecompressor()
43-
arrow_file = decompressor.decompress(attachment)
44-
45-
# the attachment may be a concatenation of multiple LZ4 frames
46-
while decompressor.unused_data:
47-
remaining_data = decompressor.unused_data
48-
arrow_file += decompressor.decompress(remaining_data)
49-
50-
logger.debug(f"LZ4 decompressed {len(arrow_file)} bytes from attachment")
51-
52-
except Exception as e:
53-
logger.error(f"LZ4 decompression failed: {e}")
54-
raise e
55-
56-
return arrow_file
57-
58-
5940
class SeaResultSetQueueFactory(ABC):
6041
@staticmethod
6142
def build_queue(
@@ -90,7 +71,7 @@ def build_queue(
9071
elif manifest.format == ResultFormat.ARROW_STREAM.value:
9172
if result_data.attachment is not None:
9273
arrow_file = (
93-
decompress_multi_frame_lz4(result_data.attachment)
74+
ResultSetDownloadHandler._decompress_data(result_data.attachment)
9475
if lz4_compressed
9576
else result_data.attachment
9677
)
@@ -300,10 +281,57 @@ def __init__(
300281
self.link_fetcher.start()
301282

302283
# Initialize table and position
303-
self.table = self._create_next_table()
284+
self.table = self._create_table_from_link(self._current_chunk_link)
285+
286+
def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink:
287+
"""Convert SEA external links to Thrift format for compatibility with existing download manager."""
288+
# Parse the ISO format expiration time
289+
expiry_time = int(dateutil.parser.parse(link.expiration).timestamp())
290+
return TSparkArrowResultLink(
291+
fileLink=link.external_link,
292+
expiryTime=expiry_time,
293+
rowCount=link.row_count,
294+
bytesNum=link.byte_count,
295+
startRowOffset=link.row_offset,
296+
httpHeaders=link.http_headers or {},
297+
)
298+
299+
def _get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]:
300+
if chunk_index not in self._chunk_index_to_link:
301+
links = self._sea_client.get_chunk_links(self._statement_id, chunk_index)
302+
self._chunk_index_to_link.update({link.chunk_index: link for link in links})
303+
return self._chunk_index_to_link.get(chunk_index, None)
304+
305+
def _progress_chunk_link(self):
306+
"""Progress to the next chunk link."""
307+
if not self._current_chunk_link:
308+
return None
309+
310+
next_chunk_index = self._current_chunk_link.next_chunk_index
311+
312+
if next_chunk_index is None:
313+
self._current_chunk_link = None
314+
return None
315+
316+
self._current_chunk_link = self._get_chunk_link(next_chunk_index)
317+
if not self._current_chunk_link:
318+
logger.error(
319+
"SeaCloudFetchQueue: unable to retrieve link for chunk {}".format(
320+
next_chunk_index
321+
)
322+
)
323+
return None
324+
325+
logger.debug(
326+
f"SeaCloudFetchQueue: Progressed to link for chunk {next_chunk_index}: {self._current_chunk_link}"
327+
)
304328

305329
def _create_next_table(self) -> Union["pyarrow.Table", None]:
306330
"""Create next table by retrieving the logical next downloaded file."""
331+
if not self._current_chunk_link:
332+
logger.debug("SeaCloudFetchQueue: No current chunk link, returning")
333+
return None
334+
307335
if not self.download_manager:
308336
logger.debug("SeaCloudFetchQueue: No download manager, returning")
309337
return None
@@ -317,4 +345,8 @@ def _create_next_table(self) -> Union["pyarrow.Table", None]:
317345

318346
self.current_chunk_index += 1
319347

320-
return arrow_table
348+
if not self._current_chunk_link:
349+
logger.debug("SeaCloudFetchQueue: No current chunk link, returning")
350+
return None
351+
352+
return self._create_table_from_link(self._current_chunk_link)

tests/e2e/common/large_queries_mixin.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import math
33
import time
44

5+
import pytest
6+
57
log = logging.getLogger(__name__)
68

79

@@ -42,7 +44,14 @@ def fetch_rows(self, cursor, row_count, fetchmany_size):
4244
+ "assuming 10K fetch size."
4345
)
4446

45-
def test_query_with_large_wide_result_set(self):
47+
@pytest.mark.parametrize(
48+
"extra_params",
49+
[
50+
{},
51+
{"use_sea": True},
52+
],
53+
)
54+
def test_query_with_large_wide_result_set(self, extra_params):
4655
resultSize = 300 * 1000 * 1000 # 300 MB
4756
width = 8192 # B
4857
rows = resultSize // width
@@ -52,7 +61,7 @@ def test_query_with_large_wide_result_set(self):
5261
fetchmany_size = 10 * 1024 * 1024 // width
5362
# This is used by PyHive tests to determine the buffer size
5463
self.arraysize = 1000
55-
with self.cursor() as cursor:
64+
with self.cursor(extra_params) as cursor:
5665
for lz4_compression in [False, True]:
5766
cursor.connection.lz4_compression = lz4_compression
5867
uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)])
@@ -68,7 +77,14 @@ def test_query_with_large_wide_result_set(self):
6877
assert row[0] == row_id # Verify no rows are dropped in the middle.
6978
assert len(row[1]) == 36
7079

71-
def test_query_with_large_narrow_result_set(self):
80+
@pytest.mark.parametrize(
81+
"extra_params",
82+
[
83+
{},
84+
{"use_sea": True},
85+
],
86+
)
87+
def test_query_with_large_narrow_result_set(self, extra_params):
7288
resultSize = 300 * 1000 * 1000 # 300 MB
7389
width = 8 # sizeof(long)
7490
rows = resultSize / width
@@ -77,12 +93,19 @@ def test_query_with_large_narrow_result_set(self):
7793
fetchmany_size = 10 * 1024 * 1024 // width
7894
# This is used by PyHive tests to determine the buffer size
7995
self.arraysize = 10000000
80-
with self.cursor() as cursor:
96+
with self.cursor(extra_params) as cursor:
8197
cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows))
8298
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
8399
assert row[0] == row_id
84100

85-
def test_long_running_query(self):
101+
@pytest.mark.parametrize(
102+
"extra_params",
103+
[
104+
{},
105+
{"use_sea": True},
106+
],
107+
)
108+
def test_long_running_query(self, extra_params):
86109
"""Incrementally increase query size until it takes at least 3 minutes,
87110
and asserts that the query completes successfully.
88111
"""
@@ -92,7 +115,7 @@ def test_long_running_query(self):
92115
duration = -1
93116
scale0 = 10000
94117
scale_factor = 1
95-
with self.cursor() as cursor:
118+
with self.cursor(extra_params) as cursor:
96119
while duration < min_duration:
97120
assert scale_factor < 1024, "Detected infinite loop"
98121
start = time.time()

0 commit comments

Comments
 (0)