|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | from abc import ABC |
4 | | -from typing import List, Optional, Tuple, Union |
| 4 | +import threading |
| 5 | +from typing import Dict, List, Optional, Tuple, Union |
5 | 6 |
|
6 | 7 | from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager |
7 | 8 |
|
@@ -136,6 +137,105 @@ def remaining_rows(self) -> List[List[str]]: |
136 | 137 | return slice |
137 | 138 |
|
138 | 139 |
|
| 140 | +class LinkFetcher: |
| 141 | + def __init__( |
| 142 | + self, |
| 143 | + download_manager: ResultFileDownloadManager, |
| 144 | + backend: "SeaDatabricksClient", |
| 145 | + statement_id: str, |
| 146 | + initial_links: List["ExternalLink"], |
| 147 | + total_chunk_count: int, |
| 148 | + ): |
| 149 | + self.download_manager = download_manager |
| 150 | + self.backend = backend |
| 151 | + self._statement_id = statement_id |
| 152 | + |
| 153 | + self._shutdown_event = threading.Event() |
| 154 | + |
| 155 | + self._condition = threading.Condition() |
| 156 | + self._error = None |
| 157 | + self.chunk_index_to_link: Dict[int, "ExternalLink"] = {} |
| 158 | + for link in initial_links: |
| 159 | + self.chunk_index_to_link[link.chunk_index] = link |
| 160 | + self.download_manager.add_link(self._convert_to_thrift_link(link)) |
| 161 | + self.total_chunk_count = total_chunk_count |
| 162 | + |
| 163 | + def _get_next_chunk_index(self) -> Optional[int]: |
| 164 | + with self._condition: |
| 165 | + max_chunk_index = max(self.chunk_index_to_link.keys(), default=None) |
| 166 | + if max_chunk_index is None: |
| 167 | + return 0 |
| 168 | + max_link = self.chunk_index_to_link[max_chunk_index] |
| 169 | + return max_link.next_chunk_index |
| 170 | + |
| 171 | + def _trigger_next_batch_download(self) -> bool: |
| 172 | + next_chunk_index = self._get_next_chunk_index() |
| 173 | + if next_chunk_index is None: |
| 174 | + return False |
| 175 | + |
| 176 | + try: |
| 177 | + links = self.backend.get_chunk_links(self._statement_id, next_chunk_index) |
| 178 | + with self._condition: |
| 179 | + self.chunk_index_to_link.update( |
| 180 | + {link.chunk_index: link for link in links} |
| 181 | + ) |
| 182 | + self._condition.notify_all() |
| 183 | + for link in links: |
| 184 | + self.download_manager.add_link(self._convert_to_thrift_link(link)) |
| 185 | + except Exception as e: |
| 186 | + logger.error( |
| 187 | + f"LinkFetcher: Error fetching links for chunk {next_chunk_index}: {e}" |
| 188 | + ) |
| 189 | + with self._condition: |
| 190 | + self._error = e |
| 191 | + self._condition.notify_all() |
| 192 | + return False |
| 193 | + |
| 194 | + return True |
| 195 | + |
| 196 | + def get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]: |
| 197 | + if chunk_index >= self.total_chunk_count: |
| 198 | + return None |
| 199 | + |
| 200 | + with self._condition: |
| 201 | + if self._error: |
| 202 | + raise self._error |
| 203 | + |
| 204 | + while chunk_index not in self.chunk_index_to_link: |
| 205 | + if self._error: |
| 206 | + raise self._error |
| 207 | + self._condition.wait() |
| 208 | + |
| 209 | + return self.chunk_index_to_link.get(chunk_index, None) |
| 210 | + |
| 211 | + def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink: |
| 212 | + """Convert SEA external links to Thrift format for compatibility with existing download manager.""" |
| 213 | + # Parse the ISO format expiration time |
| 214 | + expiry_time = int(dateutil.parser.parse(link.expiration).timestamp()) |
| 215 | + return TSparkArrowResultLink( |
| 216 | + fileLink=link.external_link, |
| 217 | + expiryTime=expiry_time, |
| 218 | + rowCount=link.row_count, |
| 219 | + bytesNum=link.byte_count, |
| 220 | + startRowOffset=link.row_offset, |
| 221 | + httpHeaders=link.http_headers or {}, |
| 222 | + ) |
| 223 | + |
| 224 | + def _worker_loop(self): |
| 225 | + while not self._shutdown_event.is_set(): |
| 226 | + links_downloaded = self._trigger_next_batch_download() |
| 227 | + if not links_downloaded: |
| 228 | + break |
| 229 | + |
| 230 | + def start(self): |
| 231 | + self._worker_thread = threading.Thread(target=self._worker_loop) |
| 232 | + self._worker_thread.start() |
| 233 | + |
| 234 | + def stop(self): |
| 235 | + self._shutdown_event.set() |
| 236 | + self._worker_thread.join() |
| 237 | + |
| 238 | + |
139 | 239 | class SeaCloudFetchQueue(CloudFetchQueue): |
140 | 240 | """Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend.""" |
141 | 241 |
|
@@ -182,84 +282,43 @@ def __init__( |
182 | 282 | ) |
183 | 283 | ) |
184 | 284 |
|
185 | | - self._chunk_index_to_link = {link.chunk_index: link for link in initial_links} |
186 | | - |
187 | | - initial_link = self._chunk_index_to_link.get(0, None) |
188 | | - if not initial_link: |
| 285 | + if total_chunk_count < 1: |
189 | 286 | return |
190 | 287 |
|
| 288 | + self.current_chunk_index = 0 |
| 289 | + |
191 | 290 | self.download_manager = ResultFileDownloadManager( |
192 | 291 | links=[], |
193 | 292 | max_download_threads=max_download_threads, |
194 | 293 | lz4_compressed=lz4_compressed, |
195 | 294 | ssl_options=ssl_options, |
196 | 295 | ) |
197 | 296 |
|
198 | | - # Track the current chunk we're processing |
199 | | - self._current_chunk_link: Optional["ExternalLink"] = initial_link |
| 297 | + self.link_fetcher = LinkFetcher( |
| 298 | + download_manager=self.download_manager, |
| 299 | + backend=self._sea_client, |
| 300 | + statement_id=self._statement_id, |
| 301 | + initial_links=initial_links, |
| 302 | + total_chunk_count=total_chunk_count, |
| 303 | + ) |
| 304 | + self.link_fetcher.start() |
200 | 305 |
|
201 | 306 | # Initialize table and position |
202 | 307 | self.table = self._create_next_table() |
203 | 308 |
|
204 | | - def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink: |
205 | | - """Convert SEA external links to Thrift format for compatibility with existing download manager.""" |
206 | | - # Parse the ISO format expiration time |
207 | | - expiry_time = int(dateutil.parser.parse(link.expiration).timestamp()) |
208 | | - return TSparkArrowResultLink( |
209 | | - fileLink=link.external_link, |
210 | | - expiryTime=expiry_time, |
211 | | - rowCount=link.row_count, |
212 | | - bytesNum=link.byte_count, |
213 | | - startRowOffset=link.row_offset, |
214 | | - httpHeaders=link.http_headers or {}, |
215 | | - ) |
216 | | - |
217 | | - def _get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]: |
218 | | - if chunk_index not in self._chunk_index_to_link: |
219 | | - links = self._sea_client.get_chunk_links(self._statement_id, chunk_index) |
220 | | - self._chunk_index_to_link.update({link.chunk_index: link for link in links}) |
221 | | - return self._chunk_index_to_link.get(chunk_index, None) |
222 | | - |
223 | | - def _progress_chunk_link(self): |
224 | | - """Progress to the next chunk link.""" |
225 | | - if not self._current_chunk_link: |
226 | | - return None |
227 | | - |
228 | | - next_chunk_index = self._current_chunk_link.next_chunk_index |
229 | | - |
230 | | - if next_chunk_index is None: |
231 | | - self._current_chunk_link = None |
232 | | - return None |
233 | | - |
234 | | - self._current_chunk_link = self._get_chunk_link(next_chunk_index) |
235 | | - if not self._current_chunk_link: |
236 | | - logger.error( |
237 | | - "SeaCloudFetchQueue: unable to retrieve link for chunk {}".format( |
238 | | - next_chunk_index |
239 | | - ) |
240 | | - ) |
241 | | - return None |
242 | | - |
243 | | - logger.debug( |
244 | | - f"SeaCloudFetchQueue: Progressed to link for chunk {next_chunk_index}: {self._current_chunk_link}" |
245 | | - ) |
246 | | - |
247 | 309 | def _create_next_table(self) -> Union["pyarrow.Table", None]: |
248 | 310 | """Create next table by retrieving the logical next downloaded file.""" |
249 | | - if not self._current_chunk_link: |
250 | | - logger.debug("SeaCloudFetchQueue: No current chunk link, returning") |
251 | | - return None |
252 | | - |
253 | 311 | if not self.download_manager: |
254 | 312 | logger.debug("SeaCloudFetchQueue: No download manager, returning") |
255 | 313 | return None |
256 | 314 |
|
257 | | - thrift_link = self._convert_to_thrift_link(self._current_chunk_link) |
258 | | - self.download_manager.add_link(thrift_link) |
| 315 | + chunk_link = self.link_fetcher.get_chunk_link(self.current_chunk_index) |
| 316 | + if not chunk_link: |
| 317 | + return None |
259 | 318 |
|
260 | | - row_offset = self._current_chunk_link.row_offset |
| 319 | + row_offset = chunk_link.row_offset |
261 | 320 | arrow_table = self._create_table_at_offset(row_offset) |
262 | 321 |
|
263 | | - self._progress_chunk_link() |
| 322 | + self.current_chunk_index += 1 |
264 | 323 |
|
265 | 324 | return arrow_table |
0 commit comments