66
77from databricks .sql .backend .sea .utils .constants import (
88 ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP ,
9+ ResultFormat ,
10+ ResultDisposition ,
11+ ResultCompression ,
12+ WaitTimeout ,
913)
1014
1115if TYPE_CHECKING :
1216 from databricks .sql .client import Cursor
1317 from databricks .sql .result_set import ResultSet
14- from databricks .sql .backend .sea .models .responses import GetChunksResponse
1518
1619from databricks .sql .backend .databricks_client import DatabricksClient
1720from databricks .sql .backend .types import (
2124 BackendType ,
2225 ExecuteResponse ,
2326)
24- from databricks .sql .exc import Error , NotSupportedError , ServerOperationError
27+ from databricks .sql .exc import ServerOperationError
2528from databricks .sql .backend .sea .utils .http_client import SeaHttpClient
2629from databricks .sql .thrift_api .TCLIService import ttypes
2730from databricks .sql .types import SSLOptions
28- from databricks .sql .utils import SeaResultSetQueueFactory
29- from databricks .sql .backend .sea .models .base import (
30- ResultData ,
31- ExternalLink ,
32- ResultManifest ,
33- )
3431
3532from databricks .sql .backend .sea .models import (
3633 ExecuteStatementRequest ,
4542 CreateSessionResponse ,
4643 GetChunksResponse ,
4744)
45+ from databricks .sql .backend .sea .models .responses import (
46+ parse_status ,
47+ parse_manifest ,
48+ parse_result ,
49+ )
4850
4951logger = logging .getLogger (__name__ )
5052
@@ -80,9 +82,6 @@ def _filter_session_configuration(
8082class SeaDatabricksClient (DatabricksClient ):
8183 """
8284 Statement Execution API (SEA) implementation of the DatabricksClient interface.
83-
84- This implementation provides session management functionality for SEA,
85- while other operations raise NotImplementedError.
8685 """
8786
8887 # SEA API paths
@@ -92,8 +91,6 @@ class SeaDatabricksClient(DatabricksClient):
9291 STATEMENT_PATH = BASE_PATH + "statements"
9392 STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
9493 CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
95- CHUNKS_PATH_WITH_ID = STATEMENT_PATH + "/{}/result/chunks"
96- CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
9794
9895 def __init__ (
9996 self ,
@@ -126,7 +123,6 @@ def __init__(
126123 )
127124
128125 self ._max_download_threads = kwargs .get ("max_download_threads" , 10 )
129- self .ssl_options = ssl_options
130126
131127 # Extract warehouse ID from http_path
132128 self .warehouse_id = self ._extract_warehouse_id (http_path )
@@ -283,19 +279,6 @@ def get_default_session_configuration_value(name: str) -> Optional[str]:
283279 """
284280 return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP .get (name .upper ())
285281
286- @staticmethod
287- def is_session_configuration_parameter_supported (name : str ) -> bool :
288- """
289- Check if a session configuration parameter is supported.
290-
291- Args:
292- name: The name of the session configuration parameter
293-
294- Returns:
295- True if the parameter is supported, False otherwise
296- """
297- return name .upper () in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP
298-
299282 @staticmethod
300283 def get_allowed_session_configurations () -> List [str ]:
301284 """
@@ -343,92 +326,27 @@ def _results_message_to_execute_response(self, sea_response, command_id):
343326 tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response,
344327 result data object, and manifest object
345328 """
346- # Extract status
347- status_data = sea_response .get ("status" , {})
348- state = CommandState .from_sea_state (status_data .get ("state" , "" ))
349-
350- # Extract description from manifest
351- description = None
352- manifest_data = sea_response .get ("manifest" , {})
353- schema_data = manifest_data .get ("schema" , {})
354- columns_data = schema_data .get ("columns" , [])
355-
356- if columns_data :
357- columns = []
358- for col_data in columns_data :
359- if not isinstance (col_data , dict ):
360- continue
361-
362- # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
363- columns .append (
364- (
365- col_data .get ("name" , "" ), # name
366- col_data .get ("type_name" , "" ), # type_code
367- None , # display_size (not provided by SEA)
368- None , # internal_size (not provided by SEA)
369- col_data .get ("precision" ), # precision
370- col_data .get ("scale" ), # scale
371- col_data .get ("nullable" , True ), # null_ok
372- )
373- )
374- description = columns if columns else None
375329
376- # Check for compression
377- lz4_compressed = manifest_data .get ("result_compression" ) == "LZ4_FRAME"
378-
379- # Initialize result_data_obj and manifest_obj
380- result_data_obj = None
381- manifest_obj = None
382-
383- result_data = sea_response .get ("result" , {})
384- if result_data :
385- # Convert external links
386- external_links = None
387- if "external_links" in result_data :
388- external_links = []
389- for link_data in result_data ["external_links" ]:
390- external_links .append (
391- ExternalLink (
392- external_link = link_data .get ("external_link" , "" ),
393- expiration = link_data .get ("expiration" , "" ),
394- chunk_index = link_data .get ("chunk_index" , 0 ),
395- byte_count = link_data .get ("byte_count" , 0 ),
396- row_count = link_data .get ("row_count" , 0 ),
397- row_offset = link_data .get ("row_offset" , 0 ),
398- next_chunk_index = link_data .get ("next_chunk_index" ),
399- next_chunk_internal_link = link_data .get (
400- "next_chunk_internal_link"
401- ),
402- http_headers = link_data .get ("http_headers" , {}),
403- )
404- )
330+ # Parse the response
331+ status = parse_status (sea_response )
332+ manifest_obj = parse_manifest (sea_response )
333+ result_data_obj = parse_result (sea_response )
405334
406- # Create the result data object
407- result_data_obj = ResultData (
408- data = result_data .get ("data_array" ), external_links = external_links
409- )
335+ # Extract description from manifest schema
336+ description = self ._extract_description_from_manifest (manifest_obj )
410337
411- # Create the manifest object
412- manifest_obj = ResultManifest (
413- format = manifest_data .get ("format" , "" ),
414- schema = manifest_data .get ("schema" , {}),
415- total_row_count = manifest_data .get ("total_row_count" , 0 ),
416- total_byte_count = manifest_data .get ("total_byte_count" , 0 ),
417- total_chunk_count = manifest_data .get ("total_chunk_count" , 0 ),
418- truncated = manifest_data .get ("truncated" , False ),
419- chunks = manifest_data .get ("chunks" ),
420- result_compression = manifest_data .get ("result_compression" ),
421- )
338+ # Check for compression
339+ lz4_compressed = manifest_obj .result_compression == "LZ4_FRAME"
422340
423341 execute_response = ExecuteResponse (
424342 command_id = command_id ,
425- status = state ,
343+ status = status . state ,
426344 description = description ,
427345 has_been_closed_server_side = False ,
428346 lz4_compressed = lz4_compressed ,
429347 is_staging_operation = False ,
430348 arrow_schema_bytes = None ,
431- result_format = manifest_data . get ( " format" ) ,
349+ result_format = manifest_obj . format ,
432350 )
433351
434352 return execute_response , result_data_obj , manifest_obj
@@ -464,6 +382,7 @@ def execute_command(
464382 Returns:
465383 ResultSet: A SeaResultSet instance for the executed command
466384 """
385+
467386 if session_id .backend_type != BackendType .SEA :
468387 raise ValueError ("Not a valid SEA session ID" )
469388
@@ -481,17 +400,25 @@ def execute_command(
481400 )
482401 )
483402
484- format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY"
485- disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE"
486- result_compression = "LZ4_FRAME" if lz4_compression else None
403+ format = (
404+ ResultFormat .ARROW_STREAM if use_cloud_fetch else ResultFormat .JSON_ARRAY
405+ ).value
406+ disposition = (
407+ ResultDisposition .EXTERNAL_LINKS
408+ if use_cloud_fetch
409+ else ResultDisposition .INLINE
410+ ).value
411+ result_compression = (
412+ ResultCompression .LZ4_FRAME if lz4_compression else ResultCompression .NONE
413+ ).value
487414
488415 request = ExecuteStatementRequest (
489416 warehouse_id = self .warehouse_id ,
490417 session_id = sea_session_id ,
491418 statement = operation ,
492419 disposition = disposition ,
493420 format = format ,
494- wait_timeout = "0s" if async_op else "10s" ,
421+ wait_timeout = ( WaitTimeout . ASYNC if async_op else WaitTimeout . SYNC ). value ,
495422 on_wait_timeout = "CONTINUE" ,
496423 row_limit = max_rows ,
497424 parameters = sea_parameters if sea_parameters else None ,
@@ -517,12 +444,11 @@ def execute_command(
517444 # Store the command ID in the cursor
518445 cursor .active_command_id = command_id
519446
520- # If async operation, return None and let the client poll for results
447+ # If async operation, return and let the client poll for results
521448 if async_op :
522449 return None
523450
524451 # For synchronous operation, wait for the statement to complete
525- # Poll until the statement is done
526452 status = response .status
527453 state = status .state
528454
@@ -552,6 +478,7 @@ def cancel_command(self, command_id: CommandId) -> None:
552478 Raises:
553479 ValueError: If the command ID is invalid
554480 """
481+
555482 if command_id .backend_type != BackendType .SEA :
556483 raise ValueError ("Not a valid SEA command ID" )
557484
@@ -574,6 +501,7 @@ def close_command(self, command_id: CommandId) -> None:
574501 Raises:
575502 ValueError: If the command ID is invalid
576503 """
504+
577505 if command_id .backend_type != BackendType .SEA :
578506 raise ValueError ("Not a valid SEA command ID" )
579507
@@ -599,6 +527,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
599527 Raises:
600528 ValueError: If the command ID is invalid
601529 """
530+
602531 if command_id .backend_type != BackendType .SEA :
603532 raise ValueError ("Not a valid SEA command ID" )
604533
@@ -633,6 +562,7 @@ def get_execution_result(
633562 Raises:
634563 ValueError: If the command ID is invalid
635564 """
565+
636566 if command_id .backend_type != BackendType .SEA :
637567 raise ValueError ("Not a valid SEA command ID" )
638568
0 commit comments