55import re
66from typing import Any , Dict , Tuple , List , Optional , Union , TYPE_CHECKING , Set
77
8- from databricks .sql .backend .sea .models .base import ExternalLink , ResultManifest
8+ from databricks .sql .backend .sea .models .base import (
9+ ExternalLink ,
10+ ResultManifest ,
11+ StatementStatus ,
12+ )
13+ from databricks .sql .backend .sea .models .responses import GetChunksResponse
914from databricks .sql .backend .sea .utils .constants import (
1015 ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP ,
1116 ResultFormat ,
1419 WaitTimeout ,
1520 MetadataCommands ,
1621)
22+ from databricks .sql .backend .sea .utils .normalize import normalize_sea_type_to_thrift
1723from databricks .sql .thrift_api .TCLIService import ttypes
1824
1925if TYPE_CHECKING :
4551 GetStatementResponse ,
4652 CreateSessionResponse ,
4753)
48- from databricks .sql .backend .sea .models .responses import GetChunksResponse
4954
5055logger = logging .getLogger (__name__ )
5156
5257
5358def _filter_session_configuration (
5459 session_configuration : Optional [Dict [str , Any ]],
5560) -> Dict [str , str ]:
61+ """
62+ Filter and normalise the provided session configuration parameters.
63+
64+ The Statement Execution API supports only a subset of SQL session
65+ configuration options. This helper validates the supplied
66+ ``session_configuration`` dictionary against the allow-list defined in
67+ ``ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP`` and returns a new
68+ dictionary that contains **only** the supported parameters.
69+
70+ Args:
71+ session_configuration: Optional mapping of session configuration
72+ names to their desired values. Key comparison is
73+ case-insensitive.
74+
75+ Returns:
76+ Dict[str, str]: A dictionary containing only the supported
77+ configuration parameters with lower-case keys and string values. If
78+ *session_configuration* is ``None`` or empty, an empty dictionary is
79+ returned.
80+ """
81+
5682 if not session_configuration :
5783 return {}
5884
@@ -143,7 +169,7 @@ def __init__(
143169 http_path = http_path ,
144170 http_headers = http_headers ,
145171 auth_provider = auth_provider ,
146- ssl_options = self . _ssl_options ,
172+ ssl_options = ssl_options ,
147173 ** kwargs ,
148174 )
149175
@@ -275,29 +301,6 @@ def close_session(self, session_id: SessionId) -> None:
275301 data = request_data .to_dict (),
276302 )
277303
278- @staticmethod
279- def get_default_session_configuration_value (name : str ) -> Optional [str ]:
280- """
281- Get the default value for a session configuration parameter.
282-
283- Args:
284- name: The name of the session configuration parameter
285-
286- Returns:
287- The default value if the parameter is supported, None otherwise
288- """
289- return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP .get (name .upper ())
290-
291- @staticmethod
292- def get_allowed_session_configurations () -> List [str ]:
293- """
294- Get the list of allowed session configuration parameters.
295-
296- Returns:
297- List of allowed session configuration parameter names
298- """
299- return list (ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP .keys ())
300-
301304 def _extract_description_from_manifest (
302305 self , manifest : ResultManifest
303306 ) -> List [Tuple ]:
@@ -309,7 +312,7 @@ def _extract_description_from_manifest(
309312 manifest: The ResultManifest object containing schema information
310313
311314 Returns:
312- List[Tuple ]: A list of column tuples
315+ Optional[List ]: A list of column tuples or None if no columns are found
313316 """
314317
315318 schema_data = manifest .schema
@@ -318,15 +321,28 @@ def _extract_description_from_manifest(
318321 columns = []
319322 for col_data in columns_data :
320323 # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
324+ name = col_data .get ("name" , "" )
325+ type_name = col_data .get ("type_name" , "" )
326+
327+ # Normalize SEA type to Thrift conventions before any processing
328+ type_name = normalize_sea_type_to_thrift (type_name , col_data )
329+
330+ # Now strip _TYPE suffix and convert to lowercase
331+ type_name = (
332+ type_name [:- 5 ] if type_name .endswith ("_TYPE" ) else type_name
333+ ).lower ()
334+ precision = col_data .get ("type_precision" )
335+ scale = col_data .get ("type_scale" )
336+
321337 columns .append (
322338 (
323- col_data . get ( " name" , "" ) , # name
324- col_data . get ( " type_name" , "" ) , # type_code
339+ name , # name
340+ type_name , # type_code
325341 None , # display_size (not provided by SEA)
326342 None , # internal_size (not provided by SEA)
327- col_data . get ( " precision" ) , # precision
328- col_data . get ( " scale" ) , # scale
329- col_data . get ( "nullable" , True ) , # null_ok
343+ precision , # precision
344+ scale , # scale
345+ None , # null_ok
330346 )
331347 )
332348
@@ -389,8 +405,9 @@ def _response_to_result_set(
389405 )
390406
391407 def _check_command_not_in_failed_or_closed_state (
392- self , state : CommandState , command_id : CommandId
408+ self , status : StatementStatus , command_id : CommandId
393409 ) -> None :
410+ state = status .state
394411 if state == CommandState .CLOSED :
395412 raise DatabaseError (
396413 "Command {} unexpectedly closed server side" .format (command_id ),
@@ -399,8 +416,11 @@ def _check_command_not_in_failed_or_closed_state(
399416 },
400417 )
401418 if state == CommandState .FAILED :
419+ error = status .error
420+ error_code = error .error_code if error else "UNKNOWN_ERROR_CODE"
421+ error_message = error .message if error else "UNKNOWN_ERROR_MESSAGE"
402422 raise ServerOperationError (
403- "Command {} failed " .format (command_id ),
423+ "Command failed: {} - {} " .format (error_code , error_message ),
404424 {
405425 "operation-id" : command_id ,
406426 },
@@ -414,16 +434,18 @@ def _wait_until_command_done(
414434 """
415435
416436 final_response : Union [ExecuteStatementResponse , GetStatementResponse ] = response
417-
418- state = final_response .status .state
419437 command_id = CommandId .from_sea_statement_id (final_response .statement_id )
420438
421- while state in [CommandState .PENDING , CommandState .RUNNING ]:
439+ while final_response .status .state in [
440+ CommandState .PENDING ,
441+ CommandState .RUNNING ,
442+ ]:
422443 time .sleep (self .POLL_INTERVAL_SECONDS )
423444 final_response = self ._poll_query (command_id )
424- state = final_response .status .state
425445
426- self ._check_command_not_in_failed_or_closed_state (state , command_id )
446+ self ._check_command_not_in_failed_or_closed_state (
447+ final_response .status , command_id
448+ )
427449
428450 return final_response
429451
@@ -457,7 +479,7 @@ def execute_command(
457479 enforce_embedded_schema_correctness: Whether to enforce schema correctness
458480
459481 Returns:
460- SeaResultSet : A SeaResultSet instance for the executed command
482+ ResultSet : A SeaResultSet instance for the executed command
461483 """
462484
463485 if session_id .backend_type != BackendType .SEA :
@@ -513,14 +535,6 @@ def execute_command(
513535 )
514536 response = ExecuteStatementResponse .from_dict (response_data )
515537 statement_id = response .statement_id
516- if not statement_id :
517- raise ServerOperationError (
518- "Failed to execute command: No statement ID returned" ,
519- {
520- "operation-id" : None ,
521- "diagnostic-info" : None ,
522- },
523- )
524538
525539 command_id = CommandId .from_sea_statement_id (statement_id )
526540
@@ -552,8 +566,6 @@ def cancel_command(self, command_id: CommandId) -> None:
552566 raise ValueError ("Not a valid SEA command ID" )
553567
554568 sea_statement_id = command_id .to_sea_statement_id ()
555- if sea_statement_id is None :
556- raise ValueError ("Not a valid SEA command ID" )
557569
558570 request = CancelStatementRequest (statement_id = sea_statement_id )
559571 self ._http_client ._make_request (
@@ -577,8 +589,6 @@ def close_command(self, command_id: CommandId) -> None:
577589 raise ValueError ("Not a valid SEA command ID" )
578590
579591 sea_statement_id = command_id .to_sea_statement_id ()
580- if sea_statement_id is None :
581- raise ValueError ("Not a valid SEA command ID" )
582592
583593 request = CloseStatementRequest (statement_id = sea_statement_id )
584594 self ._http_client ._make_request (
@@ -596,8 +606,6 @@ def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
596606 raise ValueError ("Not a valid SEA command ID" )
597607
598608 sea_statement_id = command_id .to_sea_statement_id ()
599- if sea_statement_id is None :
600- raise ValueError ("Not a valid SEA command ID" )
601609
602610 request = GetStatementRequest (statement_id = sea_statement_id )
603611 response_data = self ._http_client ._make_request (
@@ -620,7 +628,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
620628 CommandState: The current state of the command
621629
622630 Raises:
623- ProgrammingError : If the command ID is invalid
631+ ValueError : If the command ID is invalid
624632 """
625633
626634 response = self ._poll_query (command_id )
0 commit comments