Skip to content

Commit e3ee4e4

Browse files
move arrow_schema_bytes back into ExecuteResult
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 139e246 commit e3ee4e4

File tree

4 files changed

+116
-53
lines changed

4 files changed

+116
-53
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -801,18 +801,16 @@ def _results_message_to_execute_response(self, resp, operation_state):
801801
if status is None:
802802
raise ValueError(f"Unknown command state: {operation_state}")
803803

804-
return (
805-
ExecuteResponse(
806-
command_id=command_id,
807-
status=status,
808-
description=description,
809-
has_more_rows=has_more_rows,
810-
results_queue=arrow_queue_opt,
811-
has_been_closed_server_side=has_been_closed_server_side,
812-
lz4_compressed=lz4_compressed,
813-
is_staging_operation=is_staging_operation,
814-
),
815-
schema_bytes,
804+
return ExecuteResponse(
805+
command_id=command_id,
806+
status=status,
807+
description=description,
808+
has_more_rows=has_more_rows,
809+
results_queue=arrow_queue_opt,
810+
has_been_closed_server_side=has_been_closed_server_side,
811+
lz4_compressed=lz4_compressed,
812+
is_staging_operation=is_staging_operation,
813+
arrow_schema_bytes=schema_bytes,
816814
)
817815

818816
def get_execution_result(
@@ -877,6 +875,7 @@ def get_execution_result(
877875
has_been_closed_server_side=False,
878876
lz4_compressed=lz4_compressed,
879877
is_staging_operation=is_staging_operation,
878+
arrow_schema_bytes=schema_bytes,
880879
)
881880

882881
return ThriftResultSet(
@@ -886,7 +885,6 @@ def get_execution_result(
886885
buffer_size_bytes=cursor.buffer_size_bytes,
887886
arraysize=cursor.arraysize,
888887
use_cloud_fetch=cursor.connection.use_cloud_fetch,
889-
arrow_schema_bytes=schema_bytes,
890888
)
891889

892890
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
@@ -999,9 +997,7 @@ def execute_command(
999997
self._handle_execute_response_async(resp, cursor)
1000998
return None
1001999
else:
1002-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1003-
resp, cursor
1004-
)
1000+
execute_response = self._handle_execute_response(resp, cursor)
10051001

10061002
return ThriftResultSet(
10071003
connection=cursor.connection,
@@ -1010,7 +1006,6 @@ def execute_command(
10101006
buffer_size_bytes=max_bytes,
10111007
arraysize=max_rows,
10121008
use_cloud_fetch=use_cloud_fetch,
1013-
arrow_schema_bytes=arrow_schema_bytes,
10141009
)
10151010

10161011
def get_catalogs(
@@ -1032,9 +1027,7 @@ def get_catalogs(
10321027
)
10331028
resp = self.make_request(self._client.GetCatalogs, req)
10341029

1035-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1036-
resp, cursor
1037-
)
1030+
execute_response = self._handle_execute_response(resp, cursor)
10381031

10391032
return ThriftResultSet(
10401033
connection=cursor.connection,
@@ -1043,7 +1036,6 @@ def get_catalogs(
10431036
buffer_size_bytes=max_bytes,
10441037
arraysize=max_rows,
10451038
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1046-
arrow_schema_bytes=arrow_schema_bytes,
10471039
)
10481040

10491041
def get_schemas(
@@ -1069,9 +1061,7 @@ def get_schemas(
10691061
)
10701062
resp = self.make_request(self._client.GetSchemas, req)
10711063

1072-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1073-
resp, cursor
1074-
)
1064+
execute_response = self._handle_execute_response(resp, cursor)
10751065

10761066
return ThriftResultSet(
10771067
connection=cursor.connection,
@@ -1080,7 +1070,6 @@ def get_schemas(
10801070
buffer_size_bytes=max_bytes,
10811071
arraysize=max_rows,
10821072
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1083-
arrow_schema_bytes=arrow_schema_bytes,
10841073
)
10851074

10861075
def get_tables(
@@ -1110,9 +1099,7 @@ def get_tables(
11101099
)
11111100
resp = self.make_request(self._client.GetTables, req)
11121101

1113-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1114-
resp, cursor
1115-
)
1102+
execute_response = self._handle_execute_response(resp, cursor)
11161103

11171104
return ThriftResultSet(
11181105
connection=cursor.connection,
@@ -1121,7 +1108,6 @@ def get_tables(
11211108
buffer_size_bytes=max_bytes,
11221109
arraysize=max_rows,
11231110
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1124-
arrow_schema_bytes=arrow_schema_bytes,
11251111
)
11261112

11271113
def get_columns(
@@ -1151,9 +1137,7 @@ def get_columns(
11511137
)
11521138
resp = self.make_request(self._client.GetColumns, req)
11531139

1154-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1155-
resp, cursor
1156-
)
1140+
execute_response = self._handle_execute_response(resp, cursor)
11571141

11581142
return ThriftResultSet(
11591143
connection=cursor.connection,
@@ -1162,7 +1146,6 @@ def get_columns(
11621146
buffer_size_bytes=max_bytes,
11631147
arraysize=max_rows,
11641148
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1165-
arrow_schema_bytes=arrow_schema_bytes,
11661149
)
11671150

11681151
def _handle_execute_response(self, resp, cursor):
@@ -1176,11 +1159,10 @@ def _handle_execute_response(self, resp, cursor):
11761159
resp.directResults and resp.directResults.operationStatus,
11771160
)
11781161

1179-
(
1180-
execute_response,
1181-
arrow_schema_bytes,
1182-
) = self._results_message_to_execute_response(resp, final_operation_state)
1183-
return execute_response, arrow_schema_bytes
1162+
execute_response = self._results_message_to_execute_response(
1163+
resp, final_operation_state
1164+
)
1165+
return execute_response
11841166

11851167
def _handle_execute_response_async(self, resp, cursor):
11861168
command_id = CommandId.from_thrift_handle(resp.operationHandle)

src/databricks/sql/backend/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,3 +431,4 @@ class ExecuteResponse:
431431
has_been_closed_server_side: bool = False
432432
lz4_compressed: bool = True
433433
is_staging_operation: bool = False
434+
arrow_schema_bytes: Optional[bytes] = None

src/databricks/sql/result_set.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ def __init__(
157157
buffer_size_bytes: int = 104857600,
158158
arraysize: int = 10000,
159159
use_cloud_fetch: bool = True,
160-
arrow_schema_bytes: Optional[bytes] = None,
161160
):
162161
"""
163162
Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient.
@@ -169,10 +168,9 @@ def __init__(
169168
buffer_size_bytes: Buffer size for fetching results
170169
arraysize: Default number of rows to fetch
171170
use_cloud_fetch: Whether to use cloud fetch for retrieving results
172-
arrow_schema_bytes: Arrow schema bytes for the result set
173171
"""
174172
# Initialize ThriftResultSet-specific attributes
175-
self._arrow_schema_bytes = arrow_schema_bytes
173+
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
176174
self._use_cloud_fetch = use_cloud_fetch
177175
self.lz4_compressed = execute_response.lz4_compressed
178176

tests/unit/test_thrift_backend.py

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
from databricks.sql.auth.authenticators import AuthProvider
2020
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient
2121
from databricks.sql.result_set import ResultSet, ThriftResultSet
22-
from databricks.sql.backend.types import CommandId, CommandState, SessionId, BackendType
22+
from databricks.sql.backend.types import (
23+
CommandId,
24+
CommandState,
25+
SessionId,
26+
BackendType,
27+
ExecuteResponse,
28+
)
2329

2430

2531
def retry_policy_factory():
@@ -651,7 +657,7 @@ def test_handle_execute_response_sets_compression_in_direct_results(
651657
ssl_options=SSLOptions(),
652658
)
653659

654-
execute_response, _ = thrift_backend._handle_execute_response(
660+
execute_response = thrift_backend._handle_execute_response(
655661
t_execute_resp, Mock()
656662
)
657663
self.assertEqual(execute_response.lz4_compressed, lz4Compressed)
@@ -885,7 +891,7 @@ def test_handle_execute_response_can_handle_without_direct_results(
885891
auth_provider=AuthProvider(),
886892
ssl_options=SSLOptions(),
887893
)
888-
execute_response, _ = thrift_backend._handle_execute_response(
894+
execute_response = thrift_backend._handle_execute_response(
889895
execute_resp, Mock()
890896
)
891897

@@ -963,11 +969,11 @@ def test_use_arrow_schema_if_available(self, tcli_service_class):
963969
t_get_result_set_metadata_resp
964970
)
965971
thrift_backend = self._make_fake_thrift_backend()
966-
execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response(
972+
execute_response = thrift_backend._handle_execute_response(
967973
t_execute_resp, Mock()
968974
)
969975

970-
self.assertEqual(arrow_schema_bytes, arrow_schema_mock)
976+
self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock)
971977

972978
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
973979
def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class):
@@ -1040,7 +1046,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results(
10401046
)
10411047
thrift_backend = self._make_fake_thrift_backend()
10421048

1043-
execute_response, _ = thrift_backend._handle_execute_response(
1049+
execute_response = thrift_backend._handle_execute_response(
10441050
execute_resp, Mock()
10451051
)
10461052

@@ -1172,7 +1178,20 @@ def test_execute_statement_calls_client_and_handle_execute_response(
11721178
auth_provider=AuthProvider(),
11731179
ssl_options=SSLOptions(),
11741180
)
1175-
thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock()))
1181+
thrift_backend._handle_execute_response = Mock(
1182+
return_value=Mock(
1183+
spec=ExecuteResponse,
1184+
command_id=Mock(),
1185+
status=Mock(),
1186+
description=Mock(),
1187+
has_more_rows=Mock(),
1188+
results_queue=Mock(),
1189+
has_been_closed_server_side=Mock(),
1190+
lz4_compressed=Mock(),
1191+
is_staging_operation=Mock(),
1192+
arrow_schema_bytes=Mock(),
1193+
)
1194+
)
11761195
cursor_mock = Mock()
11771196

11781197
result = thrift_backend.execute_command(
@@ -1206,7 +1225,20 @@ def test_get_catalogs_calls_client_and_handle_execute_response(
12061225
auth_provider=AuthProvider(),
12071226
ssl_options=SSLOptions(),
12081227
)
1209-
thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock()))
1228+
thrift_backend._handle_execute_response = Mock(
1229+
return_value=Mock(
1230+
spec=ExecuteResponse,
1231+
command_id=Mock(),
1232+
status=Mock(),
1233+
description=Mock(),
1234+
has_more_rows=Mock(),
1235+
results_queue=Mock(),
1236+
has_been_closed_server_side=Mock(),
1237+
lz4_compressed=Mock(),
1238+
is_staging_operation=Mock(),
1239+
arrow_schema_bytes=Mock(),
1240+
)
1241+
)
12101242
cursor_mock = Mock()
12111243

12121244
result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock)
@@ -1237,7 +1269,20 @@ def test_get_schemas_calls_client_and_handle_execute_response(
12371269
auth_provider=AuthProvider(),
12381270
ssl_options=SSLOptions(),
12391271
)
1240-
thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock()))
1272+
thrift_backend._handle_execute_response = Mock(
1273+
return_value=Mock(
1274+
spec=ExecuteResponse,
1275+
command_id=Mock(),
1276+
status=Mock(),
1277+
description=Mock(),
1278+
has_more_rows=Mock(),
1279+
results_queue=Mock(),
1280+
has_been_closed_server_side=Mock(),
1281+
lz4_compressed=Mock(),
1282+
is_staging_operation=Mock(),
1283+
arrow_schema_bytes=Mock(),
1284+
)
1285+
)
12411286
cursor_mock = Mock()
12421287

12431288
result = thrift_backend.get_schemas(
@@ -1277,7 +1322,20 @@ def test_get_tables_calls_client_and_handle_execute_response(
12771322
auth_provider=AuthProvider(),
12781323
ssl_options=SSLOptions(),
12791324
)
1280-
thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock()))
1325+
thrift_backend._handle_execute_response = Mock(
1326+
return_value=Mock(
1327+
spec=ExecuteResponse,
1328+
command_id=Mock(),
1329+
status=Mock(),
1330+
description=Mock(),
1331+
has_more_rows=Mock(),
1332+
results_queue=Mock(),
1333+
has_been_closed_server_side=Mock(),
1334+
lz4_compressed=Mock(),
1335+
is_staging_operation=Mock(),
1336+
arrow_schema_bytes=Mock(),
1337+
)
1338+
)
12811339
cursor_mock = Mock()
12821340

12831341
result = thrift_backend.get_tables(
@@ -1321,7 +1379,20 @@ def test_get_columns_calls_client_and_handle_execute_response(
13211379
auth_provider=AuthProvider(),
13221380
ssl_options=SSLOptions(),
13231381
)
1324-
thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock()))
1382+
thrift_backend._handle_execute_response = Mock(
1383+
return_value=Mock(
1384+
spec=ExecuteResponse,
1385+
command_id=Mock(),
1386+
status=Mock(),
1387+
description=Mock(),
1388+
has_more_rows=Mock(),
1389+
results_queue=Mock(),
1390+
has_been_closed_server_side=Mock(),
1391+
lz4_compressed=Mock(),
1392+
is_staging_operation=Mock(),
1393+
arrow_schema_bytes=Mock(),
1394+
)
1395+
)
13251396
cursor_mock = Mock()
13261397

13271398
result = thrift_backend.get_columns(
@@ -2229,7 +2300,18 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class):
22292300
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
22302301
@patch(
22312302
"databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response",
2232-
return_value=(Mock(), Mock()),
2303+
return_value=Mock(
2304+
spec=ExecuteResponse,
2305+
command_id=Mock(),
2306+
status=Mock(),
2307+
description=Mock(),
2308+
has_more_rows=Mock(),
2309+
results_queue=Mock(),
2310+
has_been_closed_server_side=Mock(),
2311+
lz4_compressed=Mock(),
2312+
is_staging_operation=Mock(),
2313+
arrow_schema_bytes=Mock(),
2314+
),
22332315
)
22342316
def test_execute_command_sets_complex_type_fields_correctly(
22352317
self, mock_handle_execute_response, tcli_service_class

0 commit comments

Comments
 (0)