Skip to content

Commit 12a5ff8

Browse files
fix merge artifacts
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent ea8ae9f commit 12a5ff8

File tree

7 files changed

+79
-41
lines changed

7 files changed

+79
-41
lines changed

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,43 @@ def _create_filtered_manifest(result_set: SeaResultSet, new_row_count: int):
7070
result_set: Original result set to copy manifest from
7171
new_row_count: New total row count for filtered data
7272
73-
from databricks.sql.backend.sea.backend import SeaDatabricksClient
73+
Returns:
74+
Updated manifest copy
75+
"""
76+
filtered_manifest = deepcopy(result_set.manifest)
77+
filtered_manifest.total_row_count = new_row_count
78+
return filtered_manifest
79+
80+
@staticmethod
81+
def _create_filtered_result_set(
82+
result_set: SeaResultSet,
83+
result_data: ResultData,
84+
row_count: int,
85+
) -> "SeaResultSet":
86+
"""
87+
Create a new filtered SeaResultSet with the provided data.
88+
89+
Args:
90+
result_set: Original result set to copy parameters from
91+
result_data: New result data for the filtered set
92+
row_count: Number of rows in the filtered data
93+
94+
Returns:
95+
New filtered SeaResultSet
96+
"""
7497
from databricks.sql.backend.sea.result_set import SeaResultSet
7598

76-
# Create a new SeaResultSet with the filtered data
77-
manifest = result_set.manifest
78-
manifest.total_row_count = len(filtered_rows)
99+
execute_response = ResultSetFilter._create_execute_response(result_set)
100+
filtered_manifest = ResultSetFilter._create_filtered_manifest(
101+
result_set, row_count
102+
)
79103

80-
filtered_result_set = SeaResultSet(
104+
return SeaResultSet(
81105
connection=result_set.connection,
82106
execute_response=execute_response,
83107
sea_client=cast(SeaDatabricksClient, result_set.backend),
84108
result_data=result_data,
85-
manifest=manifest,
109+
manifest=filtered_manifest,
86110
buffer_size_bytes=result_set.buffer_size_bytes,
87111
arraysize=result_set.arraysize,
88112
)

src/databricks/sql/backend/thrift_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,6 @@ def get_execution_result(
895895
connection=cursor.connection,
896896
execute_response=execute_response,
897897
thrift_client=self,
898-
session_id_hex=self._session_id_hex,
899898
buffer_size_bytes=cursor.buffer_size_bytes,
900899
arraysize=cursor.arraysize,
901900
use_cloud_fetch=cursor.connection.use_cloud_fetch,

tests/unit/test_client.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class):
127127
connection=connection,
128128
execute_response=mock_execute_response,
129129
thrift_client=mock_backend,
130-
session_id_hex=Mock(),
131130
)
132131

133132
# Mock execute_command to return our real result set
@@ -187,7 +186,6 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
187186
connection=mock_connection,
188187
execute_response=Mock(),
189188
thrift_client=mock_backend,
190-
session_id_hex=Mock(),
191189
)
192190
result_set.results = mock_results
193191

@@ -217,7 +215,6 @@ def test_closing_result_set_hard_closes_commands(self):
217215
mock_connection,
218216
mock_results_response,
219217
mock_thrift_backend,
220-
session_id_hex=Mock(),
221218
)
222219
result_set.results = mock_results
223220

@@ -265,9 +262,7 @@ def test_negative_fetch_throws_exception(self):
265262
mock_backend = Mock()
266263
mock_backend.fetch_results.return_value = (Mock(), False, 0)
267264

268-
result_set = ThriftResultSet(
269-
Mock(), Mock(), mock_backend
270-
)
265+
result_set = ThriftResultSet(Mock(), Mock(), mock_backend)
271266

272267
with self.assertRaises(ValueError) as e:
273268
result_set.fetchmany(-1)

tests/unit/test_downloader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@ class DownloaderTests(unittest.TestCase):
2626
def _setup_time_mock_for_download(self, mock_time, end_time):
2727
"""Helper to setup time mock that handles logging system calls."""
2828
call_count = [0]
29+
2930
def time_side_effect():
3031
call_count[0] += 1
3132
if call_count[0] <= 2: # First two calls (validation, start_time)
3233
return 1000
3334
else: # All subsequent calls (logging, duration calculation)
3435
return end_time
36+
3537
mock_time.side_effect = time_side_effect
3638

3739
@patch("time.time", return_value=1000)
@@ -104,7 +106,7 @@ def test_run_get_response_not_ok(self, mock_time):
104106
@patch("time.time")
105107
def test_run_uncompressed_successful(self, mock_time):
106108
self._setup_time_mock_for_download(mock_time, 1000.5)
107-
109+
108110
http_client = DatabricksHttpClient.get_instance()
109111
file_bytes = b"1234567890" * 10
110112
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
@@ -133,7 +135,7 @@ def test_run_uncompressed_successful(self, mock_time):
133135
@patch("time.time")
134136
def test_run_compressed_successful(self, mock_time):
135137
self._setup_time_mock_for_download(mock_time, 1000.2)
136-
138+
137139
http_client = DatabricksHttpClient.get_instance()
138140
file_bytes = b"1234567890" * 10
139141
compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'

tests/unit/test_fetches.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def make_dummy_result_set_from_initial_results(initial_results):
6363
),
6464
thrift_client=mock_thrift_backend,
6565
t_row_set=None,
66-
session_id_hex=Mock(),
6766
)
6867
return rs
6968

@@ -108,7 +107,6 @@ def fetch_results(
108107
is_staging_operation=False,
109108
),
110109
thrift_client=mock_thrift_backend,
111-
session_id_hex=Mock(),
112110
)
113111
return rs
114112

tests/unit/test_filters.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,21 +132,25 @@ def test_filter_tables_by_type(self):
132132
ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, table_types)
133133
args, kwargs = mock_filter.call_args
134134
self.assertEqual(args[0], self.mock_sea_result_set)
135-
self.assertEqual(args[1], 5) # Table type column index
136-
self.assertEqual(args[2], table_types)
135+
self.assertEqual(kwargs.get("column_index"), 5) # Table type column index
136+
self.assertEqual(kwargs.get("allowed_values"), table_types)
137137
self.assertEqual(kwargs.get("case_sensitive"), True)
138138

139139
# Case 2: Default table types (None or empty list)
140140
with patch.object(ResultSetFilter, "_filter_json_result_set") as mock_filter:
141141
# Test with None
142142
ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, None)
143143
args, kwargs = mock_filter.call_args
144-
self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"])
144+
self.assertEqual(
145+
kwargs.get("allowed_values"), ["TABLE", "VIEW", "SYSTEM TABLE"]
146+
)
145147

146148
# Test with empty list
147149
ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, [])
148150
args, kwargs = mock_filter.call_args
149-
self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"])
151+
self.assertEqual(
152+
kwargs.get("allowed_values"), ["TABLE", "VIEW", "SYSTEM TABLE"]
153+
)
150154

151155

152156
if __name__ == "__main__":

tests/unit/test_telemetry_retry.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
77
from databricks.sql.auth.retry import DatabricksRetryPolicy
88

9-
PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'
9+
PATCH_TARGET = "urllib3.connectionpool.HTTPSConnectionPool._get_conn"
10+
1011

1112
def create_mock_conn(responses):
1213
"""Creates a mock connection object whose getresponse() method yields a series of responses."""
@@ -16,15 +17,18 @@ def create_mock_conn(responses):
1617
mock_http_response = MagicMock()
1718
mock_http_response.status = resp.get("status")
1819
mock_http_response.headers = resp.get("headers", {})
19-
body = resp.get("body", b'{}')
20+
body = resp.get("body", b"{}")
2021
mock_http_response.fp = io.BytesIO(body)
22+
2123
def release():
2224
mock_http_response.fp.close()
25+
2326
mock_http_response.release_conn = release
2427
mock_http_responses.append(mock_http_response)
2528
mock_conn.getresponse.side_effect = mock_http_responses
2629
return mock_conn
2730

31+
2832
class TestTelemetryClientRetries:
2933
@pytest.fixture(autouse=True)
3034
def setup_and_teardown(self):
@@ -49,28 +53,28 @@ def get_client(self, session_id, num_retries=3):
4953
host_url="test.databricks.com",
5054
)
5155
client = TelemetryClientFactory.get_telemetry_client(session_id)
52-
56+
5357
retry_policy = DatabricksRetryPolicy(
5458
delay_min=0.01,
5559
delay_max=0.02,
5660
stop_after_attempts_duration=2.0,
57-
stop_after_attempts_count=num_retries,
61+
stop_after_attempts_count=num_retries,
5862
delay_default=0.1,
5963
force_dangerous_codes=[],
60-
urllib3_kwargs={'total': num_retries}
64+
urllib3_kwargs={"total": num_retries},
6165
)
6266
adapter = client._http_client.session.adapters.get("https://")
6367
adapter.max_retries = retry_policy
6468
return client
6569

6670
@pytest.mark.parametrize(
67-
"status_code, description",
68-
[
69-
(401, "Unauthorized"),
70-
(403, "Forbidden"),
71-
(501, "Not Implemented"),
72-
(200, "Success"),
73-
],
71+
"status_code, description",
72+
[
73+
(401, "Unauthorized"),
74+
(403, "Forbidden"),
75+
(501, "Not Implemented"),
76+
(200, "Success"),
77+
],
7478
)
7579
def test_non_retryable_status_codes_are_not_retried(self, status_code, description):
7680
"""
@@ -80,7 +84,9 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti
8084
client = self.get_client(f"session-{status_code}")
8185
mock_responses = [{"status": status_code}]
8286

83-
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
87+
with patch(
88+
PATCH_TARGET, return_value=create_mock_conn(mock_responses)
89+
) as mock_get_conn:
8490
client.export_failure_log("TestError", "Test message")
8591
TelemetryClientFactory.close(client._session_id_hex)
8692

@@ -92,16 +98,26 @@ def test_exceeds_retry_count_limit(self):
9298
Verifies that the client respects the Retry-After header and retries on 429, 502, 503.
9399
"""
94100
num_retries = 3
95-
expected_total_calls = num_retries + 1
101+
expected_total_calls = num_retries + 1
96102
retry_after = 1
97103
client = self.get_client("session-exceed-limit", num_retries=num_retries)
98-
mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}]
99-
100-
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
104+
mock_responses = [
105+
{"status": 503, "headers": {"Retry-After": str(retry_after)}},
106+
{"status": 429},
107+
{"status": 502},
108+
{"status": 503},
109+
]
110+
111+
with patch(
112+
PATCH_TARGET, return_value=create_mock_conn(mock_responses)
113+
) as mock_get_conn:
101114
start_time = time.time()
102115
client.export_failure_log("TestError", "Test message")
103116
TelemetryClientFactory.close(client._session_id_hex)
104117
end_time = time.time()
105-
106-
assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls
107-
assert end_time - start_time > retry_after
118+
119+
assert (
120+
mock_get_conn.return_value.getresponse.call_count
121+
== expected_total_calls
122+
)
123+
assert end_time - start_time > retry_after

0 commit comments

Comments
 (0)