2020from databricks .sql .utils import (
2121 ColumnTable ,
2222 ColumnQueue ,
23+ ResultSetQueue ,
2324)
2425from databricks .sql .backend .types import CommandId , CommandState , ExecuteResponse
2526
@@ -36,14 +37,12 @@ class ResultSet(ABC):
3637 def __init__ (
3738 self ,
3839 connection : "Connection" ,
39- backend : "DatabricksClient" ,
4040 arraysize : int ,
4141 buffer_size_bytes : int ,
4242 command_id : CommandId ,
4343 status : CommandState ,
4444 has_been_closed_server_side : bool = False ,
4545 is_direct_results : bool = False ,
46- results_queue = None ,
4746 description : List [Tuple ] = [],
4847 is_staging_operation : bool = False ,
4948 lz4_compressed : bool = False ,
@@ -54,32 +53,30 @@ def __init__(
5453
5554 Parameters:
5655 :param connection: The parent connection
57- :param backend: The backend client
5856 :param arraysize: The max number of rows to fetch at a time (PEP-249)
5957 :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch
6058 :param command_id: The command ID
6159 :param status: The command status
6260 :param has_been_closed_server_side: Whether the command has been closed on the server
6361 :param is_direct_results: Whether the command has more rows
64- :param results_queue: The results queue
6562 :param description: column description of the results
6663 :param is_staging_operation: Whether the command is a staging operation
6764 """
6865
69- self .connection = connection
70- self .backend = backend
71- self .arraysize = arraysize
72- self .buffer_size_bytes = buffer_size_bytes
73- self ._next_row_index = 0
74- self .description = description
75- self .command_id = command_id
76- self .status = status
77- self .has_been_closed_server_side = has_been_closed_server_side
78- self .is_direct_results = is_direct_results
79- self .results = results_queue
80- self ._is_staging_operation = is_staging_operation
81- self .lz4_compressed = lz4_compressed
82- self ._arrow_schema_bytes = arrow_schema_bytes
66+ self .connection : "Connection" = connection
67+ self .backend : DatabricksClient = connection . session . backend
68+ self .arraysize : int = arraysize
69+ self .buffer_size_bytes : int = buffer_size_bytes
70+ self ._next_row_index : int = 0
71+ self .description : List [ Tuple ] = description
72+ self .command_id : CommandId = command_id
73+ self .status : CommandState = status
74+ self .has_been_closed_server_side : bool = has_been_closed_server_side
75+ self .is_direct_results : bool = is_direct_results
76+ self .results : Optional [ ResultSetQueue ] = None # Children will set this
77+ self ._is_staging_operation : bool = is_staging_operation
78+ self .lz4_compressed : bool = lz4_compressed
79+ self ._arrow_schema_bytes : Optional [ bytes ] = arrow_schema_bytes
8380
8481 def __iter__ (self ):
8582 while True :
@@ -190,7 +187,6 @@ def __init__(
190187 self ,
191188 connection : "Connection" ,
192189 execute_response : "ExecuteResponse" ,
193- thrift_client : "ThriftDatabricksClient" ,
194190 buffer_size_bytes : int = 104857600 ,
195191 arraysize : int = 10000 ,
196192 use_cloud_fetch : bool = True ,
@@ -205,7 +201,6 @@ def __init__(
205201 Parameters:
206202 :param connection: The parent connection
207203 :param execute_response: Response from the execute command
208- :param thrift_client: The ThriftDatabricksClient instance for direct access
209204 :param buffer_size_bytes: Buffer size for fetching results
210205 :param arraysize: Default number of rows to fetch
211206 :param use_cloud_fetch: Whether to use cloud fetch for retrieving results
@@ -238,20 +233,28 @@ def __init__(
238233 # Call parent constructor with common attributes
239234 super ().__init__ (
240235 connection = connection ,
241- backend = thrift_client ,
242236 arraysize = arraysize ,
243237 buffer_size_bytes = buffer_size_bytes ,
244238 command_id = execute_response .command_id ,
245239 status = execute_response .status ,
246240 has_been_closed_server_side = execute_response .has_been_closed_server_side ,
247241 is_direct_results = is_direct_results ,
248- results_queue = results_queue ,
249242 description = execute_response .description ,
250243 is_staging_operation = execute_response .is_staging_operation ,
251244 lz4_compressed = execute_response .lz4_compressed ,
252245 arrow_schema_bytes = execute_response .arrow_schema_bytes ,
253246 )
254247
248+ # Assert that the backend is of the correct type
249+ from databricks .sql .backend .thrift_backend import ThriftDatabricksClient
250+
251+ assert isinstance (
252+ self .backend , ThriftDatabricksClient
253+ ), "Backend must be a ThriftDatabricksClient"
254+
255+ # Set the results queue
256+ self .results = results_queue
257+
255258 # Initialize results queue if not provided
256259 if not self .results :
257260 self ._fill_results_buffer ()
@@ -307,6 +310,10 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
307310 """
308311 if size < 0 :
309312 raise ValueError ("size argument for fetchmany is %s but must be >= 0" , size )
313+
314+ if self .results is None :
315+ raise RuntimeError ("Results queue is not initialized" )
316+
310317 results = self .results .next_n_rows (size )
311318 n_remaining_rows = size - results .num_rows
312319 self ._next_row_index += results .num_rows
@@ -332,6 +339,9 @@ def fetchmany_columnar(self, size: int):
332339 if size < 0 :
333340 raise ValueError ("size argument for fetchmany is %s but must be >= 0" , size )
334341
342+ if self .results is None :
343+ raise RuntimeError ("Results queue is not initialized" )
344+
335345 results = self .results .next_n_rows (size )
336346 n_remaining_rows = size - results .num_rows
337347 self ._next_row_index += results .num_rows
@@ -351,6 +361,9 @@ def fetchmany_columnar(self, size: int):
351361
352362 def fetchall_arrow (self ) -> "pyarrow.Table" :
353363 """Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
364+ if self .results is None :
365+ raise RuntimeError ("Results queue is not initialized" )
366+
354367 results = self .results .remaining_rows ()
355368 self ._next_row_index += results .num_rows
356369
@@ -377,6 +390,9 @@ def fetchall_arrow(self) -> "pyarrow.Table":
377390
378391 def fetchall_columnar (self ):
379392 """Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
393+ if self .results is None :
394+ raise RuntimeError ("Results queue is not initialized" )
395+
380396 results = self .results .remaining_rows ()
381397 self ._next_row_index += results .num_rows
382398
@@ -393,6 +409,9 @@ def fetchone(self) -> Optional[Row]:
393409 Fetch the next row of a query result set, returning a single sequence,
394410 or None when no more data is available.
395411 """
412+ if self .results is None :
413+ raise RuntimeError ("Results queue is not initialized" )
414+
396415 if isinstance (self .results , ColumnQueue ):
397416 res = self ._convert_columnar_table (self .fetchmany_columnar (1 ))
398417 else :
0 commit comments