@@ -24,6 +24,22 @@ class DownloaderTests(unittest.TestCase):
2424 Unit tests for checking downloader logic.
2525 """
2626
27+ def create_download_handler (
28+ self , settings : Mock , result_link : Mock
29+ ) -> downloader .ResultSetDownloadHandler :
30+ def expiry_callback (link : TSparkArrowResultLink ):
31+ raise Error ("Cloudfetch link has expired" )
32+
33+ return downloader .ResultSetDownloadHandler (
34+ settings ,
35+ result_link ,
36+ ssl_options = SSLOptions (),
37+ expiry_callback = expiry_callback ,
38+ chunk_id = 0 ,
39+ session_id_hex = Mock (),
40+ statement_id = Mock (),
41+ )
42+
2743 def _setup_time_mock_for_download (self , mock_time , end_time ):
2844 """Helper to setup time mock that handles logging system calls."""
2945 call_count = [0 ]
@@ -82,14 +98,7 @@ def test_run_get_response_not_ok(self, mock_time):
8298 with patch .object (http_client , "execute" ) as mock_execute :
8399 mock_execute .return_value .__enter__ .return_value = mock_response
84100
85- d = downloader .ResultSetDownloadHandler (
86- settings ,
87- result_link ,
88- ssl_options = SSLOptions (),
89- chunk_id = 0 ,
90- session_id_hex = Mock (),
91- statement_id = Mock (),
92- )
101+ d = self .create_download_handler (settings , result_link )
93102 with self .assertRaises (requests .exceptions .HTTPError ) as context :
94103 d .run ()
95104 self .assertTrue ("404" in str (context .exception ))
@@ -111,14 +120,7 @@ def test_run_uncompressed_successful(self, mock_time):
111120 "execute" ,
112121 return_value = create_response (status_code = 200 , _content = file_bytes ),
113122 ):
114- d = downloader .ResultSetDownloadHandler (
115- settings ,
116- result_link ,
117- ssl_options = SSLOptions (),
118- chunk_id = 0 ,
119- session_id_hex = Mock (),
120- statement_id = Mock (),
121- )
123+ d = self .create_download_handler (settings , result_link )
122124 file = d .run ()
123125
124126 assert file .file_bytes == b"1234567890" * 10
@@ -141,14 +143,7 @@ def test_run_compressed_successful(self, mock_time):
141143 "execute" ,
142144 return_value = create_response (status_code = 200 , _content = compressed_bytes ),
143145 ):
144- d = downloader .ResultSetDownloadHandler (
145- settings ,
146- result_link ,
147- ssl_options = SSLOptions (),
148- chunk_id = 0 ,
149- session_id_hex = Mock (),
150- statement_id = Mock (),
151- )
146+ d = self .create_download_handler (settings , result_link )
152147 file = d .run ()
153148
154149 assert file .file_bytes == b"1234567890" * 10
@@ -163,14 +158,7 @@ def test_download_connection_error(self, mock_time):
163158 result_link = Mock (bytesNum = 100 , expiryTime = 1001 )
164159
165160 with patch .object (http_client , "execute" , side_effect = ConnectionError ("foo" )):
166- d = downloader .ResultSetDownloadHandler (
167- settings ,
168- result_link ,
169- ssl_options = SSLOptions (),
170- chunk_id = 0 ,
171- session_id_hex = Mock (),
172- statement_id = Mock (),
173- )
161+ d = self .create_download_handler (settings , result_link )
174162 with self .assertRaises (ConnectionError ):
175163 d .run ()
176164
@@ -183,13 +171,6 @@ def test_download_timeout(self, mock_time):
183171 result_link = Mock (bytesNum = 100 , expiryTime = 1001 )
184172
185173 with patch .object (http_client , "execute" , side_effect = TimeoutError ("foo" )):
186- d = downloader .ResultSetDownloadHandler (
187- settings ,
188- result_link ,
189- ssl_options = SSLOptions (),
190- chunk_id = 0 ,
191- session_id_hex = Mock (),
192- statement_id = Mock (),
193- )
174+ d = self .create_download_handler (settings , result_link )
194175 with self .assertRaises (TimeoutError ):
195176 d .run ()
0 commit comments