Skip to content

Commit f374f5f

Browse files
init sea link retry func
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent d038d84 commit f374f5f

File tree

5 files changed

+75
-12
lines changed

5 files changed

+75
-12
lines changed

examples/experimental/tests/test_sea_sync_query.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import os
55
import sys
66
import logging
7+
import time
78
from databricks.sql.client import Connection
89

9-
logging.basicConfig(level=logging.INFO)
10+
logging.basicConfig(level=logging.DEBUG)
1011
logger = logging.getLogger(__name__)
1112

1213

@@ -51,20 +52,19 @@ def test_sea_sync_query_with_cloud_fetch():
5152
)
5253

5354
# Execute a query that generates large rows to force multiple chunks
54-
requested_row_count = 10000
55+
requested_row_count = 100000000
5556
cursor = connection.cursor()
5657
query = f"""
57-
SELECT
58-
id,
59-
concat('value_', repeat('a', 10000)) as test_value
60-
FROM range(1, {requested_row_count} + 1) AS t(id)
58+
SELECT * FROM samples.tpch.lineitem LIMIT {requested_row_count}
6159
"""
6260

6361
logger.info(
6462
f"Executing synchronous query with cloud fetch to generate {requested_row_count} rows"
6563
)
6664
cursor.execute(query)
6765
results = [cursor.fetchone()]
66+
logger.info("SLEEPING FOR 1000 SECONDS TO EXPIRE LINKS")
67+
time.sleep(1000)
6868
results.extend(cursor.fetchmany(10))
6969
results.extend(cursor.fetchall())
7070
actual_row_count = len(results)

src/databricks/sql/backend/sea/queue.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,24 @@ def _worker_loop(self):
204204
if not links_downloaded:
205205
break
206206

207+
def _restart_from_expired_link(self, link: TSparkArrowResultLink):
208+
self.stop()
209+
210+
with self._link_data_update:
211+
self.download_manager.cancel_tasks_from_offset(link.startRowOffset)
212+
213+
chunks_to_restart = []
214+
for chunk_index, l in self.chunk_index_to_link.items():
215+
if l.row_offset < link.startRowOffset:
216+
continue
217+
chunks_to_restart.append(chunk_index)
218+
for chunk_index in chunks_to_restart:
219+
self.chunk_index_to_link.pop(chunk_index)
220+
221+
self.start()
222+
207223
def start(self):
224+
self._shutdown_event.clear()
208225
self._worker_thread = threading.Thread(target=self._worker_loop)
209226
self._worker_thread.start()
210227

@@ -269,6 +286,7 @@ def __init__(
269286
max_download_threads=max_download_threads,
270287
lz4_compressed=lz4_compressed,
271288
ssl_options=ssl_options,
289+
expiry_callback=self._expiry_callback,
272290
)
273291

274292
self.link_fetcher = LinkFetcher(
@@ -283,6 +301,12 @@ def __init__(
283301
# Initialize table and position
284302
self.table = self._create_next_table()
285303

304+
def _expiry_callback(self, link: TSparkArrowResultLink):
305+
logger.info(
306+
f"SeaCloudFetchQueue: Link expired, restarting from offset {link.startRowOffset}"
307+
)
308+
self.link_fetcher._restart_from_expired_link(link)
309+
286310
def _create_next_table(self) -> Union["pyarrow.Table", None]:
287311
"""Create next table by retrieving the logical next downloaded file."""
288312
if not self.download_manager:

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from concurrent.futures import ThreadPoolExecutor, Future
4-
from typing import List, Union
4+
from typing import Callable, List, Optional, Union
55

66
from databricks.sql.cloudfetch.downloader import (
77
ResultSetDownloadHandler,
@@ -22,6 +22,7 @@ def __init__(
2222
max_download_threads: int,
2323
lz4_compressed: bool,
2424
ssl_options: SSLOptions,
25+
expiry_callback: Callable[[TSparkArrowResultLink], None],
2526
):
2627
self._pending_links: List[TSparkArrowResultLink] = []
2728
for link in links:
@@ -40,6 +41,7 @@ def __init__(
4041

4142
self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
4243
self._ssl_options = ssl_options
44+
self._expiry_callback = expiry_callback
4345

4446
def get_next_downloaded_file(
4547
self, next_row_offset: int
@@ -62,7 +64,6 @@ def get_next_downloaded_file(
6264

6365
# No more files to download from this batch of links
6466
if len(self._download_tasks) == 0:
65-
self._shutdown_manager()
6667
return None
6768

6869
task = self._download_tasks.pop(0)
@@ -81,6 +82,34 @@ def get_next_downloaded_file(
8182

8283
return file
8384

85+
def cancel_tasks_from_offset(self, start_row_offset: int):
86+
"""
87+
Cancel all download tasks starting from a specific row offset.
88+
This is used when links expire and we need to restart from a certain point.
89+
"""
90+
91+
def to_cancel(link: TSparkArrowResultLink) -> bool:
92+
return link.startRowOffset < start_row_offset
93+
94+
tasks_to_cancel = [task for task in self._download_tasks if to_cancel(task.link)]
95+
for task in tasks_to_cancel:
96+
task.cancel()
97+
logger.info(
98+
f"ResultFileDownloadManager: cancelled {len(tasks_to_cancel)} tasks from offset {start_row_offset}"
99+
)
100+
101+
# Remove cancelled tasks from the download queue
102+
tasks_to_keep = [task for task in self._download_tasks if not to_cancel(task.link)]
103+
self._download_tasks = tasks_to_keep
104+
105+
pending_links_to_keep = [
106+
link for link in self._pending_links if not to_cancel(link)
107+
]
108+
self._pending_links = pending_links_to_keep
109+
logger.info(
110+
f"ResultFileDownloadManager: removed {len(self._pending_links) - len(pending_links_to_keep)} links from pending links"
111+
)
112+
84113
def _schedule_downloads(self):
85114
"""
86115
While download queue has a capacity, peek pending links and submit them to thread pool.
@@ -97,8 +126,10 @@ def _schedule_downloads(self):
97126
settings=self._downloadable_result_settings,
98127
link=link,
99128
ssl_options=self._ssl_options,
129+
expiry_callback=self._expiry_callback,
100130
)
101131
task = self._thread_pool.submit(handler.run)
132+
task.link = link
102133
self._download_tasks.append(task)
103134

104135
def add_link(self, link: TSparkArrowResultLink):

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from dataclasses import dataclass
3+
from typing import Callable
34

45
import requests
56
from requests.adapters import HTTPAdapter, Retry
@@ -66,10 +67,12 @@ def __init__(
6667
settings: DownloadableResultSettings,
6768
link: TSparkArrowResultLink,
6869
ssl_options: SSLOptions,
70+
expiry_callback: Callable[[TSparkArrowResultLink], None],
6971
):
7072
self.settings = settings
7173
self.link = link
7274
self._ssl_options = ssl_options
75+
self._expiry_callback = expiry_callback
7376

7477
def run(self) -> DownloadedFile:
7578
"""
@@ -86,7 +89,7 @@ def run(self) -> DownloadedFile:
8689
)
8790

8891
# Check if link is already expired or is expiring
89-
ResultSetDownloadHandler._validate_link(
92+
self._validate_link(
9093
self.link, self.settings.link_expiry_buffer_secs
9194
)
9295

@@ -136,8 +139,7 @@ def run(self) -> DownloadedFile:
136139
if session:
137140
session.close()
138141

139-
@staticmethod
140-
def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int):
142+
def _validate_link(self, link: TSparkArrowResultLink, expiry_buffer_secs: int):
141143
"""
142144
Check if a link has expired or will expire.
143145
@@ -149,7 +151,7 @@ def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int):
149151
link.expiryTime <= current_time
150152
or link.expiryTime - current_time <= expiry_buffer_secs
151153
):
152-
raise Error("CloudFetch link has expired")
154+
self._expiry_callback(link)
153155

154156
@staticmethod
155157
def _decompress_data(compressed_data: bytes) -> bytes:

src/databricks/sql/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import lz4.frame
1616

17+
from databricks.sql.exc import Error
18+
1719
try:
1820
import pyarrow
1921
except ImportError:
@@ -374,12 +376,16 @@ def __init__(
374376
)
375377
)
376378

379+
def expiry_callback(link: TSparkArrowResultLink):
380+
raise Error("Cloudfetch link has expired")
381+
377382
# Initialize download manager
378383
self.download_manager = ResultFileDownloadManager(
379384
links=self.result_links,
380385
max_download_threads=self.max_download_threads,
381386
lz4_compressed=self.lz4_compressed,
382387
ssl_options=self._ssl_options,
388+
expiry_callback=expiry_callback,
383389
)
384390

385391
# Initialize table and position

0 commit comments

Comments
 (0)