Skip to content

Commit 9ce0803

Browse files
decouple link fetching
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 811205e commit 9ce0803

File tree

1 file changed

+91
-61
lines changed

1 file changed

+91
-61
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 91 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
from abc import ABC
4-
from typing import List, Optional, Tuple, Union
4+
import threading
5+
from typing import Dict, List, Optional, Tuple, Union
56

67
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
78

@@ -111,6 +112,86 @@ def remaining_rows(self) -> List[List[str]]:
111112
return slice
112113

113114

115+
class LinkFetcher:
116+
def __init__(
117+
self,
118+
download_manager: ResultFileDownloadManager,
119+
backend: "SeaDatabricksClient",
120+
statement_id: str,
121+
current_chunk_link: Optional["ExternalLink"] = None,
122+
):
123+
self.download_manager = download_manager
124+
self.backend = backend
125+
self._statement_id = statement_id
126+
self._current_chunk_link = current_chunk_link
127+
128+
self._shutdown_event = threading.Event()
129+
130+
self._map_lock = threading.Lock()
131+
self.chunk_index_to_link: Dict[int, "ExternalLink"] = {}
132+
133+
def _set_current_chunk_link(self, link: "ExternalLink"):
134+
with self._map_lock:
135+
self.chunk_index_to_link[link.chunk_index] = link
136+
137+
def get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]:
138+
with self._map_lock:
139+
return self.chunk_index_to_link.get(chunk_index, None)
140+
141+
def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink:
142+
"""Convert SEA external links to Thrift format for compatibility with existing download manager."""
143+
# Parse the ISO format expiration time
144+
expiry_time = int(dateutil.parser.parse(link.expiration).timestamp())
145+
return TSparkArrowResultLink(
146+
fileLink=link.external_link,
147+
expiryTime=expiry_time,
148+
rowCount=link.row_count,
149+
bytesNum=link.byte_count,
150+
startRowOffset=link.row_offset,
151+
httpHeaders=link.http_headers or {},
152+
)
153+
154+
def _progress_chunk_link(self):
155+
"""Progress to the next chunk link."""
156+
if not self._current_chunk_link:
157+
return None
158+
159+
next_chunk_index = self._current_chunk_link.next_chunk_index
160+
161+
if next_chunk_index is None:
162+
self._current_chunk_link = None
163+
return None
164+
165+
try:
166+
self._current_chunk_link = self.backend.get_chunk_link(
167+
self._statement_id, next_chunk_index
168+
)
169+
except Exception as e:
170+
logger.error(
171+
"LinkFetcher: Error fetching link for chunk {}: {}".format(
172+
next_chunk_index, e
173+
)
174+
)
175+
self._current_chunk_link = None
176+
177+
def _worker_loop(self):
178+
while not (self._shutdown_event.is_set() or self._current_chunk_link is None):
179+
self._set_current_chunk_link(self._current_chunk_link)
180+
self.download_manager.add_link(
181+
self._convert_to_thrift_link(self._current_chunk_link)
182+
)
183+
184+
self._progress_chunk_link()
185+
186+
def start(self):
187+
self._worker_thread = threading.Thread(target=self._worker_loop)
188+
self._worker_thread.start()
189+
190+
def stop(self):
191+
self._shutdown_event.set()
192+
self._worker_thread.join()
193+
194+
114195
class SeaCloudFetchQueue(CloudFetchQueue):
115196
"""Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend."""
116197

@@ -160,6 +241,7 @@ def __init__(
160241
initial_link = next((l for l in initial_links if l.chunk_index == 0), None)
161242
if not initial_link:
162243
return
244+
self.current_chunk_index = initial_link.chunk_index
163245

164246
self.download_manager = ResultFileDownloadManager(
165247
links=[],
@@ -168,75 +250,23 @@ def __init__(
168250
ssl_options=ssl_options,
169251
)
170252

171-
# Track the current chunk we're processing
172-
self._current_chunk_link: Optional["ExternalLink"] = initial_link
173-
self._download_current_link()
253+
self.link_fetcher = LinkFetcher(
254+
self.download_manager, self._sea_client, statement_id, initial_link
255+
)
256+
self.link_fetcher.start()
174257

175258
# Initialize table and position
176259
self.table = self._create_next_table()
177260

178-
def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink:
179-
"""Convert SEA external links to Thrift format for compatibility with existing download manager."""
180-
# Parse the ISO format expiration time
181-
expiry_time = int(dateutil.parser.parse(link.expiration).timestamp())
182-
return TSparkArrowResultLink(
183-
fileLink=link.external_link,
184-
expiryTime=expiry_time,
185-
rowCount=link.row_count,
186-
bytesNum=link.byte_count,
187-
startRowOffset=link.row_offset,
188-
httpHeaders=link.http_headers or {},
189-
)
190-
191-
def _download_current_link(self):
192-
"""Download the current chunk link."""
193-
if not self._current_chunk_link:
194-
return None
195-
196-
if not self.download_manager:
197-
logger.debug("SeaCloudFetchQueue: No download manager, returning")
198-
return None
199-
200-
thrift_link = self._convert_to_thrift_link(self._current_chunk_link)
201-
self.download_manager.add_link(thrift_link)
202-
203-
def _progress_chunk_link(self):
204-
"""Progress to the next chunk link."""
205-
if not self._current_chunk_link:
206-
return None
207-
208-
next_chunk_index = self._current_chunk_link.next_chunk_index
209-
210-
if next_chunk_index is None:
211-
self._current_chunk_link = None
212-
return None
213-
214-
try:
215-
self._current_chunk_link = self._sea_client.get_chunk_link(
216-
self._statement_id, next_chunk_index
217-
)
218-
except Exception as e:
219-
logger.error(
220-
"SeaCloudFetchQueue: Error fetching link for chunk {}: {}".format(
221-
next_chunk_index, e
222-
)
223-
)
224-
return None
225-
226-
logger.debug(
227-
f"SeaCloudFetchQueue: Progressed to link for chunk {next_chunk_index}: {self._current_chunk_link}"
228-
)
229-
self._download_current_link()
230-
231261
def _create_next_table(self) -> Union["pyarrow.Table", None]:
232262
"""Create next table by retrieving the logical next downloaded file."""
233-
if not self._current_chunk_link:
234-
logger.debug("SeaCloudFetchQueue: No current chunk link, returning")
263+
current_chunk_link = self.link_fetcher.get_chunk_link(self.current_chunk_index)
264+
if not current_chunk_link:
235265
return None
236266

237-
row_offset = self._current_chunk_link.row_offset
267+
row_offset = current_chunk_link.row_offset
238268
arrow_table = self._create_table_at_offset(row_offset)
239269

240-
self._progress_chunk_link()
270+
self.current_chunk_index = current_chunk_link.next_chunk_index
241271

242272
return arrow_table

0 commit comments

Comments
 (0)