Skip to content

Commit 62ed2a2

Browse files
description, partial results (small fixes)
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent aab8ce5 commit 62ed2a2

File tree

9 files changed

+78
-45
lines changed

9 files changed

+78
-45
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,14 +364,14 @@ def __init__(
364364
# Initialize table and position
365365
self.table = self._create_next_table()
366366

367-
def _create_next_table(self) -> Union["pyarrow.Table", None]:
367+
def _create_next_table(self) -> "pyarrow.Table":
368368
"""Create next table by retrieving the logical next downloaded file."""
369369
if self.link_fetcher is None:
370-
return None
370+
return self._create_empty_table()
371371

372372
chunk_link = self.link_fetcher.get_chunk_link(self._current_chunk_index)
373373
if chunk_link is None:
374-
return None
374+
return self._create_empty_table()
375375

376376
row_offset = chunk_link.row_offset
377377
# NOTE: link has already been submitted to download manager at this point

src/databricks/sql/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
275275

276276
logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows))
277277
results = self.table.slice(0, 0)
278+
partial_result_chunks = [results]
278279
while num_rows > 0 and self.table.num_rows > 0:
279280
# Replace current table with the next table if we are at the end of the current table
280281
if self.table_row_index == self.table.num_rows:
@@ -300,6 +301,7 @@ def remaining_rows(self) -> "pyarrow.Table":
300301
"""
301302

302303
results = self.table.slice(0, 0)
304+
partial_result_chunks = [results]
303305
while self.table.num_rows > 0:
304306
table_slice = self.table.slice(
305307
self.table_row_index, self.table.num_rows - self.table_row_index
@@ -386,6 +388,8 @@ def __init__(
386388
chunk_id=chunk_id,
387389
)
388390

391+
self.num_links_downloaded = 0
392+
389393
self.start_row_index = start_row_offset
390394
self.result_links = result_links or []
391395
self.session_id_hex = session_id_hex

tests/unit/test_client.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def new(cls):
4646
is_staging_operation=False,
4747
command_id=None,
4848
has_been_closed_server_side=True,
49-
is_direct_results=True,
49+
has_more_rows=True,
5050
lz4_compressed=True,
5151
arrow_schema_bytes=b"schema",
5252
)
@@ -266,9 +266,7 @@ def test_negative_fetch_throws_exception(self):
266266
mock_backend = Mock()
267267
mock_backend.fetch_results.return_value = (Mock(), False, 0)
268268

269-
result_set = ThriftResultSet(
270-
Mock(), Mock(), mock_backend
271-
)
269+
result_set = ThriftResultSet(Mock(), Mock(), mock_backend)
272270

273271
with self.assertRaises(ValueError) as e:
274272
result_set.fetchmany(-1)

tests/unit/test_cloud_fetch_queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def test_initializer_adds_links(self, mock_create_next_table):
6969
session_id_hex=Mock(),
7070
statement_id=Mock(),
7171
chunk_id=0,
72+
description=description,
7273
)
7374

7475
assert (

tests/unit/test_downloader.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,26 @@ def test_run_get_response_not_ok(self, mock_time):
7373
settings.use_proxy = False
7474
result_link = Mock(expiryTime=1001)
7575

76-
d = downloader.ResultSetDownloadHandler(
77-
settings,
78-
result_link,
79-
ssl_options=SSLOptions(),
80-
chunk_id=0,
81-
session_id_hex=Mock(),
82-
statement_id=Mock(),
76+
# Create a mock response with 404 status
77+
mock_response = create_response(status_code=404, _content=b"Not Found")
78+
mock_response.raise_for_status = Mock(
79+
side_effect=requests.exceptions.HTTPError("404")
8380
)
84-
with self.assertRaises(requests.exceptions.HTTPError) as context:
85-
d.run()
86-
self.assertTrue("404" in str(context.exception))
81+
82+
with patch.object(http_client, "execute") as mock_execute:
83+
mock_execute.return_value.__enter__.return_value = mock_response
84+
85+
d = downloader.ResultSetDownloadHandler(
86+
settings,
87+
result_link,
88+
ssl_options=SSLOptions(),
89+
chunk_id=0,
90+
session_id_hex=Mock(),
91+
statement_id=Mock(),
92+
)
93+
with self.assertRaises(requests.exceptions.HTTPError) as context:
94+
d.run()
95+
self.assertTrue("404" in str(context.exception))
8796

8897
@patch("time.time", return_value=1000)
8998
def test_run_uncompressed_successful(self, mock_time):

tests/unit/test_fetches_bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def make_dummy_result_set_from_initial_results(arrow_table):
3636
execute_response=ExecuteResponse(
3737
status=None,
3838
has_been_closed_server_side=True,
39-
is_direct_results=False,
39+
has_more_rows=False,
4040
description=Mock(),
4141
command_id=None,
4242
arrow_schema_bytes=arrow_table.schema,

tests/unit/test_sea_queue.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
import threading
2828
import time
2929

30+
try:
31+
import pyarrow as pa
32+
except ImportError:
33+
pa = None
34+
3035

3136
class TestJsonQueue:
3237
"""Test suite for the JsonQueue class."""

tests/unit/test_telemetry_retry.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
77
from databricks.sql.auth.retry import DatabricksRetryPolicy
88

9-
PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'
9+
PATCH_TARGET = "urllib3.connectionpool.HTTPSConnectionPool._get_conn"
10+
1011

1112
def create_mock_conn(responses):
1213
"""Creates a mock connection object whose getresponse() method yields a series of responses."""
@@ -16,15 +17,18 @@ def create_mock_conn(responses):
1617
mock_http_response = MagicMock()
1718
mock_http_response.status = resp.get("status")
1819
mock_http_response.headers = resp.get("headers", {})
19-
body = resp.get("body", b'{}')
20+
body = resp.get("body", b"{}")
2021
mock_http_response.fp = io.BytesIO(body)
22+
2123
def release():
2224
mock_http_response.fp.close()
25+
2326
mock_http_response.release_conn = release
2427
mock_http_responses.append(mock_http_response)
2528
mock_conn.getresponse.side_effect = mock_http_responses
2629
return mock_conn
2730

31+
2832
class TestTelemetryClientRetries:
2933
@pytest.fixture(autouse=True)
3034
def setup_and_teardown(self):
@@ -49,28 +53,28 @@ def get_client(self, session_id, num_retries=3):
4953
host_url="test.databricks.com",
5054
)
5155
client = TelemetryClientFactory.get_telemetry_client(session_id)
52-
56+
5357
retry_policy = DatabricksRetryPolicy(
5458
delay_min=0.01,
5559
delay_max=0.02,
5660
stop_after_attempts_duration=2.0,
57-
stop_after_attempts_count=num_retries,
61+
stop_after_attempts_count=num_retries,
5862
delay_default=0.1,
5963
force_dangerous_codes=[],
60-
urllib3_kwargs={'total': num_retries}
64+
urllib3_kwargs={"total": num_retries},
6165
)
6266
adapter = client._http_client.session.adapters.get("https://")
6367
adapter.max_retries = retry_policy
6468
return client
6569

6670
@pytest.mark.parametrize(
67-
"status_code, description",
68-
[
69-
(401, "Unauthorized"),
70-
(403, "Forbidden"),
71-
(501, "Not Implemented"),
72-
(200, "Success"),
73-
],
71+
"status_code, description",
72+
[
73+
(401, "Unauthorized"),
74+
(403, "Forbidden"),
75+
(501, "Not Implemented"),
76+
(200, "Success"),
77+
],
7478
)
7579
def test_non_retryable_status_codes_are_not_retried(self, status_code, description):
7680
"""
@@ -80,7 +84,9 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti
8084
client = self.get_client(f"session-{status_code}")
8185
mock_responses = [{"status": status_code}]
8286

83-
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
87+
with patch(
88+
PATCH_TARGET, return_value=create_mock_conn(mock_responses)
89+
) as mock_get_conn:
8490
client.export_failure_log("TestError", "Test message")
8591
TelemetryClientFactory.close(client._session_id_hex)
8692

@@ -92,16 +98,26 @@ def test_exceeds_retry_count_limit(self):
9298
Verifies that the client respects the Retry-After header and retries on 429, 502, 503.
9399
"""
94100
num_retries = 3
95-
expected_total_calls = num_retries + 1
101+
expected_total_calls = num_retries + 1
96102
retry_after = 1
97103
client = self.get_client("session-exceed-limit", num_retries=num_retries)
98-
mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}]
99-
100-
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
104+
mock_responses = [
105+
{"status": 503, "headers": {"Retry-After": str(retry_after)}},
106+
{"status": 429},
107+
{"status": 502},
108+
{"status": 503},
109+
]
110+
111+
with patch(
112+
PATCH_TARGET, return_value=create_mock_conn(mock_responses)
113+
) as mock_get_conn:
101114
start_time = time.time()
102115
client.export_failure_log("TestError", "Test message")
103116
TelemetryClientFactory.close(client._session_id_hex)
104117
end_time = time.time()
105-
106-
assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls
107-
assert end_time - start_time > retry_after
118+
119+
assert (
120+
mock_get_conn.return_value.getresponse.call_count
121+
== expected_total_calls
122+
)
123+
assert end_time - start_time > retry_after

tests/unit/test_thrift_backend.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,10 +1016,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class):
10161016
def test_handle_execute_response_reads_has_more_rows_in_direct_results(
10171017
self, tcli_service_class, build_queue
10181018
):
1019-
for is_direct_results, resp_type in itertools.product(
1019+
for has_more_rows, resp_type in itertools.product(
10201020
[True, False], self.execute_response_types
10211021
):
1022-
with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type):
1022+
with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type):
10231023
tcli_service_instance = tcli_service_class.return_value
10241024
results_mock = Mock()
10251025
results_mock.startRowOffset = 0
@@ -1031,7 +1031,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results(
10311031
resultSetMetadata=self.metadata_resp,
10321032
resultSet=ttypes.TFetchResultsResp(
10331033
status=self.okay_status,
1034-
hasMoreRows=is_direct_results,
1034+
hasMoreRows=has_more_rows,
10351035
results=results_mock,
10361036
),
10371037
closeOperation=Mock(),
@@ -1062,10 +1062,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results(
10621062
def test_handle_execute_response_reads_has_more_rows_in_result_response(
10631063
self, tcli_service_class, build_queue
10641064
):
1065-
for is_direct_results, resp_type in itertools.product(
1065+
for has_more_rows, resp_type in itertools.product(
10661066
[True, False], self.execute_response_types
10671067
):
1068-
with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type):
1068+
with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type):
10691069
tcli_service_instance = tcli_service_class.return_value
10701070
results_mock = MagicMock()
10711071
results_mock.startRowOffset = 0
@@ -1078,7 +1078,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
10781078

10791079
fetch_results_resp = ttypes.TFetchResultsResp(
10801080
status=self.okay_status,
1081-
hasMoreRows=is_direct_results,
1081+
hasMoreRows=has_more_rows,
10821082
results=results_mock,
10831083
resultSetMetadata=ttypes.TGetResultSetMetadataResp(
10841084
resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET
@@ -1112,7 +1112,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
11121112
chunk_id=0,
11131113
)
11141114

1115-
self.assertEqual(is_direct_results, has_more_rows_resp)
1115+
self.assertEqual(has_more_rows, has_more_rows_resp)
11161116

11171117
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
11181118
def test_arrow_batches_row_count_are_respected(self, tcli_service_class):

0 commit comments

Comments
 (0)