Skip to content

Commit 925b2a3

Browse files
committed
Missed adding some files in previous commit
1 parent a174370 commit 925b2a3

File tree

3 files changed

+134
-11
lines changed

3 files changed

+134
-11
lines changed

src/databricks/sql/client.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ def execute(
733733
self,
734734
operation: str,
735735
parameters: Optional[TParameterCollection] = None,
736-
perform_async = True
736+
perform_async = False
737737
) -> "Cursor":
738738
"""
739739
Execute a query and wait for execution to complete.
@@ -814,10 +814,37 @@ def execute(
814814

815815
return self
816816

817-
def executeAsync(self,
817+
def execute_async(self,
818818
operation: str,
819819
parameters: Optional[TParameterCollection] = None,):
820-
return execute(operation, parameters, True)
820+
return self.execute(operation, parameters, True)
821+
822+
def get_query_status(self):
823+
self._check_not_closed()
824+
return self.thrift_backend.get_query_status(self.active_op_handle)
825+
826+
def get_execution_result(self):
827+
self._check_not_closed()
828+
829+
operation_state = self.get_query_status()
830+
if operation_state.statusCode == ttypes.TStatusCode.SUCCESS_STATUS or operation_state.statusCode == ttypes.TStatusCode.SUCCESS_WITH_INFO_STATUS:
831+
execute_response=self.thrift_backend.get_execution_result(self.active_op_handle)
832+
self.active_result_set = ResultSet(
833+
self.connection,
834+
execute_response,
835+
self.thrift_backend,
836+
self.buffer_size_bytes,
837+
self.arraysize,
838+
)
839+
840+
if execute_response.is_staging_operation:
841+
self._handle_staging_operation(
842+
staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path
843+
)
844+
845+
return self
846+
else:
847+
raise Error(f"get_execution_result failed with status code {operation_state.statusCode}")
821848

822849
def executemany(self, operation, seq_of_parameters):
823850
"""
@@ -1126,7 +1153,7 @@ def __init__(
11261153
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
11271154
self._next_row_index = 0
11281155

1129-
if execute_response.arrow_queue:
1156+
if execute_response.arrow_queue or True:
11301157
# In this case the server has taken the fast path and returned an initial batch of
11311158
# results
11321159
self.results = execute_response.arrow_queue

src/databricks/sql/constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from databricks.sql.thrift_api.TCLIService import ttypes
2+
3+
class QueryExecutionStatus:
4+
INITIALIZED_STATE=ttypes.TOperationState.INITIALIZED_STATE
5+
RUNNING_STATE = ttypes.TOperationState.RUNNING_STATE
6+
FINISHED_STATE = ttypes.TOperationState.FINISHED_STATE
7+
CANCELED_STATE = ttypes.TOperationState.CANCELED_STATE
8+
CLOSED_STATE = ttypes.TOperationState.CLOSED_STATE
9+
ERROR_STATE = ttypes.TOperationState.ERROR_STATE
10+
UKNOWN_STATE = ttypes.TOperationState.UKNOWN_STATE
11+
PENDING_STATE = ttypes.TOperationState.PENDING_STATE
12+
TIMEDOUT_STATE = ttypes.TOperationState.TIMEDOUT_STATE

src/databricks/sql/thrift_backend.py

Lines changed: 91 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,66 @@ def _results_message_to_execute_response(self, resp, operation_state):
769769
arrow_schema_bytes=schema_bytes,
770770
)
771771

772+
def get_execution_result(self, op_handle):
773+
774+
assert op_handle is not None
775+
776+
req = ttypes.TFetchResultsReq(
777+
operationHandle=ttypes.TOperationHandle(
778+
op_handle.operationId,
779+
op_handle.operationType,
780+
False,
781+
op_handle.modifiedRowCount,
782+
),
783+
maxRows=max_rows,
784+
maxBytes=max_bytes,
785+
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
786+
includeResultSetMetadata=True,
787+
)
788+
789+
resp = self.make_request(self._client.FetchResults, req)
790+
791+
t_result_set_metadata_resp = resp.resultSetMetaData
792+
793+
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
794+
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
795+
has_more_rows = resp.hasMoreRows
796+
description = self._hive_schema_to_description(
797+
t_result_set_metadata_resp.schema
798+
)
799+
800+
if pyarrow:
801+
schema_bytes = (
802+
t_result_set_metadata_resp.arrowSchema
803+
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
804+
.serialize()
805+
.to_pybytes()
806+
)
807+
else:
808+
schema_bytes = None
809+
810+
queue = ResultSetQueueFactory.build_queue(
811+
row_set_type=resp.resultSetMetadata.resultFormat,
812+
t_row_set=resp.results,
813+
arrow_schema_bytes=schema_bytes,
814+
max_download_threads=self.max_download_threads,
815+
lz4_compressed=lz4_compressed,
816+
description=description,
817+
ssl_options=self._ssl_options,
818+
)
819+
820+
return ExecuteResponse(
821+
arrow_queue=queue,
822+
status=resp.status,
823+
has_been_closed_server_side=has_been_closed_server_side,
824+
has_more_rows=has_more_rows,
825+
lz4_compressed=lz4_compressed,
826+
is_staging_operation=is_staging_operation,
827+
command_handle=resp.operationHandle,
828+
description=description,
829+
arrow_schema_bytes=schema_bytes,
830+
)
831+
772832
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
773833
if initial_operation_status_resp:
774834
self._check_command_not_in_error_or_closed_state(
@@ -787,6 +847,12 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
787847
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
788848
return operation_state
789849

850+
def get_query_status(self, op_handle):
851+
poll_resp = self._poll_for_status(op_handle)
852+
operation_state = poll_resp.status
853+
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
854+
return operation_state
855+
790856
@staticmethod
791857
def _check_direct_results_for_error(t_spark_direct_results):
792858
if t_spark_direct_results:
@@ -848,7 +914,10 @@ def execute_command(
848914
)
849915
resp = self.make_request(self._client.ExecuteStatement, req)
850916

851-
return self._handle_execute_response(resp, cursor, perform_async)
917+
if perform_async:
918+
return self._handle_execute_response_async(resp, cursor)
919+
else:
920+
return self._handle_execute_response(resp, cursor)
852921

853922
def get_catalogs(self, session_handle, max_rows, max_bytes, cursor):
854923
assert session_handle is not None
@@ -936,19 +1005,34 @@ def get_columns(
9361005
resp = self.make_request(self._client.GetColumns, req)
9371006
return self._handle_execute_response(resp, cursor)
9381007

939-
def _handle_execute_response(self, resp, cursor, perform_async=False):
1008+
def _handle_execute_response(self, resp, cursor):
9401009
cursor.active_op_handle = resp.operationHandle
9411010
self._check_direct_results_for_error(resp.directResults)
9421011

943-
if perform_async:
944-
final_operation_state=ttypes.TStatusCode.STILL_EXECUTING_STATUS
945-
else:
946-
final_operation_state=self._wait_until_command_done(
1012+
final_operation_state = self._wait_until_command_done(
9471013
resp.operationHandle,
948-
resp.directResults and resp.directResults.operationStatus)
1014+
resp.directResults and resp.directResults.operationStatus,
1015+
)
9491016

9501017
return self._results_message_to_execute_response(resp, final_operation_state)
9511018

1019+
def _handle_execute_response_async(self, resp, cursor):
1020+
cursor.active_op_handle = resp.operationHandle
1021+
self._check_direct_results_for_error(resp.directResults)
1022+
operation_status = resp.status.statusCode
1023+
1024+
return ExecuteResponse(
1025+
arrow_queue=None,
1026+
status=operation_status,
1027+
has_been_closed_server_side=None,
1028+
has_more_rows=None,
1029+
lz4_compressed=None,
1030+
is_staging_operation=None,
1031+
command_handle=resp.operationHandle,
1032+
description=None,
1033+
arrow_schema_bytes=None,
1034+
)
1035+
9521036
def fetch_results(
9531037
self,
9541038
op_handle,

0 commit comments

Comments
 (0)