Skip to content

Commit 8374a75

Browse files
reduce additional network call after wait
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 1dd3877 commit 8374a75

File tree

2 files changed

+28
-37
lines changed

2 files changed

+28
-37
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def _check_command_not_in_failed_or_closed_state(
398398

399399
def _wait_until_command_done(
400400
self, response: ExecuteStatementResponse
401-
) -> CommandState:
401+
) -> ExecuteStatementResponse:
402402
"""
403403
Wait until a command is done.
404404
"""
@@ -408,11 +408,12 @@ def _wait_until_command_done(
408408

409409
while state in [CommandState.PENDING, CommandState.RUNNING]:
410410
time.sleep(self.POLL_INTERVAL_SECONDS)
411-
state = self.get_query_state(command_id)
411+
response = self._poll_query(command_id)
412+
state = response.status.state
412413

413414
self._check_command_not_in_failed_or_closed_state(state, command_id)
414415

415-
return state
416+
return response
416417

417418
def execute_command(
418419
self,
@@ -516,8 +517,8 @@ def execute_command(
516517
# if the response succeeded within the wait_timeout, return the results immediately
517518
return self._response_to_result_set(response, cursor)
518519

519-
self._wait_until_command_done(response)
520-
return self.get_execution_result(command_id, cursor)
520+
response = self._wait_until_command_done(response)
521+
return self._response_to_result_set(response, cursor)
521522

522523
def cancel_command(self, command_id: CommandId) -> None:
523524
"""
@@ -569,18 +570,9 @@ def close_command(self, command_id: CommandId) -> None:
569570
data=request.to_dict(),
570571
)
571572

572-
def get_query_state(self, command_id: CommandId) -> CommandState:
573+
def _poll_query(self, command_id: CommandId) -> ExecuteStatementResponse:
573574
"""
574-
Get the state of a running query.
575-
576-
Args:
577-
command_id: Command identifier
578-
579-
Returns:
580-
CommandState: The current state of the command
581-
582-
Raises:
583-
ProgrammingError: If the command ID is invalid
575+
Poll for the current command info.
584576
"""
585577

586578
if command_id.backend_type != BackendType.SEA:
@@ -596,9 +588,25 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
596588
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
597589
data=request.to_dict(),
598590
)
599-
600-
# Parse the response
601591
response = ExecuteStatementResponse.from_dict(response_data)
592+
593+
return response
594+
595+
def get_query_state(self, command_id: CommandId) -> CommandState:
596+
"""
597+
Get the state of a running query.
598+
599+
Args:
600+
command_id: Command identifier
601+
602+
Returns:
603+
CommandState: The current state of the command
604+
605+
Raises:
606+
ProgrammingError: If the command ID is invalid
607+
"""
608+
609+
response = self._poll_query(command_id)
602610
return response.status.state
603611

604612
def get_execution_result(
@@ -620,24 +628,7 @@ def get_execution_result(
620628
ValueError: If the command ID is invalid
621629
"""
622630

623-
if command_id.backend_type != BackendType.SEA:
624-
raise ValueError("Not a valid SEA command ID")
625-
626-
sea_statement_id = command_id.to_sea_statement_id()
627-
if sea_statement_id is None:
628-
raise ValueError("Not a valid SEA command ID")
629-
630-
# Create the request model
631-
request = GetStatementRequest(statement_id=sea_statement_id)
632-
633-
# Get the statement result
634-
response_data = self.http_client._make_request(
635-
method="GET",
636-
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
637-
data=request.to_dict(),
638-
)
639-
response = ExecuteStatementResponse.from_dict(response_data)
640-
631+
response = self._poll_query(command_id)
641632
return self._response_to_result_set(response, cursor)
642633

643634
# == Metadata Operations ==

tests/unit/test_sea_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def test_command_execution_advanced(
327327
mock_http_client._make_request.side_effect = [initial_response, poll_response]
328328

329329
with patch.object(
330-
sea_client, "get_execution_result", return_value="mock_result_set"
330+
sea_client, "_response_to_result_set", return_value="mock_result_set"
331331
) as mock_get_result:
332332
with patch("time.sleep"):
333333
result = sea_client.execute_command(

0 commit comments

Comments
 (0)