11from abc import ABC , abstractmethod
2- from typing import List , Optional , Any , Union
2+ from typing import List , Optional , Any , Union , TYPE_CHECKING
33
44import logging
55import time
66import pandas
77
8+ from databricks .sql .backend .types import CommandId , CommandState
9+
810try :
911 import pyarrow
1012except ImportError :
1113 pyarrow = None
1214
15+ if TYPE_CHECKING :
16+ from databricks .sql .backend .databricks_client import DatabricksClient
17+ from databricks .sql .backend .thrift_backend import ThriftDatabricksClient
18+ from databricks .sql .client import Connection
19+
1320from databricks .sql .thrift_api .TCLIService import ttypes
1421from databricks .sql .types import Row
1522from databricks .sql .exc import Error , RequestError , CursorAlreadyClosedError
@@ -25,10 +32,30 @@ class ResultSet(ABC):
2532 This class defines the interface that all concrete result set implementations must follow.
2633 """
2734
28- def __init__ (self , connection , backend , arraysize : int , buffer_size_bytes : int ):
29- """Initialize the base ResultSet with common properties."""
35+ def __init__ (
36+ self ,
37+ connection : "Connection" ,
38+ backend : "DatabricksClient" ,
39+ command_id : CommandId ,
40+ op_state : Optional [CommandState ],
41+ has_been_closed_server_side : bool ,
42+ arraysize : int ,
43+ buffer_size_bytes : int ,
44+ ):
45+ """
46+ A ResultSet manages the results of a single command.
47+
48+ :param connection: The parent connection that was used to execute this command
49+ :param backend: The specialised backend client to be invoked in the fetch phase
50+ :param execute_response: A `ExecuteResponse` class returned by a command execution
51+ :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch
52+ amount :param arraysize: The max number of rows to fetch at a time (PEP-249)
53+ """
54+ self .command_id = command_id
55+ self .op_state = op_state
56+ self .has_been_closed_server_side = has_been_closed_server_side
3057 self .connection = connection
31- self .backend = backend # Store the backend client directly
58+ self .backend = backend
3259 self .arraysize = arraysize
3360 self .buffer_size_bytes = buffer_size_bytes
3461 self ._next_row_index = 0
@@ -83,20 +110,36 @@ def fetchall_arrow(self) -> Any:
83110 """Fetch all remaining rows as an Arrow table."""
84111 pass
85112
86- @abstractmethod
87113 def close (self ) -> None :
88- """Close the result set and release any resources."""
89- pass
114+ """
115+ Close the result set.
116+
117+ If the connection has not been closed, and the result set has not already
118+ been closed on the server for some other reason, issue a request to the server to close it.
119+ """
120+ try :
121+ if (
122+ self .op_state != CommandState .CLOSED
123+ and not self .has_been_closed_server_side
124+ and self .connection .open
125+ ):
126+ self .backend .close_command (self .command_id )
127+ except RequestError as e :
128+ if isinstance (e .args [1 ], CursorAlreadyClosedError ):
129+ logger .info ("Operation was canceled by a prior request" )
130+ finally :
131+ self .has_been_closed_server_side = True
132+ self .op_state = CommandState .CLOSED
90133
91134
92135class ThriftResultSet (ResultSet ):
93136 """ResultSet implementation for the Thrift backend."""
94137
95138 def __init__ (
96139 self ,
97- connection ,
140+ connection : "Connection" ,
98141 execute_response : ExecuteResponse ,
99- thrift_client , # Pass the specific ThriftDatabricksClient instance
142+ thrift_client : " ThriftDatabricksClient" ,
100143 buffer_size_bytes : int = 104857600 ,
101144 arraysize : int = 10000 ,
102145 use_cloud_fetch : bool = True ,
@@ -112,11 +155,20 @@ def __init__(
112155 arraysize: Default number of rows to fetch
113156 use_cloud_fetch: Whether to use cloud fetch for retrieving results
114157 """
115- super ().__init__ (connection , thrift_client , arraysize , buffer_size_bytes )
158+ command_id = execute_response .command_id
159+ op_state = CommandState .from_thrift_state (execute_response .status )
160+ has_been_closed_server_side = execute_response .has_been_closed_server_side
161+ super ().__init__ (
162+ connection ,
163+ thrift_client ,
164+ command_id ,
165+ op_state ,
166+ has_been_closed_server_side ,
167+ arraysize ,
168+ buffer_size_bytes ,
169+ )
116170
117171 # Initialize ThriftResultSet-specific attributes
118- self .command_id = execute_response .command_id
119- self .op_state = execute_response .status
120172 self .has_been_closed_server_side = execute_response .has_been_closed_server_side
121173 self .has_more_rows = execute_response .has_more_rows
122174 self .lz4_compressed = execute_response .lz4_compressed
@@ -127,11 +179,15 @@ def __init__(
127179
128180 # Initialize results queue
129181 if execute_response .arrow_queue :
182+ # In this case the server has taken the fast path and returned an initial batch of
183+ # results
130184 self .results = execute_response .arrow_queue
131185 else :
186+ # In this case, there are results waiting on the server so we fetch now for simplicity
132187 self ._fill_results_buffer ()
133188
134189 def _fill_results_buffer (self ):
190+ # At initialization or if the server does not have cloud fetch result links available
135191 results , has_more_rows = self .backend .fetch_results (
136192 command_id = self .command_id ,
137193 max_rows = self .arraysize ,
@@ -336,28 +392,24 @@ def fetchmany(self, size: int) -> List[Row]:
336392 else :
337393 return self ._convert_arrow_table (self .fetchmany_arrow (size ))
338394
339- def close (self ) -> None :
340- """
341- Close the cursor.
342-
343- If the connection has not been closed, and the cursor has not already
344- been closed on the server for some other reason, issue a request to the server to close it.
345- """
346- try :
347- if (
348- self .op_state != ttypes .TOperationState .CLOSED_STATE
349- and not self .has_been_closed_server_side
350- and self .connection .open
351- ):
352- self .backend .close_command (self .command_id )
353- except RequestError as e :
354- if isinstance (e .args [1 ], CursorAlreadyClosedError ):
355- logger .info ("Operation was canceled by a prior request" )
356- finally :
357- self .has_been_closed_server_side = True
358- self .op_state = ttypes .TOperationState .CLOSED_STATE
359-
360395 @property
361396 def is_staging_operation (self ) -> bool :
362397 """Whether this result set represents a staging operation."""
363398 return self ._is_staging_operation
399+
400+ @staticmethod
401+ def _get_schema_description (table_schema_message ):
402+ """
403+ Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249
404+ """
405+
406+ def map_col_type (type_ ):
407+ if type_ .startswith ("decimal" ):
408+ return "decimal"
409+ else :
410+ return type_
411+
412+ return [
413+ (column .name , map_col_type (column .datatype ), None , None , None , None , None )
414+ for column in table_schema_message .columns
415+ ]
0 commit comments