55import re
66from typing import Any , Dict , Tuple , List , Optional , Union , TYPE_CHECKING , Set
77
8- from databricks .sql .backend .sea .models .base import ExternalLink , ResultManifest
8+ from databricks .sql .backend .sea .models .base import (
9+ ExternalLink ,
10+ ResultManifest ,
11+ StatementStatus ,
12+ )
13+ from databricks .sql .backend .sea .models .responses import GetChunksResponse
914from databricks .sql .backend .sea .utils .constants import (
1015 ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP ,
1116 ResultFormat ,
4550 GetStatementResponse ,
4651 CreateSessionResponse ,
4752)
48- from databricks .sql .backend .sea .models .responses import GetChunksResponse
4953
5054logger = logging .getLogger (__name__ )
5155
5256
5357def _filter_session_configuration (
5458 session_configuration : Optional [Dict [str , Any ]],
5559) -> Dict [str , str ]:
60+ """
61+ Filter and normalise the provided session configuration parameters.
62+
63+ The Statement Execution API supports only a subset of SQL session
64+ configuration options. This helper validates the supplied
65+ ``session_configuration`` dictionary against the allow-list defined in
66+ ``ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP`` and returns a new
67+ dictionary that contains **only** the supported parameters.
68+
69+ Args:
70+ session_configuration: Optional mapping of session configuration
71+ names to their desired values. Key comparison is
72+ case-insensitive.
73+
74+ Returns:
75+ Dict[str, str]: A dictionary containing only the supported
76+ configuration parameters with lower-case keys and string values. If
77+ *session_configuration* is ``None`` or empty, an empty dictionary is
78+ returned.
79+ """
80+
5681 if not session_configuration :
5782 return {}
5883
@@ -143,7 +168,7 @@ def __init__(
143168 http_path = http_path ,
144169 http_headers = http_headers ,
145170 auth_provider = auth_provider ,
146- ssl_options = self . _ssl_options ,
171+ ssl_options = ssl_options ,
147172 ** kwargs ,
148173 )
149174
@@ -275,29 +300,6 @@ def close_session(self, session_id: SessionId) -> None:
275300 data = request_data .to_dict (),
276301 )
277302
278- @staticmethod
279- def get_default_session_configuration_value (name : str ) -> Optional [str ]:
280- """
281- Get the default value for a session configuration parameter.
282-
283- Args:
284- name: The name of the session configuration parameter
285-
286- Returns:
287- The default value if the parameter is supported, None otherwise
288- """
289- return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP .get (name .upper ())
290-
291- @staticmethod
292- def get_allowed_session_configurations () -> List [str ]:
293- """
294- Get the list of allowed session configuration parameters.
295-
296- Returns:
297- List of allowed session configuration parameter names
298- """
299- return list (ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP .keys ())
300-
301303 def _extract_description_from_manifest (
302304 self , manifest : ResultManifest
303305 ) -> List [Tuple ]:
@@ -309,7 +311,7 @@ def _extract_description_from_manifest(
309311 manifest: The ResultManifest object containing schema information
310312
311313 Returns:
312- List[Tuple ]: A list of column tuples
314+ Optional[List ]: A list of column tuples or None if no columns are found
313315 """
314316
315317 schema_data = manifest .schema
@@ -397,8 +399,9 @@ def _response_to_result_set(
397399 )
398400
399401 def _check_command_not_in_failed_or_closed_state (
400- self , state : CommandState , command_id : CommandId
402+ self , status : StatementStatus , command_id : CommandId
401403 ) -> None :
404+ state = status .state
402405 if state == CommandState .CLOSED :
403406 raise DatabaseError (
404407 "Command {} unexpectedly closed server side" .format (command_id ),
@@ -407,8 +410,11 @@ def _check_command_not_in_failed_or_closed_state(
407410 },
408411 )
409412 if state == CommandState .FAILED :
413+ error = status .error
414+ error_code = error .error_code if error else "UNKNOWN_ERROR_CODE"
415+ error_message = error .message if error else "UNKNOWN_ERROR_MESSAGE"
410416 raise ServerOperationError (
411- "Command {} failed " .format (command_id ),
417+ "Command failed: {} - {} " .format (error_code , error_message ),
412418 {
413419 "operation-id" : command_id ,
414420 },
@@ -422,16 +428,18 @@ def _wait_until_command_done(
422428 """
423429
424430 final_response : Union [ExecuteStatementResponse , GetStatementResponse ] = response
425-
426- state = final_response .status .state
427431 command_id = CommandId .from_sea_statement_id (final_response .statement_id )
428432
429- while state in [CommandState .PENDING , CommandState .RUNNING ]:
433+ while final_response .status .state in [
434+ CommandState .PENDING ,
435+ CommandState .RUNNING ,
436+ ]:
430437 time .sleep (self .POLL_INTERVAL_SECONDS )
431438 final_response = self ._poll_query (command_id )
432- state = final_response .status .state
433439
434- self ._check_command_not_in_failed_or_closed_state (state , command_id )
440+ self ._check_command_not_in_failed_or_closed_state (
441+ final_response .status , command_id
442+ )
435443
436444 return final_response
437445
@@ -465,7 +473,7 @@ def execute_command(
465473 enforce_embedded_schema_correctness: Whether to enforce schema correctness
466474
467475 Returns:
468- SeaResultSet : A SeaResultSet instance for the executed command
476+ ResultSet : A SeaResultSet instance for the executed command
469477 """
470478
471479 if session_id .backend_type != BackendType .SEA :
@@ -521,14 +529,6 @@ def execute_command(
521529 )
522530 response = ExecuteStatementResponse .from_dict (response_data )
523531 statement_id = response .statement_id
524- if not statement_id :
525- raise ServerOperationError (
526- "Failed to execute command: No statement ID returned" ,
527- {
528- "operation-id" : None ,
529- "diagnostic-info" : None ,
530- },
531- )
532532
533533 command_id = CommandId .from_sea_statement_id (statement_id )
534534
@@ -560,8 +560,6 @@ def cancel_command(self, command_id: CommandId) -> None:
560560 raise ValueError ("Not a valid SEA command ID" )
561561
562562 sea_statement_id = command_id .to_sea_statement_id ()
563- if sea_statement_id is None :
564- raise ValueError ("Not a valid SEA command ID" )
565563
566564 request = CancelStatementRequest (statement_id = sea_statement_id )
567565 self ._http_client ._make_request (
@@ -585,8 +583,6 @@ def close_command(self, command_id: CommandId) -> None:
585583 raise ValueError ("Not a valid SEA command ID" )
586584
587585 sea_statement_id = command_id .to_sea_statement_id ()
588- if sea_statement_id is None :
589- raise ValueError ("Not a valid SEA command ID" )
590586
591587 request = CloseStatementRequest (statement_id = sea_statement_id )
592588 self ._http_client ._make_request (
@@ -604,8 +600,6 @@ def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
604600 raise ValueError ("Not a valid SEA command ID" )
605601
606602 sea_statement_id = command_id .to_sea_statement_id ()
607- if sea_statement_id is None :
608- raise ValueError ("Not a valid SEA command ID" )
609603
610604 request = GetStatementRequest (statement_id = sea_statement_id )
611605 response_data = self ._http_client ._make_request (
@@ -628,7 +622,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
628622 CommandState: The current state of the command
629623
630624 Raises:
631- ProgrammingError : If the command ID is invalid
625+ ValueError : If the command ID is invalid
632626 """
633627
634628 response = self ._poll_query (command_id )
0 commit comments