55
66import databricks .sql .cloudfetch .downloader as downloader
77from databricks .sql .exc import Error
8+ from databricks .sql .thrift_api .TCLIService .ttypes import TSparkArrowResultLink
89from databricks .sql .types import SSLOptions
910
1011
@@ -20,15 +21,26 @@ class DownloaderTests(unittest.TestCase):
2021 Unit tests for checking downloader logic.
2122 """
2223
24+ def create_download_handler (
25+ self , settings : Mock , result_link : Mock
26+ ) -> downloader .ResultSetDownloadHandler :
27+ def expiry_callback (link : TSparkArrowResultLink ):
28+ raise Error ("Cloudfetch link has expired" )
29+
30+ return downloader .ResultSetDownloadHandler (
31+ settings ,
32+ result_link ,
33+ ssl_options = SSLOptions (),
34+ expiry_callback = expiry_callback ,
35+ )
36+
2337 @patch ("time.time" , return_value = 1000 )
2438 def test_run_link_expired (self , mock_time ):
2539 settings = Mock ()
2640 result_link = Mock ()
2741 # Already expired
2842 result_link .expiryTime = 999
29- d = downloader .ResultSetDownloadHandler (
30- settings , result_link , ssl_options = SSLOptions ()
31- )
43+ d = self .create_download_handler (settings , result_link )
3244
3345 with self .assertRaises (Error ) as context :
3446 d .run ()
@@ -42,9 +54,7 @@ def test_run_link_past_expiry_buffer(self, mock_time):
4254 result_link = Mock ()
4355 # Within the expiry buffer time
4456 result_link .expiryTime = 1004
45- d = downloader .ResultSetDownloadHandler (
46- settings , result_link , ssl_options = SSLOptions ()
47- )
57+ d = self .create_download_handler (settings , result_link )
4858
4959 with self .assertRaises (Error ) as context :
5060 d .run ()
@@ -62,9 +72,7 @@ def test_run_get_response_not_ok(self, mock_time, mock_session):
6272 settings .use_proxy = False
6373 result_link = Mock (expiryTime = 1001 )
6474
65- d = downloader .ResultSetDownloadHandler (
66- settings , result_link , ssl_options = SSLOptions ()
67- )
75+ d = self .create_download_handler (settings , result_link )
6876 with self .assertRaises (requests .exceptions .HTTPError ) as context :
6977 d .run ()
7078 self .assertTrue ("404" in str (context .exception ))
@@ -81,9 +89,7 @@ def test_run_uncompressed_successful(self, mock_time, mock_session):
8189 settings .is_lz4_compressed = False
8290 result_link = Mock (bytesNum = 100 , expiryTime = 1001 )
8391
84- d = downloader .ResultSetDownloadHandler (
85- settings , result_link , ssl_options = SSLOptions ()
86- )
92+ d = self .create_download_handler (settings , result_link )
8793 file = d .run ()
8894
8995 assert file .file_bytes == b"1234567890" * 10
@@ -104,9 +110,7 @@ def test_run_compressed_successful(self, mock_time, mock_session):
104110 settings .is_lz4_compressed = True
105111 result_link = Mock (bytesNum = 100 , expiryTime = 1001 )
106112
107- d = downloader .ResultSetDownloadHandler (
108- settings , result_link , ssl_options = SSLOptions ()
109- )
113+ d = self .create_download_handler (settings , result_link )
110114 file = d .run ()
111115
112116 assert file .file_bytes == b"1234567890" * 10
@@ -120,9 +124,7 @@ def test_download_connection_error(self, mock_time, mock_session):
120124 result_link = Mock (bytesNum = 100 , expiryTime = 1001 )
121125 mock_session .return_value .get .return_value .content = b'\x04 "M\x18 h@d\x00 \x00 \x00 \x00 \x00 \x00 \x00 #\x14 \x00 \x00 \x00 \xaf 1234567890\n \x00 BP67890\x00 \x00 \x00 \x00 '
122126
123- d = downloader .ResultSetDownloadHandler (
124- settings , result_link , ssl_options = SSLOptions ()
125- )
127+ d = self .create_download_handler (settings , result_link )
126128 with self .assertRaises (ConnectionError ):
127129 d .run ()
128130
@@ -135,8 +137,6 @@ def test_download_timeout(self, mock_time, mock_session):
135137 result_link = Mock (bytesNum = 100 , expiryTime = 1001 )
136138 mock_session .return_value .get .return_value .content = b'\x04 "M\x18 h@d\x00 \x00 \x00 \x00 \x00 \x00 \x00 #\x14 \x00 \x00 \x00 \xaf 1234567890\n \x00 BP67890\x00 \x00 \x00 \x00 '
137139
138- d = downloader .ResultSetDownloadHandler (
139- settings , result_link , ssl_options = SSLOptions ()
140- )
140+ d = self .create_download_handler (settings , result_link )
141141 with self .assertRaises (TimeoutError ):
142142 d .run ()
0 commit comments