11from __future__ import annotations
22
33from abc import ABC
4- from typing import List , Optional , Tuple , Union
4+ import threading
5+ from typing import Dict , List , Optional , Tuple , Union
56
67from databricks .sql .cloudfetch .download_manager import ResultFileDownloadManager
78
@@ -111,6 +112,86 @@ def remaining_rows(self) -> List[List[str]]:
111112 return slice
112113
113114
115+ class LinkFetcher :
116+ def __init__ (
117+ self ,
118+ download_manager : ResultFileDownloadManager ,
119+ backend : "SeaDatabricksClient" ,
120+ statement_id : str ,
121+ current_chunk_link : Optional ["ExternalLink" ] = None ,
122+ ):
123+ self .download_manager = download_manager
124+ self .backend = backend
125+ self ._statement_id = statement_id
126+ self ._current_chunk_link = current_chunk_link
127+
128+ self ._shutdown_event = threading .Event ()
129+
130+ self ._map_lock = threading .Lock ()
131+ self .chunk_index_to_link : Dict [int , "ExternalLink" ] = {}
132+
133+ def _set_current_chunk_link (self , link : "ExternalLink" ):
134+ with self ._map_lock :
135+ self .chunk_index_to_link [link .chunk_index ] = link
136+
137+ def get_chunk_link (self , chunk_index : int ) -> Optional ["ExternalLink" ]:
138+ with self ._map_lock :
139+ return self .chunk_index_to_link .get (chunk_index , None )
140+
141+ def _convert_to_thrift_link (self , link : "ExternalLink" ) -> TSparkArrowResultLink :
142+ """Convert SEA external links to Thrift format for compatibility with existing download manager."""
143+ # Parse the ISO format expiration time
144+ expiry_time = int (dateutil .parser .parse (link .expiration ).timestamp ())
145+ return TSparkArrowResultLink (
146+ fileLink = link .external_link ,
147+ expiryTime = expiry_time ,
148+ rowCount = link .row_count ,
149+ bytesNum = link .byte_count ,
150+ startRowOffset = link .row_offset ,
151+ httpHeaders = link .http_headers or {},
152+ )
153+
154+ def _progress_chunk_link (self ):
155+ """Progress to the next chunk link."""
156+ if not self ._current_chunk_link :
157+ return None
158+
159+ next_chunk_index = self ._current_chunk_link .next_chunk_index
160+
161+ if next_chunk_index is None :
162+ self ._current_chunk_link = None
163+ return None
164+
165+ try :
166+ self ._current_chunk_link = self .backend .get_chunk_link (
167+ self ._statement_id , next_chunk_index
168+ )
169+ except Exception as e :
170+ logger .error (
171+ "LinkFetcher: Error fetching link for chunk {}: {}" .format (
172+ next_chunk_index , e
173+ )
174+ )
175+ self ._current_chunk_link = None
176+
177+ def _worker_loop (self ):
178+ while not (self ._shutdown_event .is_set () or self ._current_chunk_link is None ):
179+ self ._set_current_chunk_link (self ._current_chunk_link )
180+ self .download_manager .add_link (
181+ self ._convert_to_thrift_link (self ._current_chunk_link )
182+ )
183+
184+ self ._progress_chunk_link ()
185+
186+ def start (self ):
187+ self ._worker_thread = threading .Thread (target = self ._worker_loop )
188+ self ._worker_thread .start ()
189+
190+ def stop (self ):
191+ self ._shutdown_event .set ()
192+ self ._worker_thread .join ()
193+
194+
114195class SeaCloudFetchQueue (CloudFetchQueue ):
115196 """Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend."""
116197
@@ -160,6 +241,7 @@ def __init__(
160241 initial_link = next ((l for l in initial_links if l .chunk_index == 0 ), None )
161242 if not initial_link :
162243 return
244+ self .current_chunk_index = initial_link .chunk_index
163245
164246 self .download_manager = ResultFileDownloadManager (
165247 links = [],
@@ -168,75 +250,23 @@ def __init__(
168250 ssl_options = ssl_options ,
169251 )
170252
171- # Track the current chunk we're processing
172- self ._current_chunk_link : Optional ["ExternalLink" ] = initial_link
173- self ._download_current_link ()
253+ self .link_fetcher = LinkFetcher (
254+ self .download_manager , self ._sea_client , statement_id , initial_link
255+ )
256+ self .link_fetcher .start ()
174257
175258 # Initialize table and position
176259 self .table = self ._create_next_table ()
177260
178- def _convert_to_thrift_link (self , link : "ExternalLink" ) -> TSparkArrowResultLink :
179- """Convert SEA external links to Thrift format for compatibility with existing download manager."""
180- # Parse the ISO format expiration time
181- expiry_time = int (dateutil .parser .parse (link .expiration ).timestamp ())
182- return TSparkArrowResultLink (
183- fileLink = link .external_link ,
184- expiryTime = expiry_time ,
185- rowCount = link .row_count ,
186- bytesNum = link .byte_count ,
187- startRowOffset = link .row_offset ,
188- httpHeaders = link .http_headers or {},
189- )
190-
191- def _download_current_link (self ):
192- """Download the current chunk link."""
193- if not self ._current_chunk_link :
194- return None
195-
196- if not self .download_manager :
197- logger .debug ("SeaCloudFetchQueue: No download manager, returning" )
198- return None
199-
200- thrift_link = self ._convert_to_thrift_link (self ._current_chunk_link )
201- self .download_manager .add_link (thrift_link )
202-
203- def _progress_chunk_link (self ):
204- """Progress to the next chunk link."""
205- if not self ._current_chunk_link :
206- return None
207-
208- next_chunk_index = self ._current_chunk_link .next_chunk_index
209-
210- if next_chunk_index is None :
211- self ._current_chunk_link = None
212- return None
213-
214- try :
215- self ._current_chunk_link = self ._sea_client .get_chunk_link (
216- self ._statement_id , next_chunk_index
217- )
218- except Exception as e :
219- logger .error (
220- "SeaCloudFetchQueue: Error fetching link for chunk {}: {}" .format (
221- next_chunk_index , e
222- )
223- )
224- return None
225-
226- logger .debug (
227- f"SeaCloudFetchQueue: Progressed to link for chunk { next_chunk_index } : { self ._current_chunk_link } "
228- )
229- self ._download_current_link ()
230-
231261 def _create_next_table (self ) -> Union ["pyarrow.Table" , None ]:
232262 """Create next table by retrieving the logical next downloaded file."""
233- if not self ._current_chunk_link :
234- logger . debug ( "SeaCloudFetchQueue: No current chunk link, returning" )
263+ current_chunk_link = self .link_fetcher . get_chunk_link ( self . current_chunk_index )
264+ if not current_chunk_link :
235265 return None
236266
237- row_offset = self . _current_chunk_link .row_offset
267+ row_offset = current_chunk_link .row_offset
238268 arrow_table = self ._create_table_at_offset (row_offset )
239269
240- self ._progress_chunk_link ()
270+ self .current_chunk_index = current_chunk_link . next_chunk_index
241271
242272 return arrow_table
0 commit comments