Skip to content

Commit a18be78

Browse files
create strongly typed Future for download tasks
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 0829e67 commit a18be78

File tree

1 file changed

+32
-9
lines changed

1 file changed

+32
-9
lines changed

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from concurrent.futures import ThreadPoolExecutor, Future
4-
from typing import Callable, List, Optional, Union
4+
from typing import Callable, List, Optional, Union, Generic, TypeVar
55

66
from databricks.sql.cloudfetch.downloader import (
77
ResultSetDownloadHandler,
@@ -14,6 +14,28 @@
1414

1515
logger = logging.getLogger(__name__)
1616

17+
T = TypeVar('T')
18+
19+
20+
class TaskWithMetadata(Generic[T]):
21+
"""
22+
Wrapper around Future that stores additional metadata (the link).
23+
Provides type-safe access to both the Future result and the associated link.
24+
"""
25+
26+
def __init__(self, future: Future[T], link: TSparkArrowResultLink):
27+
self.future = future
28+
self.link = link
29+
30+
def result(self, timeout: Optional[float] = None) -> T:
31+
"""Get the result of the Future, blocking if necessary."""
32+
return self.future.result(timeout)
33+
34+
def cancel(self) -> bool:
35+
"""Cancel the Future if possible."""
36+
return self.future.cancel()
37+
38+
1739

1840
class ResultFileDownloadManager:
1941
def __init__(
@@ -22,7 +44,7 @@ def __init__(
2244
max_download_threads: int,
2345
lz4_compressed: bool,
2446
ssl_options: SSLOptions,
25-
expiry_callback: Callable[[TSparkArrowResultLink], None],
47+
expiry_callback: Optional[Callable[[TSparkArrowResultLink], None]] = None,
2648
):
2749
self._pending_links: List[TSparkArrowResultLink] = []
2850
for link in links:
@@ -35,7 +57,7 @@ def __init__(
3557
)
3658
self._pending_links.append(link)
3759

38-
self._download_tasks: List[Future[DownloadedFile]] = []
60+
self._download_tasks: List[TaskWithMetadata[DownloadedFile]] = []
3961
self._max_download_threads: int = max_download_threads
4062
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)
4163

@@ -53,7 +75,7 @@ def get_next_downloaded_file(
5375
in relation to the full result. File downloads are scheduled if not already, and once the correct
5476
download handler is located, the function waits for the download status and returns the resulting file.
5577
If there are no more downloads, a download was not successful, or the correct file could not be located,
56-
this function shuts down the thread pool and returns None.
78+
this function returns None.
5779
5880
Args:
5981
next_row_offset (int): The offset of the starting row of the next file we want data from.
@@ -86,6 +108,9 @@ def cancel_tasks_from_offset(self, start_row_offset: int):
86108
"""
87109
Cancel all download tasks starting from a specific row offset.
88110
This is used when links expire and we need to restart from a certain point.
111+
112+
Args:
113+
start_row_offset (int): Row offset from which to cancel tasks
89114
"""
90115

91116
def to_cancel(link: TSparkArrowResultLink) -> bool:
@@ -132,21 +157,19 @@ def _schedule_downloads(self):
132157
ssl_options=self._ssl_options,
133158
expiry_callback=self._expiry_callback,
134159
)
135-
task = self._thread_pool.submit(handler.run)
136-
task.link = link
160+
future = self._thread_pool.submit(handler.run)
161+
task = TaskWithMetadata(future, link)
137162
self._download_tasks.append(task)
138163

139164
def add_link(self, link: TSparkArrowResultLink):
140165
"""
141166
Add more links to the download manager.
142167
143168
Args:
144-
link: Link to add
169+
link (TSparkArrowResultLink): The link to add to the download manager.
145170
"""
146-
147171
if link.rowCount <= 0:
148172
return
149-
150173
logger.debug(
151174
"ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format(
152175
link.startRowOffset, link.rowCount

0 commit comments

Comments
 (0)