Skip to content

Commit ba91138

Browse files
Merge branch 'exec-resp-norm' into sea-res-set
2 parents 65e7c6b + c04d583 commit ba91138

File tree

10 files changed

+237
-190
lines changed

10 files changed

+237
-190
lines changed

src/databricks/sql/backend/sea/models/requests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
@dataclass
66
class CreateSessionRequest:
7-
"""Request to create a new session."""
7+
"""Representation of a request to create a new session."""
88

99
warehouse_id: str
1010
session_confs: Optional[Dict[str, str]] = None
@@ -29,7 +29,7 @@ def to_dict(self) -> Dict[str, Any]:
2929

3030
@dataclass
3131
class DeleteSessionRequest:
32-
"""Request to delete a session."""
32+
"""Representation of a request to delete a session."""
3333

3434
warehouse_id: str
3535
session_id: str

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
@dataclass
66
class CreateSessionResponse:
7-
"""Response from creating a new session."""
7+
"""Representation of the response from creating a new session."""
88

99
session_id: str
1010

src/databricks/sql/backend/thrift_backend.py

Lines changed: 72 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,20 @@
33
import logging
44
import math
55
import time
6-
import uuid
76
import threading
87
from typing import List, Union, Any, TYPE_CHECKING
98

109
if TYPE_CHECKING:
1110
from databricks.sql.client import Cursor
1211

13-
from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
1412
from databricks.sql.backend.types import (
1513
CommandState,
1614
SessionId,
1715
CommandId,
18-
BackendType,
19-
guid_to_hex_id,
2016
ExecuteResponse,
2117
)
18+
from databricks.sql.backend.utils import guid_to_hex_id
19+
2220

2321
try:
2422
import pyarrow
@@ -759,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
759757
)
760758
direct_results = resp.directResults
761759
has_been_closed_server_side = direct_results and direct_results.closeOperation
762-
has_more_rows = (
760+
761+
is_direct_results = (
763762
(not direct_results)
764763
or (not direct_results.resultSet)
765764
or direct_results.resultSet.hasMoreRows
766765
)
766+
767767
description = self._hive_schema_to_description(
768768
t_result_set_metadata_resp.schema
769769
)
@@ -779,43 +779,25 @@ def _results_message_to_execute_response(self, resp, operation_state):
779779
schema_bytes = None
780780

781781
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
782-
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
783-
if direct_results and direct_results.resultSet:
784-
assert direct_results.resultSet.results.startRowOffset == 0
785-
assert direct_results.resultSetMetadata
786-
787-
arrow_queue_opt = ResultSetQueueFactory.build_queue(
788-
row_set_type=t_result_set_metadata_resp.resultFormat,
789-
t_row_set=direct_results.resultSet.results,
790-
arrow_schema_bytes=schema_bytes,
791-
max_download_threads=self.max_download_threads,
792-
lz4_compressed=lz4_compressed,
793-
description=description,
794-
ssl_options=self._ssl_options,
795-
)
796-
else:
797-
arrow_queue_opt = None
798-
799782
command_id = CommandId.from_thrift_handle(resp.operationHandle)
800783

801784
status = CommandState.from_thrift_state(operation_state)
802785
if status is None:
803786
raise ValueError(f"Unknown command state: {operation_state}")
804787

805-
return (
806-
ExecuteResponse(
807-
command_id=command_id,
808-
status=status,
809-
description=description,
810-
has_more_rows=has_more_rows,
811-
results_queue=arrow_queue_opt,
812-
has_been_closed_server_side=has_been_closed_server_side,
813-
lz4_compressed=lz4_compressed,
814-
is_staging_operation=is_staging_operation,
815-
),
816-
schema_bytes,
788+
execute_response = ExecuteResponse(
789+
command_id=command_id,
790+
status=status,
791+
description=description,
792+
has_been_closed_server_side=has_been_closed_server_side,
793+
lz4_compressed=lz4_compressed,
794+
is_staging_operation=t_result_set_metadata_resp.isStagingOperation,
795+
arrow_schema_bytes=schema_bytes,
796+
result_format=t_result_set_metadata_resp.resultFormat,
817797
)
818798

799+
return execute_response, is_direct_results
800+
819801
def get_execution_result(
820802
self, command_id: CommandId, cursor: "Cursor"
821803
) -> "ResultSet":
@@ -840,9 +822,6 @@ def get_execution_result(
840822

841823
t_result_set_metadata_resp = resp.resultSetMetadata
842824

843-
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
844-
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
845-
has_more_rows = resp.hasMoreRows
846825
description = self._hive_schema_to_description(
847826
t_result_set_metadata_resp.schema
848827
)
@@ -857,27 +836,21 @@ def get_execution_result(
857836
else:
858837
schema_bytes = None
859838

860-
queue = ResultSetQueueFactory.build_queue(
861-
row_set_type=resp.resultSetMetadata.resultFormat,
862-
t_row_set=resp.results,
863-
arrow_schema_bytes=schema_bytes,
864-
max_download_threads=self.max_download_threads,
865-
lz4_compressed=lz4_compressed,
866-
description=description,
867-
ssl_options=self._ssl_options,
868-
)
839+
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
840+
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
841+
is_direct_results = resp.hasMoreRows
869842

870843
status = self.get_query_state(command_id)
871844

872845
execute_response = ExecuteResponse(
873846
command_id=command_id,
874847
status=status,
875848
description=description,
876-
has_more_rows=has_more_rows,
877-
results_queue=queue,
878849
has_been_closed_server_side=False,
879850
lz4_compressed=lz4_compressed,
880851
is_staging_operation=is_staging_operation,
852+
arrow_schema_bytes=schema_bytes,
853+
result_format=t_result_set_metadata_resp.resultFormat,
881854
)
882855

883856
return ThriftResultSet(
@@ -887,7 +860,10 @@ def get_execution_result(
887860
buffer_size_bytes=cursor.buffer_size_bytes,
888861
arraysize=cursor.arraysize,
889862
use_cloud_fetch=cursor.connection.use_cloud_fetch,
890-
arrow_schema_bytes=schema_bytes,
863+
t_row_set=resp.results,
864+
max_download_threads=self.max_download_threads,
865+
ssl_options=self._ssl_options,
866+
is_direct_results=is_direct_results,
891867
)
892868

893869
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
@@ -918,7 +894,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
918894
self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp)
919895
state = CommandState.from_thrift_state(operation_state)
920896
if state is None:
921-
raise ValueError(f"Invalid operation state: {operation_state}")
897+
raise ValueError(f"Unknown command state: {operation_state}")
922898
return state
923899

924900
@staticmethod
@@ -1000,18 +976,25 @@ def execute_command(
1000976
self._handle_execute_response_async(resp, cursor)
1001977
return None
1002978
else:
1003-
execute_response, arrow_schema_bytes = self._handle_execute_response(
979+
execute_response, is_direct_results = self._handle_execute_response(
1004980
resp, cursor
1005981
)
1006982

983+
t_row_set = None
984+
if resp.directResults and resp.directResults.resultSet:
985+
t_row_set = resp.directResults.resultSet.results
986+
1007987
return ThriftResultSet(
1008988
connection=cursor.connection,
1009989
execute_response=execute_response,
1010990
thrift_client=self,
1011991
buffer_size_bytes=max_bytes,
1012992
arraysize=max_rows,
1013993
use_cloud_fetch=use_cloud_fetch,
1014-
arrow_schema_bytes=arrow_schema_bytes,
994+
t_row_set=t_row_set,
995+
max_download_threads=self.max_download_threads,
996+
ssl_options=self._ssl_options,
997+
is_direct_results=is_direct_results,
1015998
)
1016999

10171000
def get_catalogs(
@@ -1033,18 +1016,25 @@ def get_catalogs(
10331016
)
10341017
resp = self.make_request(self._client.GetCatalogs, req)
10351018

1036-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1019+
execute_response, is_direct_results = self._handle_execute_response(
10371020
resp, cursor
10381021
)
10391022

1023+
t_row_set = None
1024+
if resp.directResults and resp.directResults.resultSet:
1025+
t_row_set = resp.directResults.resultSet.results
1026+
10401027
return ThriftResultSet(
10411028
connection=cursor.connection,
10421029
execute_response=execute_response,
10431030
thrift_client=self,
10441031
buffer_size_bytes=max_bytes,
10451032
arraysize=max_rows,
10461033
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1047-
arrow_schema_bytes=arrow_schema_bytes,
1034+
t_row_set=t_row_set,
1035+
max_download_threads=self.max_download_threads,
1036+
ssl_options=self._ssl_options,
1037+
is_direct_results=is_direct_results,
10481038
)
10491039

10501040
def get_schemas(
@@ -1070,18 +1060,25 @@ def get_schemas(
10701060
)
10711061
resp = self.make_request(self._client.GetSchemas, req)
10721062

1073-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1063+
execute_response, is_direct_results = self._handle_execute_response(
10741064
resp, cursor
10751065
)
10761066

1067+
t_row_set = None
1068+
if resp.directResults and resp.directResults.resultSet:
1069+
t_row_set = resp.directResults.resultSet.results
1070+
10771071
return ThriftResultSet(
10781072
connection=cursor.connection,
10791073
execute_response=execute_response,
10801074
thrift_client=self,
10811075
buffer_size_bytes=max_bytes,
10821076
arraysize=max_rows,
10831077
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1084-
arrow_schema_bytes=arrow_schema_bytes,
1078+
t_row_set=t_row_set,
1079+
max_download_threads=self.max_download_threads,
1080+
ssl_options=self._ssl_options,
1081+
is_direct_results=is_direct_results,
10851082
)
10861083

10871084
def get_tables(
@@ -1111,18 +1108,25 @@ def get_tables(
11111108
)
11121109
resp = self.make_request(self._client.GetTables, req)
11131110

1114-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1111+
execute_response, is_direct_results = self._handle_execute_response(
11151112
resp, cursor
11161113
)
11171114

1115+
t_row_set = None
1116+
if resp.directResults and resp.directResults.resultSet:
1117+
t_row_set = resp.directResults.resultSet.results
1118+
11181119
return ThriftResultSet(
11191120
connection=cursor.connection,
11201121
execute_response=execute_response,
11211122
thrift_client=self,
11221123
buffer_size_bytes=max_bytes,
11231124
arraysize=max_rows,
11241125
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1125-
arrow_schema_bytes=arrow_schema_bytes,
1126+
t_row_set=t_row_set,
1127+
max_download_threads=self.max_download_threads,
1128+
ssl_options=self._ssl_options,
1129+
is_direct_results=is_direct_results,
11261130
)
11271131

11281132
def get_columns(
@@ -1152,18 +1156,25 @@ def get_columns(
11521156
)
11531157
resp = self.make_request(self._client.GetColumns, req)
11541158

1155-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1159+
execute_response, is_direct_results = self._handle_execute_response(
11561160
resp, cursor
11571161
)
11581162

1163+
t_row_set = None
1164+
if resp.directResults and resp.directResults.resultSet:
1165+
t_row_set = resp.directResults.resultSet.results
1166+
11591167
return ThriftResultSet(
11601168
connection=cursor.connection,
11611169
execute_response=execute_response,
11621170
thrift_client=self,
11631171
buffer_size_bytes=max_bytes,
11641172
arraysize=max_rows,
11651173
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1166-
arrow_schema_bytes=arrow_schema_bytes,
1174+
t_row_set=t_row_set,
1175+
max_download_threads=self.max_download_threads,
1176+
ssl_options=self._ssl_options,
1177+
is_direct_results=is_direct_results,
11671178
)
11681179

11691180
def _handle_execute_response(self, resp, cursor):
@@ -1177,11 +1188,7 @@ def _handle_execute_response(self, resp, cursor):
11771188
resp.directResults and resp.directResults.operationStatus,
11781189
)
11791190

1180-
(
1181-
execute_response,
1182-
arrow_schema_bytes,
1183-
) = self._results_message_to_execute_response(resp, final_operation_state)
1184-
return execute_response, arrow_schema_bytes
1191+
return self._results_message_to_execute_response(resp, final_operation_state)
11851192

11861193
def _handle_execute_response_async(self, resp, cursor):
11871194
command_id = CommandId.from_thrift_handle(resp.operationHandle)

src/databricks/sql/backend/types.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -423,11 +423,9 @@ class ExecuteResponse:
423423

424424
command_id: CommandId
425425
status: CommandState
426-
description: Optional[
427-
List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]]
428-
] = None
429-
has_more_rows: bool = False
430-
results_queue: Optional[Any] = None
426+
description: Optional[List[Tuple]] = None
431427
has_been_closed_server_side: bool = False
432428
lz4_compressed: bool = True
433429
is_staging_operation: bool = False
430+
arrow_schema_bytes: Optional[bytes] = None
431+
result_format: Optional[Any] = None

0 commit comments

Comments
 (0)