Skip to content

Commit e9040cb

Browse files
committed
test fix
1 parent 7c7b121 commit e9040cb

File tree

1 file changed

+58
-52
lines changed

1 file changed

+58
-52
lines changed

tests/unit/test_downloader.py

Lines changed: 58 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from contextlib import contextmanager
12
import unittest
23
from unittest.mock import Mock, patch, MagicMock
34

45
import requests
56

67
import databricks.sql.cloudfetch.downloader as downloader
8+
from databricks.sql.common.http import DatabricksHttpClient
79
from databricks.sql.exc import Error
810
from databricks.sql.types import SSLOptions
911

@@ -12,6 +14,7 @@ def create_response(**kwargs) -> requests.Response:
1214
result = requests.Response()
1315
for k, v in kwargs.items():
1416
setattr(result, k, v)
17+
result.close = Mock()
1518
return result
1619

1720

@@ -52,91 +55,94 @@ def test_run_link_past_expiry_buffer(self, mock_time):
5255

5356
mock_time.assert_called_once()
5457

55-
@patch("requests.Session", return_value=MagicMock(get=MagicMock(return_value=None)))
5658
@patch("time.time", return_value=1000)
57-
def test_run_get_response_not_ok(self, mock_time, mock_session):
58-
mock_session.return_value.get.return_value = create_response(status_code=404)
59-
59+
def test_run_get_response_not_ok(self, mock_time):
60+
http_client = DatabricksHttpClient.get_instance()
6061
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0)
6162
settings.download_timeout = 0
6263
settings.use_proxy = False
6364
result_link = Mock(expiryTime=1001)
6465

65-
d = downloader.ResultSetDownloadHandler(
66-
settings, result_link, ssl_options=SSLOptions()
67-
)
68-
with self.assertRaises(requests.exceptions.HTTPError) as context:
69-
d.run()
70-
self.assertTrue("404" in str(context.exception))
66+
with patch.object(
67+
http_client,
68+
"execute",
69+
return_value=create_response(status_code=404, _content=b"1234"),
70+
):
71+
d = downloader.ResultSetDownloadHandler(
72+
settings, result_link, ssl_options=SSLOptions()
73+
)
74+
with self.assertRaises(requests.exceptions.HTTPError) as context:
75+
d.run()
76+
self.assertTrue("404" in str(context.exception))
7177

72-
@patch("requests.Session", return_value=MagicMock(get=MagicMock(return_value=None)))
7378
@patch("time.time", return_value=1000)
74-
def test_run_uncompressed_successful(self, mock_time, mock_session):
79+
def test_run_uncompressed_successful(self, mock_time):
80+
http_client = DatabricksHttpClient.get_instance()
7581
file_bytes = b"1234567890" * 10
76-
mock_session.return_value.get.return_value = create_response(
77-
status_code=200, _content=file_bytes
78-
)
79-
8082
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
8183
settings.is_lz4_compressed = False
8284
result_link = Mock(bytesNum=100, expiryTime=1001)
8385

84-
d = downloader.ResultSetDownloadHandler(
85-
settings, result_link, ssl_options=SSLOptions()
86-
)
87-
file = d.run()
86+
with patch.object(
87+
http_client,
88+
"execute",
89+
return_value=create_response(status_code=200, _content=file_bytes),
90+
):
91+
d = downloader.ResultSetDownloadHandler(
92+
settings, result_link, ssl_options=SSLOptions()
93+
)
94+
file = d.run()
8895

89-
assert file.file_bytes == b"1234567890" * 10
96+
assert file.file_bytes == b"1234567890" * 10
9097

91-
@patch(
92-
"requests.Session",
93-
return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True))),
94-
)
9598
@patch("time.time", return_value=1000)
96-
def test_run_compressed_successful(self, mock_time, mock_session):
99+
def test_run_compressed_successful(self, mock_time):
100+
http_client = DatabricksHttpClient.get_instance()
97101
file_bytes = b"1234567890" * 10
98102
compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
99-
mock_session.return_value.get.return_value = create_response(
100-
status_code=200, _content=compressed_bytes
101-
)
102103

103104
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
104105
settings.is_lz4_compressed = True
105106
result_link = Mock(bytesNum=100, expiryTime=1001)
107+
with patch.object(
108+
http_client,
109+
"execute",
110+
return_value=create_response(status_code=200, _content=compressed_bytes),
111+
):
112+
d = downloader.ResultSetDownloadHandler(
113+
settings, result_link, ssl_options=SSLOptions()
114+
)
115+
file = d.run()
116+
117+
assert file.file_bytes == b"1234567890" * 10
106118

107-
d = downloader.ResultSetDownloadHandler(
108-
settings, result_link, ssl_options=SSLOptions()
109-
)
110-
file = d.run()
111-
112-
assert file.file_bytes == b"1234567890" * 10
113-
114-
@patch("requests.Session.get", side_effect=ConnectionError("foo"))
115119
@patch("time.time", return_value=1000)
116-
def test_download_connection_error(self, mock_time, mock_session):
120+
def test_download_connection_error(self, mock_time):
121+
122+
http_client = DatabricksHttpClient.get_instance()
117123
settings = Mock(
118124
link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True
119125
)
120126
result_link = Mock(bytesNum=100, expiryTime=1001)
121-
mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
122127

123-
d = downloader.ResultSetDownloadHandler(
124-
settings, result_link, ssl_options=SSLOptions()
125-
)
126-
with self.assertRaises(ConnectionError):
127-
d.run()
128+
with patch.object(http_client, "execute", side_effect=ConnectionError("foo")):
129+
d = downloader.ResultSetDownloadHandler(
130+
settings, result_link, ssl_options=SSLOptions()
131+
)
132+
with self.assertRaises(ConnectionError):
133+
d.run()
128134

129-
@patch("requests.Session.get", side_effect=TimeoutError("foo"))
130135
@patch("time.time", return_value=1000)
131-
def test_download_timeout(self, mock_time, mock_session):
136+
def test_download_timeout(self, mock_time):
137+
http_client = DatabricksHttpClient.get_instance()
132138
settings = Mock(
133139
link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True
134140
)
135141
result_link = Mock(bytesNum=100, expiryTime=1001)
136-
mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
137142

138-
d = downloader.ResultSetDownloadHandler(
139-
settings, result_link, ssl_options=SSLOptions()
140-
)
141-
with self.assertRaises(TimeoutError):
142-
d.run()
143+
with patch.object(http_client, "execute", side_effect=TimeoutError("foo")):
144+
d = downloader.ResultSetDownloadHandler(
145+
settings, result_link, ssl_options=SSLOptions()
146+
)
147+
with self.assertRaises(TimeoutError):
148+
d.run()

0 commit comments

Comments
 (0)