11import logging
22
33from concurrent .futures import ThreadPoolExecutor , Future
4- from typing import Callable , List , Optional , Union
4+ from typing import Callable , List , Optional , Union , Generic , TypeVar
55
66from databricks .sql .cloudfetch .downloader import (
77 ResultSetDownloadHandler ,
1414
1515logger = logging .getLogger (__name__ )
1616
17+ T = TypeVar ('T' )
18+
19+
20+ class TaskWithMetadata (Generic [T ]):
21+ """
22+ Wrapper around Future that stores additional metadata (the link).
23+ Provides type-safe access to both the Future result and the associated link.
24+ """
25+
26+ def __init__ (self , future : Future [T ], link : TSparkArrowResultLink ):
27+ self .future = future
28+ self .link = link
29+
30+ def result (self , timeout : Optional [float ] = None ) -> T :
31+ """Get the result of the Future, blocking if necessary."""
32+ return self .future .result (timeout )
33+
34+ def cancel (self ) -> bool :
35+ """Cancel the Future if possible."""
36+ return self .future .cancel ()
37+
38+
1739
1840class ResultFileDownloadManager :
1941 def __init__ (
@@ -22,7 +44,7 @@ def __init__(
2244 max_download_threads : int ,
2345 lz4_compressed : bool ,
2446 ssl_options : SSLOptions ,
25- expiry_callback : Callable [[TSparkArrowResultLink ], None ],
47+ expiry_callback : Optional [ Callable [[TSparkArrowResultLink ], None ]] = None ,
2648 ):
2749 self ._pending_links : List [TSparkArrowResultLink ] = []
2850 for link in links :
@@ -35,7 +57,7 @@ def __init__(
3557 )
3658 self ._pending_links .append (link )
3759
38- self ._download_tasks : List [Future [DownloadedFile ]] = []
60+ self ._download_tasks : List [TaskWithMetadata [DownloadedFile ]] = []
3961 self ._max_download_threads : int = max_download_threads
4062 self ._thread_pool = ThreadPoolExecutor (max_workers = self ._max_download_threads )
4163
@@ -53,7 +75,7 @@ def get_next_downloaded_file(
5375 in relation to the full result. File downloads are scheduled if not already, and once the correct
5476 download handler is located, the function waits for the download status and returns the resulting file.
5577 If there are no more downloads, a download was not successful, or the correct file could not be located,
56- this function shuts down the thread pool and returns None.
78+ this function returns None.
5779
5880 Args:
5981 next_row_offset (int): The offset of the starting row of the next file we want data from.
@@ -86,6 +108,9 @@ def cancel_tasks_from_offset(self, start_row_offset: int):
86108 """
87109 Cancel all download tasks starting from a specific row offset.
88110 This is used when links expire and we need to restart from a certain point.
111+
112+ Args:
113+ start_row_offset (int): Row offset from which to cancel tasks
89114 """
90115
91116 def to_cancel (link : TSparkArrowResultLink ) -> bool :
@@ -132,21 +157,19 @@ def _schedule_downloads(self):
132157 ssl_options = self ._ssl_options ,
133158 expiry_callback = self ._expiry_callback ,
134159 )
135- task = self ._thread_pool .submit (handler .run )
136- task . link = link
160+ future = self ._thread_pool .submit (handler .run )
161+ task = TaskWithMetadata ( future , link )
137162 self ._download_tasks .append (task )
138163
139164 def add_link (self , link : TSparkArrowResultLink ):
140165 """
141166 Add more links to the download manager.
142167
143168 Args:
144- link: Link to add
169+ link (TSparkArrowResultLink): The link to add to the download manager.
145170 """
146-
147171 if link .rowCount <= 0 :
148172 return
149-
150173 logger .debug (
151174 "ResultFileDownloadManager: adding file link, start offset {}, row count: {}" .format (
152175 link .startRowOffset , link .rowCount
0 commit comments