33import logging
44import math
55import time
6- import uuid
76import threading
87from typing import List , Union , Any , TYPE_CHECKING
98
109if TYPE_CHECKING :
1110 from databricks .sql .client import Cursor
1211
13- from databricks .sql .thrift_api .TCLIService .ttypes import TOperationState
1412from databricks .sql .backend .types import (
1513 CommandState ,
1614 SessionId ,
1715 CommandId ,
18- BackendType ,
19- guid_to_hex_id ,
2016 ExecuteResponse ,
2117)
18+ from databricks .sql .backend .utils import guid_to_hex_id
19+
2220
2321try :
2422 import pyarrow
@@ -759,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
759757 )
760758 direct_results = resp .directResults
761759 has_been_closed_server_side = direct_results and direct_results .closeOperation
762- has_more_rows = (
760+
761+ is_direct_results = (
763762 (not direct_results )
764763 or (not direct_results .resultSet )
765764 or direct_results .resultSet .hasMoreRows
766765 )
766+
767767 description = self ._hive_schema_to_description (
768768 t_result_set_metadata_resp .schema
769769 )
@@ -779,43 +779,25 @@ def _results_message_to_execute_response(self, resp, operation_state):
779779 schema_bytes = None
780780
781781 lz4_compressed = t_result_set_metadata_resp .lz4Compressed
782- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
783- if direct_results and direct_results .resultSet :
784- assert direct_results .resultSet .results .startRowOffset == 0
785- assert direct_results .resultSetMetadata
786-
787- arrow_queue_opt = ResultSetQueueFactory .build_queue (
788- row_set_type = t_result_set_metadata_resp .resultFormat ,
789- t_row_set = direct_results .resultSet .results ,
790- arrow_schema_bytes = schema_bytes ,
791- max_download_threads = self .max_download_threads ,
792- lz4_compressed = lz4_compressed ,
793- description = description ,
794- ssl_options = self ._ssl_options ,
795- )
796- else :
797- arrow_queue_opt = None
798-
799782 command_id = CommandId .from_thrift_handle (resp .operationHandle )
800783
801784 status = CommandState .from_thrift_state (operation_state )
802785 if status is None :
803786 raise ValueError (f"Unknown command state: { operation_state } " )
804787
805- return (
806- ExecuteResponse (
807- command_id = command_id ,
808- status = status ,
809- description = description ,
810- has_more_rows = has_more_rows ,
811- results_queue = arrow_queue_opt ,
812- has_been_closed_server_side = has_been_closed_server_side ,
813- lz4_compressed = lz4_compressed ,
814- is_staging_operation = is_staging_operation ,
815- ),
816- schema_bytes ,
788+ execute_response = ExecuteResponse (
789+ command_id = command_id ,
790+ status = status ,
791+ description = description ,
792+ has_been_closed_server_side = has_been_closed_server_side ,
793+ lz4_compressed = lz4_compressed ,
794+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation ,
795+ arrow_schema_bytes = schema_bytes ,
796+ result_format = t_result_set_metadata_resp .resultFormat ,
817797 )
818798
799+ return execute_response , is_direct_results
800+
819801 def get_execution_result (
820802 self , command_id : CommandId , cursor : "Cursor"
821803 ) -> "ResultSet" :
@@ -840,9 +822,6 @@ def get_execution_result(
840822
841823 t_result_set_metadata_resp = resp .resultSetMetadata
842824
843- lz4_compressed = t_result_set_metadata_resp .lz4Compressed
844- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
845- has_more_rows = resp .hasMoreRows
846825 description = self ._hive_schema_to_description (
847826 t_result_set_metadata_resp .schema
848827 )
@@ -857,27 +836,21 @@ def get_execution_result(
857836 else :
858837 schema_bytes = None
859838
860- queue = ResultSetQueueFactory .build_queue (
861- row_set_type = resp .resultSetMetadata .resultFormat ,
862- t_row_set = resp .results ,
863- arrow_schema_bytes = schema_bytes ,
864- max_download_threads = self .max_download_threads ,
865- lz4_compressed = lz4_compressed ,
866- description = description ,
867- ssl_options = self ._ssl_options ,
868- )
839+ lz4_compressed = t_result_set_metadata_resp .lz4Compressed
840+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation
841+ is_direct_results = resp .hasMoreRows
869842
870843 status = self .get_query_state (command_id )
871844
872845 execute_response = ExecuteResponse (
873846 command_id = command_id ,
874847 status = status ,
875848 description = description ,
876- has_more_rows = has_more_rows ,
877- results_queue = queue ,
878849 has_been_closed_server_side = False ,
879850 lz4_compressed = lz4_compressed ,
880851 is_staging_operation = is_staging_operation ,
852+ arrow_schema_bytes = schema_bytes ,
853+ result_format = t_result_set_metadata_resp .resultFormat ,
881854 )
882855
883856 return ThriftResultSet (
@@ -887,7 +860,10 @@ def get_execution_result(
887860 buffer_size_bytes = cursor .buffer_size_bytes ,
888861 arraysize = cursor .arraysize ,
889862 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
890- arrow_schema_bytes = schema_bytes ,
863+ t_row_set = resp .results ,
864+ max_download_threads = self .max_download_threads ,
865+ ssl_options = self ._ssl_options ,
866+ is_direct_results = is_direct_results ,
891867 )
892868
893869 def _wait_until_command_done (self , op_handle , initial_operation_status_resp ):
@@ -918,7 +894,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
918894 self ._check_command_not_in_error_or_closed_state (thrift_handle , poll_resp )
919895 state = CommandState .from_thrift_state (operation_state )
920896 if state is None :
921- raise ValueError (f"Invalid operation state: { operation_state } " )
897+ raise ValueError (f"Unknown command state: { operation_state } " )
922898 return state
923899
924900 @staticmethod
@@ -1000,18 +976,25 @@ def execute_command(
1000976 self ._handle_execute_response_async (resp , cursor )
1001977 return None
1002978 else :
1003- execute_response , arrow_schema_bytes = self ._handle_execute_response (
979+ execute_response , is_direct_results = self ._handle_execute_response (
1004980 resp , cursor
1005981 )
1006982
983+ t_row_set = None
984+ if resp .directResults and resp .directResults .resultSet :
985+ t_row_set = resp .directResults .resultSet .results
986+
1007987 return ThriftResultSet (
1008988 connection = cursor .connection ,
1009989 execute_response = execute_response ,
1010990 thrift_client = self ,
1011991 buffer_size_bytes = max_bytes ,
1012992 arraysize = max_rows ,
1013993 use_cloud_fetch = use_cloud_fetch ,
1014- arrow_schema_bytes = arrow_schema_bytes ,
994+ t_row_set = t_row_set ,
995+ max_download_threads = self .max_download_threads ,
996+ ssl_options = self ._ssl_options ,
997+ is_direct_results = is_direct_results ,
1015998 )
1016999
10171000 def get_catalogs (
@@ -1033,18 +1016,25 @@ def get_catalogs(
10331016 )
10341017 resp = self .make_request (self ._client .GetCatalogs , req )
10351018
1036- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1019+ execute_response , is_direct_results = self ._handle_execute_response (
10371020 resp , cursor
10381021 )
10391022
1023+ t_row_set = None
1024+ if resp .directResults and resp .directResults .resultSet :
1025+ t_row_set = resp .directResults .resultSet .results
1026+
10401027 return ThriftResultSet (
10411028 connection = cursor .connection ,
10421029 execute_response = execute_response ,
10431030 thrift_client = self ,
10441031 buffer_size_bytes = max_bytes ,
10451032 arraysize = max_rows ,
10461033 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1047- arrow_schema_bytes = arrow_schema_bytes ,
1034+ t_row_set = t_row_set ,
1035+ max_download_threads = self .max_download_threads ,
1036+ ssl_options = self ._ssl_options ,
1037+ is_direct_results = is_direct_results ,
10481038 )
10491039
10501040 def get_schemas (
@@ -1070,18 +1060,25 @@ def get_schemas(
10701060 )
10711061 resp = self .make_request (self ._client .GetSchemas , req )
10721062
1073- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1063+ execute_response , is_direct_results = self ._handle_execute_response (
10741064 resp , cursor
10751065 )
10761066
1067+ t_row_set = None
1068+ if resp .directResults and resp .directResults .resultSet :
1069+ t_row_set = resp .directResults .resultSet .results
1070+
10771071 return ThriftResultSet (
10781072 connection = cursor .connection ,
10791073 execute_response = execute_response ,
10801074 thrift_client = self ,
10811075 buffer_size_bytes = max_bytes ,
10821076 arraysize = max_rows ,
10831077 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1084- arrow_schema_bytes = arrow_schema_bytes ,
1078+ t_row_set = t_row_set ,
1079+ max_download_threads = self .max_download_threads ,
1080+ ssl_options = self ._ssl_options ,
1081+ is_direct_results = is_direct_results ,
10851082 )
10861083
10871084 def get_tables (
@@ -1111,18 +1108,25 @@ def get_tables(
11111108 )
11121109 resp = self .make_request (self ._client .GetTables , req )
11131110
1114- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1111+ execute_response , is_direct_results = self ._handle_execute_response (
11151112 resp , cursor
11161113 )
11171114
1115+ t_row_set = None
1116+ if resp .directResults and resp .directResults .resultSet :
1117+ t_row_set = resp .directResults .resultSet .results
1118+
11181119 return ThriftResultSet (
11191120 connection = cursor .connection ,
11201121 execute_response = execute_response ,
11211122 thrift_client = self ,
11221123 buffer_size_bytes = max_bytes ,
11231124 arraysize = max_rows ,
11241125 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1125- arrow_schema_bytes = arrow_schema_bytes ,
1126+ t_row_set = t_row_set ,
1127+ max_download_threads = self .max_download_threads ,
1128+ ssl_options = self ._ssl_options ,
1129+ is_direct_results = is_direct_results ,
11261130 )
11271131
11281132 def get_columns (
@@ -1152,18 +1156,25 @@ def get_columns(
11521156 )
11531157 resp = self .make_request (self ._client .GetColumns , req )
11541158
1155- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1159+ execute_response , is_direct_results = self ._handle_execute_response (
11561160 resp , cursor
11571161 )
11581162
1163+ t_row_set = None
1164+ if resp .directResults and resp .directResults .resultSet :
1165+ t_row_set = resp .directResults .resultSet .results
1166+
11591167 return ThriftResultSet (
11601168 connection = cursor .connection ,
11611169 execute_response = execute_response ,
11621170 thrift_client = self ,
11631171 buffer_size_bytes = max_bytes ,
11641172 arraysize = max_rows ,
11651173 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1166- arrow_schema_bytes = arrow_schema_bytes ,
1174+ t_row_set = t_row_set ,
1175+ max_download_threads = self .max_download_threads ,
1176+ ssl_options = self ._ssl_options ,
1177+ is_direct_results = is_direct_results ,
11671178 )
11681179
11691180 def _handle_execute_response (self , resp , cursor ):
@@ -1177,11 +1188,7 @@ def _handle_execute_response(self, resp, cursor):
11771188 resp .directResults and resp .directResults .operationStatus ,
11781189 )
11791190
1180- (
1181- execute_response ,
1182- arrow_schema_bytes ,
1183- ) = self ._results_message_to_execute_response (resp , final_operation_state )
1184- return execute_response , arrow_schema_bytes
1191+ return self ._results_message_to_execute_response (resp , final_operation_state )
11851192
11861193 def _handle_execute_response_async (self , resp , cursor ):
11871194 command_id = CommandId .from_thrift_handle (resp .operationHandle )
0 commit comments