11import logging
2+ import uuid
23import time
34import re
4- from typing import Any , Dict , Tuple , List , Optional , Union , TYPE_CHECKING , Set
5+ from typing import Dict , Tuple , List , Optional , Any , Union , TYPE_CHECKING , Set
56
6- from databricks .sql .backend .sea .models .base import ResultManifest
7+ from databricks .sql .backend .sea .models .base import ExternalLink
78from databricks .sql .backend .sea .utils .constants import (
89 ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP ,
910 ResultFormat ,
1011 ResultDisposition ,
1112 ResultCompression ,
1213 WaitTimeout ,
13- MetadataCommands ,
1414)
1515
1616if TYPE_CHECKING :
2525 BackendType ,
2626 ExecuteResponse ,
2727)
28- from databricks .sql .exc import DatabaseError , ServerOperationError
28+ from databricks .sql .exc import ServerOperationError
2929from databricks .sql .backend .sea .utils .http_client import SeaHttpClient
30+ from databricks .sql .thrift_api .TCLIService import ttypes
3031from databricks .sql .types import SSLOptions
3132
3233from databricks .sql .backend .sea .models import (
4041 ExecuteStatementResponse ,
4142 GetStatementResponse ,
4243 CreateSessionResponse ,
44+ GetChunksResponse ,
4345)
4446from databricks .sql .backend .sea .models .responses import (
45- _parse_status ,
46- _parse_manifest ,
47- _parse_result ,
47+ parse_status ,
48+ parse_manifest ,
49+ parse_result ,
4850)
4951
5052logger = logging .getLogger (__name__ )
@@ -90,9 +92,7 @@ class SeaDatabricksClient(DatabricksClient):
9092 STATEMENT_PATH = BASE_PATH + "statements"
9193 STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
9294 CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
93-
94- # SEA constants
95- POLL_INTERVAL_SECONDS = 0.2
95+ CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
9696
9797 def __init__ (
9898 self ,
@@ -124,7 +124,7 @@ def __init__(
124124 http_path ,
125125 )
126126
127- self . _max_download_threads = kwargs . get ( "max_download_threads" , 10 )
127+ super (). __init__ ( ssl_options , ** kwargs )
128128
129129 # Extract warehouse ID from http_path
130130 self .warehouse_id = self ._extract_warehouse_id (http_path )
@@ -136,7 +136,7 @@ def __init__(
136136 http_path = http_path ,
137137 http_headers = http_headers ,
138138 auth_provider = auth_provider ,
139- ssl_options = ssl_options ,
139+ ssl_options = self . _ssl_options ,
140140 ** kwargs ,
141141 )
142142
@@ -291,28 +291,28 @@ def get_allowed_session_configurations() -> List[str]:
291291 """
292292 return list (ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP .keys ())
293293
294- def _extract_description_from_manifest (
295- self , manifest : ResultManifest
296- ) -> Optional [List ]:
294+ def _extract_description_from_manifest (self , manifest_obj ) -> Optional [List ]:
297295 """
298- Extract column description from a manifest object, in the format defined by
299- the spec: https://peps.python.org/pep-0249/#description
296+ Extract column description from a manifest object.
300297
301298 Args:
302- manifest : The ResultManifest object containing schema information
299+ manifest_obj : The ResultManifest object containing schema information
303300
304301 Returns:
305302 Optional[List]: A list of column tuples or None if no columns are found
306303 """
307304
308- schema_data = manifest .schema
305+ schema_data = manifest_obj .schema
309306 columns_data = schema_data .get ("columns" , [])
310307
311308 if not columns_data :
312309 return None
313310
314311 columns = []
315312 for col_data in columns_data :
313+ if not isinstance (col_data , dict ):
314+ continue
315+
316316 # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
317317 columns .append (
318318 (
@@ -328,9 +328,38 @@ def _extract_description_from_manifest(
328328
329329 return columns if columns else None
330330
331- def _results_message_to_execute_response (
332- self , response : GetStatementResponse
333- ) -> ExecuteResponse :
331+ def get_chunk_link (self , statement_id : str , chunk_index : int ) -> ExternalLink :
332+ """
333+ Get links for chunks starting from the specified index.
334+
335+ Args:
336+ statement_id: The statement ID
337+ chunk_index: The starting chunk index
338+
339+ Returns:
340+ ExternalLink: External link for the chunk
341+ """
342+
343+ response_data = self .http_client ._make_request (
344+ method = "GET" ,
345+ path = self .CHUNK_PATH_WITH_ID_AND_INDEX .format (statement_id , chunk_index ),
346+ )
347+ response = GetChunksResponse .from_dict (response_data )
348+
349+ links = response .external_links
350+ link = next ((l for l in links if l .chunk_index == chunk_index ), None )
351+ if not link :
352+ raise ServerOperationError (
353+ f"No link found for chunk index { chunk_index } " ,
354+ {
355+ "operation-id" : statement_id ,
356+ "diagnostic-info" : None ,
357+ },
358+ )
359+
360+ return link
361+
362+ def _results_message_to_execute_response (self , sea_response , command_id ):
334363 """
335364 Convert a SEA response to an ExecuteResponse and extract result data.
336365
@@ -339,65 +368,33 @@ def _results_message_to_execute_response(
339368 command_id: The command ID
340369
341370 Returns:
342- ExecuteResponse: The normalized execute response
371+ tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response,
372+ result data object, and manifest object
343373 """
344374
375+ # Parse the response
376+ status = parse_status (sea_response )
377+ manifest_obj = parse_manifest (sea_response )
378+ result_data_obj = parse_result (sea_response )
379+
345380 # Extract description from manifest schema
346- description = self ._extract_description_from_manifest (response . manifest )
381+ description = self ._extract_description_from_manifest (manifest_obj )
347382
348383 # Check for compression
349- lz4_compressed = (
350- response .manifest .result_compression == ResultCompression .LZ4_FRAME
351- )
384+ lz4_compressed = manifest_obj .result_compression == "LZ4_FRAME"
352385
353386 execute_response = ExecuteResponse (
354- command_id = CommandId . from_sea_statement_id ( response . statement_id ) ,
355- status = response . status .state ,
387+ command_id = command_id ,
388+ status = status .state ,
356389 description = description ,
357390 has_been_closed_server_side = False ,
358391 lz4_compressed = lz4_compressed ,
359392 is_staging_operation = False ,
360393 arrow_schema_bytes = None ,
361- result_format = response . manifest .format ,
394+ result_format = manifest_obj .format ,
362395 )
363396
364- return execute_response
365-
366- def _check_command_not_in_failed_or_closed_state (
367- self , state : CommandState , command_id : CommandId
368- ) -> None :
369- if state == CommandState .CLOSED :
370- raise DatabaseError (
371- "Command {} unexpectedly closed server side" .format (command_id ),
372- {
373- "operation-id" : command_id ,
374- },
375- )
376- if state == CommandState .FAILED :
377- raise ServerOperationError (
378- "Command {} failed" .format (command_id ),
379- {
380- "operation-id" : command_id ,
381- },
382- )
383-
384- def _wait_until_command_done (
385- self , response : ExecuteStatementResponse
386- ) -> CommandState :
387- """
388- Wait until a command is done.
389- """
390-
391- state = response .status .state
392- command_id = CommandId .from_sea_statement_id (response .statement_id )
393-
394- while state in [CommandState .PENDING , CommandState .RUNNING ]:
395- time .sleep (self .POLL_INTERVAL_SECONDS )
396- state = self .get_query_state (command_id )
397-
398- self ._check_command_not_in_failed_or_closed_state (state , command_id )
399-
400- return state
397+ return execute_response , result_data_obj , manifest_obj
401398
402399 def execute_command (
403400 self ,
@@ -408,7 +405,7 @@ def execute_command(
408405 lz4_compression : bool ,
409406 cursor : "Cursor" ,
410407 use_cloud_fetch : bool ,
411- parameters : List [ Dict [ str , Any ]] ,
408+ parameters : List ,
412409 async_op : bool ,
413410 enforce_embedded_schema_correctness : bool ,
414411 ) -> Union ["ResultSet" , None ]:
@@ -442,9 +439,9 @@ def execute_command(
442439 for param in parameters :
443440 sea_parameters .append (
444441 StatementParameter (
445- name = param [ " name" ] ,
446- value = param [ " value" ] ,
447- type = param [ " type" ] if "type" in param else None ,
442+ name = param . name ,
443+ value = param . value ,
444+ type = param . type if hasattr ( param , "type" ) else None ,
448445 )
449446 )
450447
@@ -496,7 +493,24 @@ def execute_command(
496493 if async_op :
497494 return None
498495
499- self ._wait_until_command_done (response )
496+ # For synchronous operation, wait for the statement to complete
497+ status = response .status
498+ state = status .state
499+
500+ # Keep polling until we reach a terminal state
501+ while state in [CommandState .PENDING , CommandState .RUNNING ]:
502+ time .sleep (0.5 ) # add a small delay to avoid excessive API calls
503+ state = self .get_query_state (command_id )
504+
505+ if state != CommandState .SUCCEEDED :
506+ raise ServerOperationError (
507+ f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
508+ {
509+ "operation-id" : command_id .to_sea_statement_id (),
510+ "diagnostic-info" : None ,
511+ },
512+ )
513+
500514 return self .get_execution_result (command_id , cursor )
501515
502516 def cancel_command (self , command_id : CommandId ) -> None :
@@ -608,21 +622,25 @@ def get_execution_result(
608622 path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
609623 data = request .to_dict (),
610624 )
611- response = GetStatementResponse .from_dict (response_data )
612625
613626 # Create and return a SeaResultSet
614627 from databricks .sql .result_set import SeaResultSet
615628
616- execute_response = self ._results_message_to_execute_response (response )
629+ # Convert the response to an ExecuteResponse and extract result data
630+ (
631+ execute_response ,
632+ result_data ,
633+ manifest ,
634+ ) = self ._results_message_to_execute_response (response_data , command_id )
617635
618636 return SeaResultSet (
619637 connection = cursor .connection ,
620638 execute_response = execute_response ,
621639 sea_client = self ,
622640 buffer_size_bytes = cursor .buffer_size_bytes ,
623641 arraysize = cursor .arraysize ,
624- result_data = response . result ,
625- manifest = response . manifest ,
642+ result_data = result_data ,
643+ manifest = manifest ,
626644 )
627645
628646 # == Metadata Operations ==
@@ -636,7 +654,7 @@ def get_catalogs(
636654 ) -> "ResultSet" :
637655 """Get available catalogs by executing 'SHOW CATALOGS'."""
638656 result = self .execute_command (
639- operation = MetadataCommands . SHOW_CATALOGS . value ,
657+ operation = "SHOW CATALOGS" ,
640658 session_id = session_id ,
641659 max_rows = max_rows ,
642660 max_bytes = max_bytes ,
@@ -663,10 +681,10 @@ def get_schemas(
663681 if not catalog_name :
664682 raise ValueError ("Catalog name is required for get_schemas" )
665683
666- operation = MetadataCommands . SHOW_SCHEMAS . value . format ( catalog_name )
684+ operation = f"SHOW SCHEMAS IN ` { catalog_name } `"
667685
668686 if schema_name :
669- operation += MetadataCommands . LIKE_PATTERN . value . format ( schema_name )
687+ operation += f" LIKE ' { schema_name } '"
670688
671689 result = self .execute_command (
672690 operation = operation ,
@@ -698,19 +716,17 @@ def get_tables(
698716 if not catalog_name :
699717 raise ValueError ("Catalog name is required for get_tables" )
700718
701- operation = (
702- MetadataCommands . SHOW_TABLES_ALL_CATALOGS . value
719+ operation = "SHOW TABLES IN " + (
720+ "ALL CATALOGS"
703721 if catalog_name in [None , "*" , "%" ]
704- else MetadataCommands .SHOW_TABLES .value .format (
705- MetadataCommands .CATALOG_SPECIFIC .value .format (catalog_name )
706- )
722+ else f"CATALOG `{ catalog_name } `"
707723 )
708724
709725 if schema_name :
710- operation += MetadataCommands . SCHEMA_LIKE_PATTERN . value . format ( schema_name )
726+ operation += f" SCHEMA LIKE ' { schema_name } '"
711727
712728 if table_name :
713- operation += MetadataCommands . LIKE_PATTERN . value . format ( table_name )
729+ operation += f" LIKE ' { table_name } '"
714730
715731 result = self .execute_command (
716732 operation = operation ,
@@ -726,7 +742,7 @@ def get_tables(
726742 )
727743 assert result is not None , "execute_command returned None in synchronous mode"
728744
729- # Apply client-side filtering by table_types
745+ # Apply client-side filtering by table_types if specified
730746 from databricks .sql .backend .filters import ResultSetFilter
731747
732748 result = ResultSetFilter .filter_tables_by_type (result , table_types )
@@ -748,16 +764,16 @@ def get_columns(
748764 if not catalog_name :
749765 raise ValueError ("Catalog name is required for get_columns" )
750766
751- operation = MetadataCommands . SHOW_COLUMNS . value . format ( catalog_name )
767+ operation = f"SHOW COLUMNS IN CATALOG ` { catalog_name } `"
752768
753769 if schema_name :
754- operation += MetadataCommands . SCHEMA_LIKE_PATTERN . value . format ( schema_name )
770+ operation += f" SCHEMA LIKE ' { schema_name } '"
755771
756772 if table_name :
757- operation += MetadataCommands . TABLE_LIKE_PATTERN . value . format ( table_name )
773+ operation += f" TABLE LIKE ' { table_name } '"
758774
759775 if column_name :
760- operation += MetadataCommands . LIKE_PATTERN . value . format ( column_name )
776+ operation += f" LIKE ' { column_name } '"
761777
762778 result = self .execute_command (
763779 operation = operation ,
0 commit comments