Skip to content

Commit 0829e67

Browse files
correct unit tests for download management
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent f374f5f commit 0829e67

File tree

4 files changed

+32
-26
lines changed

4 files changed

+32
-26
lines changed

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,19 @@ def cancel_tasks_from_offset(self, start_row_offset: int):
9191
def to_cancel(link: TSparkArrowResultLink) -> bool:
9292
return link.startRowOffset < start_row_offset
9393

94-
tasks_to_cancel = [task for task in self._download_tasks if to_cancel(task.link)]
94+
tasks_to_cancel = [
95+
task for task in self._download_tasks if to_cancel(task.link)
96+
]
9597
for task in tasks_to_cancel:
9698
task.cancel()
9799
logger.info(
98100
f"ResultFileDownloadManager: cancelled {len(tasks_to_cancel)} tasks from offset {start_row_offset}"
99101
)
100102

101103
# Remove cancelled tasks from the download queue
102-
tasks_to_keep = [task for task in self._download_tasks if not to_cancel(task.link)]
104+
tasks_to_keep = [
105+
task for task in self._download_tasks if not to_cancel(task.link)
106+
]
103107
self._download_tasks = tasks_to_keep
104108

105109
pending_links_to_keep = [

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,7 @@ def run(self) -> DownloadedFile:
8989
)
9090

9191
# Check if link is already expired or is expiring
92-
self._validate_link(
93-
self.link, self.settings.link_expiry_buffer_secs
94-
)
92+
self._validate_link(self.link, self.settings.link_expiry_buffer_secs)
9593

9694
session = requests.Session()
9795
session.mount("http://", HTTPAdapter(max_retries=retryPolicy))

tests/unit/test_download_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@ class DownloadManagerTests(unittest.TestCase):
1414
def create_download_manager(
1515
self, links, max_download_threads=10, lz4_compressed=True
1616
):
17+
def expiry_callback(link: TSparkArrowResultLink):
18+
return None
19+
1720
return download_manager.ResultFileDownloadManager(
1821
links,
1922
max_download_threads,
2023
lz4_compressed,
2124
ssl_options=SSLOptions(),
25+
expiry_callback=expiry_callback,
2226
)
2327

2428
def create_result_link(

tests/unit/test_downloader.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import databricks.sql.cloudfetch.downloader as downloader
77
from databricks.sql.exc import Error
8+
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
89
from databricks.sql.types import SSLOptions
910

1011

@@ -20,15 +21,26 @@ class DownloaderTests(unittest.TestCase):
2021
Unit tests for checking downloader logic.
2122
"""
2223

24+
def create_download_handler(
25+
self, settings: Mock, result_link: Mock
26+
) -> downloader.ResultSetDownloadHandler:
27+
def expiry_callback(link: TSparkArrowResultLink):
28+
raise Error("Cloudfetch link has expired")
29+
30+
return downloader.ResultSetDownloadHandler(
31+
settings,
32+
result_link,
33+
ssl_options=SSLOptions(),
34+
expiry_callback=expiry_callback,
35+
)
36+
2337
@patch("time.time", return_value=1000)
2438
def test_run_link_expired(self, mock_time):
2539
settings = Mock()
2640
result_link = Mock()
2741
# Already expired
2842
result_link.expiryTime = 999
29-
d = downloader.ResultSetDownloadHandler(
30-
settings, result_link, ssl_options=SSLOptions()
31-
)
43+
d = self.create_download_handler(settings, result_link)
3244

3345
with self.assertRaises(Error) as context:
3446
d.run()
@@ -42,9 +54,7 @@ def test_run_link_past_expiry_buffer(self, mock_time):
4254
result_link = Mock()
4355
# Within the expiry buffer time
4456
result_link.expiryTime = 1004
45-
d = downloader.ResultSetDownloadHandler(
46-
settings, result_link, ssl_options=SSLOptions()
47-
)
57+
d = self.create_download_handler(settings, result_link)
4858

4959
with self.assertRaises(Error) as context:
5060
d.run()
@@ -62,9 +72,7 @@ def test_run_get_response_not_ok(self, mock_time, mock_session):
6272
settings.use_proxy = False
6373
result_link = Mock(expiryTime=1001)
6474

65-
d = downloader.ResultSetDownloadHandler(
66-
settings, result_link, ssl_options=SSLOptions()
67-
)
75+
d = self.create_download_handler(settings, result_link)
6876
with self.assertRaises(requests.exceptions.HTTPError) as context:
6977
d.run()
7078
self.assertTrue("404" in str(context.exception))
@@ -81,9 +89,7 @@ def test_run_uncompressed_successful(self, mock_time, mock_session):
8189
settings.is_lz4_compressed = False
8290
result_link = Mock(bytesNum=100, expiryTime=1001)
8391

84-
d = downloader.ResultSetDownloadHandler(
85-
settings, result_link, ssl_options=SSLOptions()
86-
)
92+
d = self.create_download_handler(settings, result_link)
8793
file = d.run()
8894

8995
assert file.file_bytes == b"1234567890" * 10
@@ -104,9 +110,7 @@ def test_run_compressed_successful(self, mock_time, mock_session):
104110
settings.is_lz4_compressed = True
105111
result_link = Mock(bytesNum=100, expiryTime=1001)
106112

107-
d = downloader.ResultSetDownloadHandler(
108-
settings, result_link, ssl_options=SSLOptions()
109-
)
113+
d = self.create_download_handler(settings, result_link)
110114
file = d.run()
111115

112116
assert file.file_bytes == b"1234567890" * 10
@@ -120,9 +124,7 @@ def test_download_connection_error(self, mock_time, mock_session):
120124
result_link = Mock(bytesNum=100, expiryTime=1001)
121125
mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
122126

123-
d = downloader.ResultSetDownloadHandler(
124-
settings, result_link, ssl_options=SSLOptions()
125-
)
127+
d = self.create_download_handler(settings, result_link)
126128
with self.assertRaises(ConnectionError):
127129
d.run()
128130

@@ -135,8 +137,6 @@ def test_download_timeout(self, mock_time, mock_session):
135137
result_link = Mock(bytesNum=100, expiryTime=1001)
136138
mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
137139

138-
d = downloader.ResultSetDownloadHandler(
139-
settings, result_link, ssl_options=SSLOptions()
140-
)
140+
d = self.create_download_handler(settings, result_link)
141141
with self.assertRaises(TimeoutError):
142142
d.run()

0 commit comments

Comments
 (0)