55import re
66from typing import Any , Dict , Tuple , List , Optional , Union , TYPE_CHECKING , Set
77
8- from databricks .sql .backend .sea .models .base import ExternalLink , ResultManifest
8+ from databricks .sql .backend .sea .models .base import (
9+ ExternalLink ,
10+ ResultManifest ,
11+ StatementStatus ,
12+ )
13+ from databricks .sql .backend .sea .models .responses import GetChunksResponse
914from databricks .sql .backend .sea .utils .constants import (
1015 ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP ,
1116 ResultFormat ,
4550 GetStatementResponse ,
4651 CreateSessionResponse ,
4752)
48- from databricks .sql .backend .sea .models .responses import GetChunksResponse
4953
5054logger = logging .getLogger (__name__ )
5155
5256
5357def _filter_session_configuration (
5458 session_configuration : Optional [Dict [str , Any ]],
5559) -> Dict [str , str ]:
60+ """
61+ Filter and normalise the provided session configuration parameters.
62+
63+ The Statement Execution API supports only a subset of SQL session
64+ configuration options. This helper validates the supplied
65+ ``session_configuration`` dictionary against the allow-list defined in
66+ ``ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP`` and returns a new
67+ dictionary that contains **only** the supported parameters.
68+
69+ Args:
70+ session_configuration: Optional mapping of session configuration
71+ names to their desired values. Key comparison is
72+ case-insensitive.
73+
74+ Returns:
75+ Dict[str, str]: A dictionary containing only the supported
76+ configuration parameters with lower-case keys and string values. If
77+ *session_configuration* is ``None`` or empty, an empty dictionary is
78+ returned.
79+ """
80+
5681 if not session_configuration :
5782 return {}
5883
@@ -143,7 +168,7 @@ def __init__(
143168 http_path = http_path ,
144169 http_headers = http_headers ,
145170 auth_provider = auth_provider ,
146- ssl_options = self . _ssl_options ,
171+ ssl_options = ssl_options ,
147172 ** kwargs ,
148173 )
149174
@@ -275,29 +300,6 @@ def close_session(self, session_id: SessionId) -> None:
275300 data = request_data .to_dict (),
276301 )
277302
278- @staticmethod
279- def get_default_session_configuration_value (name : str ) -> Optional [str ]:
280- """
281- Get the default value for a session configuration parameter.
282-
283- Args:
284- name: The name of the session configuration parameter
285-
286- Returns:
287- The default value if the parameter is supported, None otherwise
288- """
289- return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP .get (name .upper ())
290-
291- @staticmethod
292- def get_allowed_session_configurations () -> List [str ]:
293- """
294- Get the list of allowed session configuration parameters.
295-
296- Returns:
297- List of allowed session configuration parameter names
298- """
299- return list (ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP .keys ())
300-
301303 def _extract_description_from_manifest (
302304 self , manifest : ResultManifest
303305 ) -> List [Tuple ]:
@@ -309,7 +311,7 @@ def _extract_description_from_manifest(
309311 manifest: The ResultManifest object containing schema information
310312
311313 Returns:
312- List[Tuple ]: A list of column tuples
314+ Optional[List ]: A list of column tuples or None if no columns are found
313315 """
314316
315317 schema_data = manifest .schema
@@ -318,15 +320,23 @@ def _extract_description_from_manifest(
318320 columns = []
319321 for col_data in columns_data :
320322 # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
323+ name = col_data .get ("name" , "" )
324+ type_name = col_data .get ("type_name" , "" )
325+ type_name = (
326+ type_name [:- 5 ] if type_name .endswith ("_TYPE" ) else type_name
327+ ).lower ()
328+ precision = col_data .get ("type_precision" )
329+ scale = col_data .get ("type_scale" )
330+
321331 columns .append (
322332 (
323- col_data . get ( " name" , "" ) , # name
324- col_data . get ( " type_name" , "" ) , # type_code
333+ name , # name
334+ type_name , # type_code
325335 None , # display_size (not provided by SEA)
326336 None , # internal_size (not provided by SEA)
327- col_data . get ( " precision" ) , # precision
328- col_data . get ( " scale" ) , # scale
329- col_data . get ( "nullable" , True ) , # null_ok
337+ precision , # precision
338+ scale , # scale
339+ None , # null_ok
330340 )
331341 )
332342
@@ -389,8 +399,9 @@ def _response_to_result_set(
389399 )
390400
391401 def _check_command_not_in_failed_or_closed_state (
392- self , state : CommandState , command_id : CommandId
402+ self , status : StatementStatus , command_id : CommandId
393403 ) -> None :
404+ state = status .state
394405 if state == CommandState .CLOSED :
395406 raise DatabaseError (
396407 "Command {} unexpectedly closed server side" .format (command_id ),
@@ -399,8 +410,11 @@ def _check_command_not_in_failed_or_closed_state(
399410 },
400411 )
401412 if state == CommandState .FAILED :
413+ error = status .error
414+ error_code = error .error_code if error else "UNKNOWN_ERROR_CODE"
415+ error_message = error .message if error else "UNKNOWN_ERROR_MESSAGE"
402416 raise ServerOperationError (
403- "Command {} failed " .format (command_id ),
417+ "Command failed: {} - {} " .format (error_code , error_message ),
404418 {
405419 "operation-id" : command_id ,
406420 },
@@ -414,16 +428,18 @@ def _wait_until_command_done(
414428 """
415429
416430 final_response : Union [ExecuteStatementResponse , GetStatementResponse ] = response
417-
418- state = final_response .status .state
419431 command_id = CommandId .from_sea_statement_id (final_response .statement_id )
420432
421- while state in [CommandState .PENDING , CommandState .RUNNING ]:
433+ while final_response .status .state in [
434+ CommandState .PENDING ,
435+ CommandState .RUNNING ,
436+ ]:
422437 time .sleep (self .POLL_INTERVAL_SECONDS )
423438 final_response = self ._poll_query (command_id )
424- state = final_response .status .state
425439
426- self ._check_command_not_in_failed_or_closed_state (state , command_id )
440+ self ._check_command_not_in_failed_or_closed_state (
441+ final_response .status , command_id
442+ )
427443
428444 return final_response
429445
@@ -457,7 +473,7 @@ def execute_command(
457473 enforce_embedded_schema_correctness: Whether to enforce schema correctness
458474
459475 Returns:
460- SeaResultSet : A SeaResultSet instance for the executed command
476+ ResultSet : A SeaResultSet instance for the executed command
461477 """
462478
463479 if session_id .backend_type != BackendType .SEA :
@@ -513,14 +529,6 @@ def execute_command(
513529 )
514530 response = ExecuteStatementResponse .from_dict (response_data )
515531 statement_id = response .statement_id
516- if not statement_id :
517- raise ServerOperationError (
518- "Failed to execute command: No statement ID returned" ,
519- {
520- "operation-id" : None ,
521- "diagnostic-info" : None ,
522- },
523- )
524532
525533 command_id = CommandId .from_sea_statement_id (statement_id )
526534
@@ -552,8 +560,6 @@ def cancel_command(self, command_id: CommandId) -> None:
552560 raise ValueError ("Not a valid SEA command ID" )
553561
554562 sea_statement_id = command_id .to_sea_statement_id ()
555- if sea_statement_id is None :
556- raise ValueError ("Not a valid SEA command ID" )
557563
558564 request = CancelStatementRequest (statement_id = sea_statement_id )
559565 self ._http_client ._make_request (
@@ -577,8 +583,6 @@ def close_command(self, command_id: CommandId) -> None:
577583 raise ValueError ("Not a valid SEA command ID" )
578584
579585 sea_statement_id = command_id .to_sea_statement_id ()
580- if sea_statement_id is None :
581- raise ValueError ("Not a valid SEA command ID" )
582586
583587 request = CloseStatementRequest (statement_id = sea_statement_id )
584588 self ._http_client ._make_request (
@@ -596,8 +600,6 @@ def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
596600 raise ValueError ("Not a valid SEA command ID" )
597601
598602 sea_statement_id = command_id .to_sea_statement_id ()
599- if sea_statement_id is None :
600- raise ValueError ("Not a valid SEA command ID" )
601603
602604 request = GetStatementRequest (statement_id = sea_statement_id )
603605 response_data = self ._http_client ._make_request (
@@ -620,7 +622,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
620622 CommandState: The current state of the command
621623
622624 Raises:
623- ProgrammingError : If the command ID is invalid
625+ ValueError : If the command ID is invalid
624626 """
625627
626628 response = self ._poll_query (command_id )
0 commit comments