Skip to content

Commit cf8a629

Browse files
use normalised CommandState type in ExecuteResponse
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent db73ecf commit cf8a629

File tree

6 files changed

+13
-20
lines changed

6 files changed

+13
-20
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
163163
command_id: The command identifier to check
164164
165165
Returns:
166-
ttypes.TOperationState: The current state of the command
166+
CommandState: The current state of the command
167167
168168
Raises:
169169
ValueError: If the command ID is invalid

src/databricks/sql/backend/thrift_backend.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@
8686

8787

8888
class ThriftDatabricksClient(DatabricksClient):
89-
CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE
90-
ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE
89+
CLOSED_OP_STATE = CommandState.CLOSED
90+
ERROR_OP_STATE = CommandState.FAILED
9191

9292
_retry_delay_min: float
9393
_retry_delay_max: float
@@ -351,6 +351,7 @@ def make_request(self, method, request, retryable=True):
351351
Will stop retry attempts if total elapsed time + next retry delay would exceed
352352
_retry_stop_after_attempts_duration.
353353
"""
354+
354355
# basic strategy: build range iterator rep'ing number of available
355356
# retries. bounds can be computed from there. iterate over it with
356357
# retries until success or final failure achieved.
@@ -798,7 +799,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
798799

799800
return ExecuteResponse(
800801
arrow_queue=arrow_queue_opt,
801-
status=operation_state,
802+
status=CommandState.from_thrift_state(operation_state),
802803
has_been_closed_server_side=has_been_closed_server_side,
803804
has_more_rows=has_more_rows,
804805
lz4_compressed=lz4_compressed,
@@ -863,7 +864,7 @@ def get_execution_result(
863864

864865
execute_response = ExecuteResponse(
865866
arrow_queue=queue,
866-
status=resp.status,
867+
status=CommandState.from_thrift_state(resp.status),
867868
has_been_closed_server_side=False,
868869
has_more_rows=has_more_rows,
869870
lz4_compressed=lz4_compressed,

src/databricks/sql/result_set.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,15 +155,12 @@ def __init__(
155155
arraysize: Default number of rows to fetch
156156
use_cloud_fetch: Whether to use cloud fetch for retrieving results
157157
"""
158-
command_id = execute_response.command_id
159-
op_state = CommandState.from_thrift_state(execute_response.status)
160-
has_been_closed_server_side = execute_response.has_been_closed_server_side
161158
super().__init__(
162159
connection,
163160
thrift_client,
164-
command_id,
165-
op_state,
166-
has_been_closed_server_side,
161+
execute_response.command_id,
162+
execute_response.status,
163+
execute_response.has_been_closed_server_side,
167164
arraysize,
168165
buffer_size_bytes,
169166
)

tests/e2e/test_driver.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -827,10 +827,7 @@ def test_close_connection_closes_cursors(self):
827827
getProgressUpdate=False,
828828
)
829829
op_status_at_server = ars.backend._client.GetOperationStatus(status_request)
830-
assert (
831-
op_status_at_server.operationState
832-
!= ttypes.TOperationState.CLOSED_STATE
833-
)
830+
assert op_status_at_server.operationState != CommandState.CLOSED
834831

835832
conn.close()
836833

tests/unit/test_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class):
9696

9797
mock_execute_response.command_id = Mock(spec=CommandId)
9898
mock_execute_response.status = (
99-
TOperationState.FINISHED_STATE
100-
if not closed
101-
else TOperationState.CLOSED_STATE
99+
CommandState.SUCCEEDED if not closed else CommandState.CLOSED
102100
)
103101
mock_execute_response.has_been_closed_server_side = closed
104102
mock_execute_response.is_staging_operation = False

tests/unit/test_thrift_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from databricks.sql.auth.authenticators import AuthProvider
2020
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient
2121
from databricks.sql.result_set import ResultSet, ThriftResultSet
22-
from databricks.sql.backend.types import CommandId, SessionId, BackendType
22+
from databricks.sql.backend.types import CommandId, CommandState, SessionId, BackendType
2323

2424

2525
def retry_policy_factory():
@@ -883,7 +883,7 @@ def test_handle_execute_response_can_handle_without_direct_results(
883883
)
884884
self.assertEqual(
885885
results_message_response.status,
886-
ttypes.TOperationState.FINISHED_STATE,
886+
CommandState.SUCCEEDED,
887887
)
888888

889889
def test_handle_execute_response_can_handle_with_direct_results(self):

0 commit comments

Comments
 (0)