Skip to content

Commit 92fe5fd

Browse files
introduce separate link fetcher
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 2202057 commit 92fe5fd

File tree

1 file changed

+117
-58
lines changed

1 file changed

+117
-58
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 117 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
from abc import ABC
4-
from typing import List, Optional, Tuple, Union
4+
import threading
5+
from typing import Dict, List, Optional, Tuple, Union
56

67
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
78

@@ -136,6 +137,105 @@ def remaining_rows(self) -> List[List[str]]:
136137
return slice
137138

138139

140+
class LinkFetcher:
141+
def __init__(
142+
self,
143+
download_manager: ResultFileDownloadManager,
144+
backend: "SeaDatabricksClient",
145+
statement_id: str,
146+
initial_links: List["ExternalLink"],
147+
total_chunk_count: int,
148+
):
149+
self.download_manager = download_manager
150+
self.backend = backend
151+
self._statement_id = statement_id
152+
153+
self._shutdown_event = threading.Event()
154+
155+
self._condition = threading.Condition()
156+
self._error = None
157+
self.chunk_index_to_link: Dict[int, "ExternalLink"] = {}
158+
for link in initial_links:
159+
self.chunk_index_to_link[link.chunk_index] = link
160+
self.download_manager.add_link(self._convert_to_thrift_link(link))
161+
self.total_chunk_count = total_chunk_count
162+
163+
def _get_next_chunk_index(self) -> Optional[int]:
164+
with self._condition:
165+
max_chunk_index = max(self.chunk_index_to_link.keys(), default=None)
166+
if max_chunk_index is None:
167+
return 0
168+
max_link = self.chunk_index_to_link[max_chunk_index]
169+
return max_link.next_chunk_index
170+
171+
def _trigger_next_batch_download(self) -> bool:
172+
next_chunk_index = self._get_next_chunk_index()
173+
if next_chunk_index is None:
174+
return False
175+
176+
try:
177+
links = self.backend.get_chunk_links(self._statement_id, next_chunk_index)
178+
with self._condition:
179+
self.chunk_index_to_link.update(
180+
{link.chunk_index: link for link in links}
181+
)
182+
self._condition.notify_all()
183+
for link in links:
184+
self.download_manager.add_link(self._convert_to_thrift_link(link))
185+
except Exception as e:
186+
logger.error(
187+
f"LinkFetcher: Error fetching links for chunk {next_chunk_index}: {e}"
188+
)
189+
with self._condition:
190+
self._error = e
191+
self._condition.notify_all()
192+
return False
193+
194+
return True
195+
196+
def get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]:
197+
if chunk_index >= self.total_chunk_count:
198+
return None
199+
200+
with self._condition:
201+
if self._error:
202+
raise self._error
203+
204+
while chunk_index not in self.chunk_index_to_link:
205+
if self._error:
206+
raise self._error
207+
self._condition.wait()
208+
209+
return self.chunk_index_to_link.get(chunk_index, None)
210+
211+
def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink:
212+
"""Convert SEA external links to Thrift format for compatibility with existing download manager."""
213+
# Parse the ISO format expiration time
214+
expiry_time = int(dateutil.parser.parse(link.expiration).timestamp())
215+
return TSparkArrowResultLink(
216+
fileLink=link.external_link,
217+
expiryTime=expiry_time,
218+
rowCount=link.row_count,
219+
bytesNum=link.byte_count,
220+
startRowOffset=link.row_offset,
221+
httpHeaders=link.http_headers or {},
222+
)
223+
224+
def _worker_loop(self):
225+
while not self._shutdown_event.is_set():
226+
links_downloaded = self._trigger_next_batch_download()
227+
if not links_downloaded:
228+
break
229+
230+
def start(self):
231+
self._worker_thread = threading.Thread(target=self._worker_loop)
232+
self._worker_thread.start()
233+
234+
def stop(self):
235+
self._shutdown_event.set()
236+
self._worker_thread.join()
237+
238+
139239
class SeaCloudFetchQueue(CloudFetchQueue):
140240
"""Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend."""
141241

@@ -182,84 +282,43 @@ def __init__(
182282
)
183283
)
184284

185-
self._chunk_index_to_link = {link.chunk_index: link for link in initial_links}
186-
187-
initial_link = self._chunk_index_to_link.get(0, None)
188-
if not initial_link:
285+
if total_chunk_count < 1:
189286
return
190287

288+
self.current_chunk_index = 0
289+
191290
self.download_manager = ResultFileDownloadManager(
192291
links=[],
193292
max_download_threads=max_download_threads,
194293
lz4_compressed=lz4_compressed,
195294
ssl_options=ssl_options,
196295
)
197296

198-
# Track the current chunk we're processing
199-
self._current_chunk_link: Optional["ExternalLink"] = initial_link
297+
self.link_fetcher = LinkFetcher(
298+
download_manager=self.download_manager,
299+
backend=self._sea_client,
300+
statement_id=self._statement_id,
301+
initial_links=initial_links,
302+
total_chunk_count=total_chunk_count,
303+
)
304+
self.link_fetcher.start()
200305

201306
# Initialize table and position
202307
self.table = self._create_next_table()
203308

204-
def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink:
205-
"""Convert SEA external links to Thrift format for compatibility with existing download manager."""
206-
# Parse the ISO format expiration time
207-
expiry_time = int(dateutil.parser.parse(link.expiration).timestamp())
208-
return TSparkArrowResultLink(
209-
fileLink=link.external_link,
210-
expiryTime=expiry_time,
211-
rowCount=link.row_count,
212-
bytesNum=link.byte_count,
213-
startRowOffset=link.row_offset,
214-
httpHeaders=link.http_headers or {},
215-
)
216-
217-
def _get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]:
218-
if chunk_index not in self._chunk_index_to_link:
219-
links = self._sea_client.get_chunk_links(self._statement_id, chunk_index)
220-
self._chunk_index_to_link.update({link.chunk_index: link for link in links})
221-
return self._chunk_index_to_link.get(chunk_index, None)
222-
223-
def _progress_chunk_link(self):
224-
"""Progress to the next chunk link."""
225-
if not self._current_chunk_link:
226-
return None
227-
228-
next_chunk_index = self._current_chunk_link.next_chunk_index
229-
230-
if next_chunk_index is None:
231-
self._current_chunk_link = None
232-
return None
233-
234-
self._current_chunk_link = self._get_chunk_link(next_chunk_index)
235-
if not self._current_chunk_link:
236-
logger.error(
237-
"SeaCloudFetchQueue: unable to retrieve link for chunk {}".format(
238-
next_chunk_index
239-
)
240-
)
241-
return None
242-
243-
logger.debug(
244-
f"SeaCloudFetchQueue: Progressed to link for chunk {next_chunk_index}: {self._current_chunk_link}"
245-
)
246-
247309
def _create_next_table(self) -> Union["pyarrow.Table", None]:
248310
"""Create next table by retrieving the logical next downloaded file."""
249-
if not self._current_chunk_link:
250-
logger.debug("SeaCloudFetchQueue: No current chunk link, returning")
251-
return None
252-
253311
if not self.download_manager:
254312
logger.debug("SeaCloudFetchQueue: No download manager, returning")
255313
return None
256314

257-
thrift_link = self._convert_to_thrift_link(self._current_chunk_link)
258-
self.download_manager.add_link(thrift_link)
315+
chunk_link = self.link_fetcher.get_chunk_link(self.current_chunk_index)
316+
if not chunk_link:
317+
return None
259318

260-
row_offset = self._current_chunk_link.row_offset
319+
row_offset = chunk_link.row_offset
261320
arrow_table = self._create_table_at_offset(row_offset)
262321

263-
self._progress_chunk_link()
322+
self.current_chunk_index += 1
264323

265324
return arrow_table

0 commit comments

Comments
 (0)