Skip to content

Commit a8004a0

Browse files
fix tests
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 8cf118f commit a8004a0

File tree

12 files changed

+126
-102
lines changed

12 files changed

+126
-102
lines changed

src/databricks/sql/backend/filters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,15 @@ def _filter_sea_result_set(
6969
command_id=command_id,
7070
status=result_set.status,
7171
description=result_set.description,
72-
has_more_rows=result_set.has_more_rows,
7372
has_been_closed_server_side=result_set.has_been_closed_server_side,
74-
lz4_compressed=False,
73+
lz4_compressed=result_set.lz4_compressed,
74+
arrow_schema_bytes=result_set.arrow_schema_bytes,
7575
is_staging_operation=False,
7676
)
7777

7878
# Create a new ResultData object with filtered data
7979
from databricks.sql.backend.sea.models.base import ResultData
80-
80+
8181
result_data = ResultData(data=filtered_rows, external_links=None)
8282

8383
# Create a new SeaResultSet with the filtered data
@@ -89,7 +89,7 @@ def _filter_sea_result_set(
8989
arraysize=result_set.arraysize,
9090
result_data=result_data,
9191
)
92-
92+
9393
return filtered_result_set
9494

9595
@staticmethod

src/databricks/sql/backend/sea/backend.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import re
55
from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING, Set
66

7-
from databricks.sql.backend.sea.utils.constants import ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP
7+
from databricks.sql.backend.sea.utils.constants import (
8+
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
9+
)
810

911
if TYPE_CHECKING:
1012
from databricks.sql.client import Cursor
@@ -72,6 +74,8 @@ def _filter_session_configuration(
7274
)
7375

7476
return filtered_session_configuration
77+
78+
7579
class SeaDatabricksClient(DatabricksClient):
7680
"""
7781
Statement Execution API (SEA) implementation of the DatabricksClient interface.
@@ -89,7 +93,7 @@ class SeaDatabricksClient(DatabricksClient):
8993
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
9094
CHUNKS_PATH_WITH_ID = STATEMENT_PATH + "/{}/result/chunks"
9195
CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
92-
96+
9397
def __init__(
9498
self,
9599
server_hostname: str,
@@ -502,7 +506,7 @@ def _results_message_to_execute_response(self, sea_response, command_id):
502506
# Initialize result_data_obj and manifest_obj
503507
result_data_obj = None
504508
manifest_obj = None
505-
509+
506510
result_data = sea_response.get("result", {})
507511
if result_data:
508512
# Convert external links
@@ -554,7 +558,7 @@ def _results_message_to_execute_response(self, sea_response, command_id):
554558
arrow_schema_bytes=schema_bytes,
555559
result_format=manifest_data.get("format"),
556560
)
557-
561+
558562
return execute_response, result_data_obj, manifest_obj
559563

560564
def execute_command(
@@ -776,9 +780,11 @@ def get_execution_result(
776780
from databricks.sql.result_set import SeaResultSet
777781

778782
# Convert the response to an ExecuteResponse and extract result data
779-
execute_response, result_data, manifest = self._results_message_to_execute_response(
780-
response_data, command_id
781-
)
783+
(
784+
execute_response,
785+
result_data,
786+
manifest,
787+
) = self._results_message_to_execute_response(response_data, command_id)
782788

783789
return SeaResultSet(
784790
connection=cursor.connection,

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse":
4040
)
4141

4242
state = CommandState.from_sea_state(status_data.get("state", ""))
43-
if state is None:
43+
if state is None:
4444
raise ValueError(f"Invalid state: {status_data.get('state', '')}")
4545

4646
status = StatementStatus(
@@ -124,7 +124,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
124124
)
125125

126126
state = CommandState.from_sea_state(status_data.get("state", ""))
127-
if state is None:
127+
if state is None:
128128
raise ValueError(f"Invalid state: {status_data.get('state', '')}")
129129

130130
status = StatementStatus(

src/databricks/sql/backend/thrift_backend.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,12 +1179,7 @@ def _handle_execute_response(self, resp, cursor):
11791179
resp.directResults and resp.directResults.operationStatus,
11801180
)
11811181

1182-
(
1183-
execute_response,
1184-
arrow_schema_bytes,
1185-
) = self._results_message_to_execute_response(resp, final_operation_state)
1186-
execute_response.command_id = command_id
1187-
return execute_response, arrow_schema_bytes
1182+
return self._results_message_to_execute_response(resp, final_operation_state)
11881183

11891184
def _handle_execute_response_async(self, resp, cursor):
11901185
command_id = CommandId.from_thrift_handle(resp.operationHandle)

src/databricks/sql/result_set.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
import pandas
77

88
from databricks.sql.backend.sea.backend import SeaDatabricksClient
9-
from databricks.sql.backend.sea.models.base import ExternalLink, ResultData, ResultManifest
9+
from databricks.sql.backend.sea.models.base import (
10+
ExternalLink,
11+
ResultData,
12+
ResultManifest,
13+
)
1014
from databricks.sql.cloud_fetch_queue import SeaCloudFetchQueue
1115
from databricks.sql.utils import SeaResultSetQueueFactory
1216

@@ -44,10 +48,11 @@ def __init__(
4448
command_id: CommandId,
4549
status: CommandState,
4650
has_been_closed_server_side: bool = False,
47-
has_more_rows: bool = False,
4851
results_queue=None,
4952
description=None,
5053
is_staging_operation: bool = False,
54+
lz4_compressed: bool = False,
55+
arrow_schema_bytes: bytes = b"",
5156
):
5257
"""
5358
A ResultSet manages the results of a single command.
@@ -75,9 +80,10 @@ def __init__(
7580
self.command_id = command_id
7681
self.status = status
7782
self.has_been_closed_server_side = has_been_closed_server_side
78-
self.has_more_rows = has_more_rows
7983
self.results = results_queue
8084
self._is_staging_operation = is_staging_operation
85+
self.lz4_compressed = lz4_compressed
86+
self._arrow_schema_bytes = arrow_schema_bytes
8187

8288
def __iter__(self):
8389
while True:
@@ -181,9 +187,8 @@ def __init__(
181187
has_more_rows: Whether there are more rows to fetch
182188
"""
183189
# Initialize ThriftResultSet-specific attributes
184-
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
185190
self._use_cloud_fetch = use_cloud_fetch
186-
self.lz4_compressed = execute_response.lz4_compressed
191+
self.has_more_rows = has_more_rows
187192

188193
# Build the results queue if t_row_set is provided
189194
results_queue = None
@@ -210,10 +215,11 @@ def __init__(
210215
command_id=execute_response.command_id,
211216
status=execute_response.status,
212217
has_been_closed_server_side=execute_response.has_been_closed_server_side,
213-
has_more_rows=has_more_rows,
214218
results_queue=results_queue,
215219
description=execute_response.description,
216220
is_staging_operation=execute_response.is_staging_operation,
221+
lz4_compressed=execute_response.lz4_compressed,
222+
arrow_schema_bytes=execute_response.arrow_schema_bytes,
217223
)
218224

219225
# Initialize results queue if not provided
@@ -442,6 +448,7 @@ def map_col_type(type_):
442448
for column in table_schema_message.columns
443449
]
444450

451+
445452
class SeaResultSet(ResultSet):
446453
"""ResultSet implementation for SEA backend."""
447454

@@ -473,24 +480,26 @@ def __init__(
473480
if execute_response.command_id
474481
else None
475482
)
476-
483+
477484
# Build the results queue
478485
results_queue = None
479-
486+
480487
if result_data:
481488
from typing import cast, List
482-
489+
483490
# Convert description to the expected format
484491
desc = None
485492
if execute_response.description:
486493
desc = cast(List[Tuple[Any, ...]], execute_response.description)
487-
494+
488495
results_queue = SeaResultSetQueueFactory.build_queue(
489496
result_data,
490497
manifest,
491498
str(self.statement_id),
492499
description=desc,
493-
schema_bytes=execute_response.arrow_schema_bytes if execute_response.arrow_schema_bytes else None,
500+
schema_bytes=execute_response.arrow_schema_bytes
501+
if execute_response.arrow_schema_bytes
502+
else None,
494503
max_download_threads=sea_client.max_download_threads,
495504
ssl_options=sea_client.ssl_options,
496505
sea_client=sea_client,
@@ -506,9 +515,10 @@ def __init__(
506515
command_id=execute_response.command_id,
507516
status=execute_response.status,
508517
has_been_closed_server_side=execute_response.has_been_closed_server_side,
509-
has_more_rows=execute_response.has_more_rows,
510518
description=execute_response.description,
511519
is_staging_operation=execute_response.is_staging_operation,
520+
lz4_compressed=execute_response.lz4_compressed,
521+
arrow_schema_bytes=execute_response.arrow_schema_bytes,
512522
)
513523

514524
# Initialize queue for result data if not provided

src/databricks/sql/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,7 @@ def build_queue(
177177
"SEA client is required for EXTERNAL_LINKS disposition"
178178
)
179179
if not manifest:
180-
raise ValueError(
181-
"Manifest is required for EXTERNAL_LINKS disposition"
182-
)
180+
raise ValueError("Manifest is required for EXTERNAL_LINKS disposition")
183181

184182
return SeaCloudFetchQueue(
185183
initial_links=sea_result_data.external_links,

tests/unit/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class):
566566
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
567567
def test_staging_operation_response_is_handled(
568568
self,
569-
mock_thrift_client_class,
569+
mock_client_class,
570570
mock_handle_staging_operation,
571571
mock_execute_response,
572572
):

tests/unit/test_fetches_bench.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@ def make_dummy_result_set_from_initial_results(arrow_table):
3636
execute_response=ExecuteResponse(
3737
status=None,
3838
has_been_closed_server_side=True,
39-
has_more_rows=False,
4039
description=Mock(),
4140
command_id=None,
42-
arrow_queue=arrow_queue,
43-
arrow_schema=arrow_table.schema,
41+
arrow_schema_bytes=arrow_table.schema,
4442
),
4543
)
4644
rs.description = [

tests/unit/test_parameters.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,7 @@ class TestSessionHandleChecks(object):
6464
],
6565
)
6666
def test_get_protocol_version_fallback_behavior(self, test_input, expected):
67-
properties = (
68-
{"serverProtocolVersion": test_input.serverProtocolVersion}
69-
if test_input.serverProtocolVersion
70-
else {}
71-
)
72-
session_id = SessionId.from_thrift_handle(test_input.sessionHandle, properties)
73-
assert Connection.get_protocol_version(session_id) == expected
67+
assert Connection.get_protocol_version(test_input) == expected
7468

7569
@pytest.mark.parametrize(
7670
"test_input,expected",

0 commit comments

Comments
 (0)