Skip to content

Commit 89f8680

Browse files
stop passing client to ResultSet, infer from connection
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent abf9aab commit 89f8680

File tree

9 files changed

+159
-99
lines changed

9 files changed

+159
-99
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,6 @@ def get_execution_result(
620620
return SeaResultSet(
621621
connection=cursor.connection,
622622
execute_response=execute_response,
623-
sea_client=self,
624623
result_data=response.result,
625624
manifest=response.manifest,
626625
buffer_size_bytes=cursor.buffer_size_bytes,

src/databricks/sql/backend/sea/result_set.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def __init__(
3131
self,
3232
connection: Connection,
3333
execute_response: ExecuteResponse,
34-
sea_client: SeaDatabricksClient,
3534
result_data: ResultData,
3635
manifest: ResultManifest,
3736
buffer_size_bytes: int = 104857600,
@@ -43,7 +42,6 @@ def __init__(
4342
Args:
4443
connection: The parent connection
4544
execute_response: Response from the execute command
46-
sea_client: The SeaDatabricksClient instance for direct access
4745
buffer_size_bytes: Buffer size for fetching results
4846
arraysize: Default number of rows to fetch
4947
result_data: Result data from SEA response
@@ -56,32 +54,38 @@ def __init__(
5654
if statement_id is None:
5755
raise ValueError("Command ID is not a SEA statement ID")
5856

59-
results_queue = SeaResultSetQueueFactory.build_queue(
60-
result_data,
61-
self.manifest,
62-
statement_id,
63-
description=execute_response.description,
64-
max_download_threads=sea_client.max_download_threads,
65-
sea_client=sea_client,
66-
lz4_compressed=execute_response.lz4_compressed,
67-
)
68-
6957
# Call parent constructor with common attributes
7058
super().__init__(
7159
connection=connection,
72-
backend=sea_client,
7360
arraysize=arraysize,
7461
buffer_size_bytes=buffer_size_bytes,
7562
command_id=execute_response.command_id,
7663
status=execute_response.status,
7764
has_been_closed_server_side=execute_response.has_been_closed_server_side,
78-
results_queue=results_queue,
7965
description=execute_response.description,
8066
is_staging_operation=execute_response.is_staging_operation,
8167
lz4_compressed=execute_response.lz4_compressed,
8268
arrow_schema_bytes=execute_response.arrow_schema_bytes,
8369
)
8470

71+
# Assert that the backend is of the correct type
72+
assert isinstance(
73+
self.backend, SeaDatabricksClient
74+
), "Backend must be a SeaDatabricksClient"
75+
76+
results_queue = SeaResultSetQueueFactory.build_queue(
77+
result_data,
78+
self.manifest,
79+
statement_id,
80+
description=execute_response.description,
81+
max_download_threads=self.backend.max_download_threads,
82+
sea_client=self.backend,
83+
lz4_compressed=execute_response.lz4_compressed,
84+
)
85+
86+
# Set the results queue
87+
self.results = results_queue
88+
8589
def _convert_json_types(self, row: List[str]) -> List[Any]:
8690
"""
8791
Convert string values in the row to appropriate Python types based on column metadata.
@@ -160,6 +164,9 @@ def fetchmany_json(self, size: int) -> List[List[str]]:
160164
if size < 0:
161165
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")
162166

167+
if self.results is None:
168+
raise RuntimeError("Results queue is not initialized")
169+
163170
results = self.results.next_n_rows(size)
164171
self._next_row_index += len(results)
165172

@@ -173,6 +180,9 @@ def fetchall_json(self) -> List[List[str]]:
173180
Columnar table containing all remaining rows
174181
"""
175182

183+
if self.results is None:
184+
raise RuntimeError("Results queue is not initialized")
185+
176186
results = self.results.remaining_rows()
177187
self._next_row_index += len(results)
178188

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
Optional,
1313
Any,
1414
Callable,
15-
cast,
1615
TYPE_CHECKING,
1716
)
1817

1918
if TYPE_CHECKING:
2019
from databricks.sql.backend.sea.result_set import SeaResultSet
2120

22-
from databricks.sql.backend.types import ExecuteResponse
21+
from databricks.sql.backend.types import ExecuteResponse, CommandId, CommandState
2322

2423
logger = logging.getLogger(__name__)
2524

@@ -45,6 +44,9 @@ def _filter_sea_result_set(
4544
"""
4645

4746
# Get all remaining rows
47+
if result_set.results is None:
48+
raise RuntimeError("Results queue is not initialized")
49+
4850
all_rows = result_set.results.remaining_rows()
4951

5052
# Filter rows
@@ -79,7 +81,6 @@ def _filter_sea_result_set(
7981
filtered_result_set = SeaResultSet(
8082
connection=result_set.connection,
8183
execute_response=execute_response,
82-
sea_client=cast(SeaDatabricksClient, result_set.backend),
8384
result_data=result_data,
8485
manifest=manifest,
8586
buffer_size_bytes=result_set.buffer_size_bytes,

src/databricks/sql/backend/thrift_backend.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,6 @@ def get_execution_result(
856856
return ThriftResultSet(
857857
connection=cursor.connection,
858858
execute_response=execute_response,
859-
thrift_client=self,
860859
buffer_size_bytes=cursor.buffer_size_bytes,
861860
arraysize=cursor.arraysize,
862861
use_cloud_fetch=cursor.connection.use_cloud_fetch,
@@ -987,7 +986,6 @@ def execute_command(
987986
return ThriftResultSet(
988987
connection=cursor.connection,
989988
execute_response=execute_response,
990-
thrift_client=self,
991989
buffer_size_bytes=max_bytes,
992990
arraysize=max_rows,
993991
use_cloud_fetch=use_cloud_fetch,
@@ -1027,7 +1025,6 @@ def get_catalogs(
10271025
return ThriftResultSet(
10281026
connection=cursor.connection,
10291027
execute_response=execute_response,
1030-
thrift_client=self,
10311028
buffer_size_bytes=max_bytes,
10321029
arraysize=max_rows,
10331030
use_cloud_fetch=cursor.connection.use_cloud_fetch,
@@ -1071,7 +1068,6 @@ def get_schemas(
10711068
return ThriftResultSet(
10721069
connection=cursor.connection,
10731070
execute_response=execute_response,
1074-
thrift_client=self,
10751071
buffer_size_bytes=max_bytes,
10761072
arraysize=max_rows,
10771073
use_cloud_fetch=cursor.connection.use_cloud_fetch,
@@ -1119,7 +1115,6 @@ def get_tables(
11191115
return ThriftResultSet(
11201116
connection=cursor.connection,
11211117
execute_response=execute_response,
1122-
thrift_client=self,
11231118
buffer_size_bytes=max_bytes,
11241119
arraysize=max_rows,
11251120
use_cloud_fetch=cursor.connection.use_cloud_fetch,
@@ -1167,7 +1162,6 @@ def get_columns(
11671162
return ThriftResultSet(
11681163
connection=cursor.connection,
11691164
execute_response=execute_response,
1170-
thrift_client=self,
11711165
buffer_size_bytes=max_bytes,
11721166
arraysize=max_rows,
11731167
use_cloud_fetch=cursor.connection.use_cloud_fetch,

src/databricks/sql/result_set.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from databricks.sql.utils import (
2121
ColumnTable,
2222
ColumnQueue,
23+
ResultSetQueue,
2324
)
2425
from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse
2526

@@ -36,14 +37,12 @@ class ResultSet(ABC):
3637
def __init__(
3738
self,
3839
connection: "Connection",
39-
backend: "DatabricksClient",
4040
arraysize: int,
4141
buffer_size_bytes: int,
4242
command_id: CommandId,
4343
status: CommandState,
4444
has_been_closed_server_side: bool = False,
4545
is_direct_results: bool = False,
46-
results_queue=None,
4746
description: List[Tuple] = [],
4847
is_staging_operation: bool = False,
4948
lz4_compressed: bool = False,
@@ -54,32 +53,30 @@ def __init__(
5453
5554
Parameters:
5655
:param connection: The parent connection
57-
:param backend: The backend client
5856
:param arraysize: The max number of rows to fetch at a time (PEP-249)
5957
:param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch
6058
:param command_id: The command ID
6159
:param status: The command status
6260
:param has_been_closed_server_side: Whether the command has been closed on the server
6361
:param is_direct_results: Whether the command has more rows
64-
:param results_queue: The results queue
6562
:param description: column description of the results
6663
:param is_staging_operation: Whether the command is a staging operation
6764
"""
6865

69-
self.connection = connection
70-
self.backend = backend
71-
self.arraysize = arraysize
72-
self.buffer_size_bytes = buffer_size_bytes
73-
self._next_row_index = 0
74-
self.description = description
75-
self.command_id = command_id
76-
self.status = status
77-
self.has_been_closed_server_side = has_been_closed_server_side
78-
self.is_direct_results = is_direct_results
79-
self.results = results_queue
80-
self._is_staging_operation = is_staging_operation
81-
self.lz4_compressed = lz4_compressed
82-
self._arrow_schema_bytes = arrow_schema_bytes
66+
self.connection: "Connection" = connection
67+
self.backend: DatabricksClient = connection.session.backend
68+
self.arraysize: int = arraysize
69+
self.buffer_size_bytes: int = buffer_size_bytes
70+
self._next_row_index: int = 0
71+
self.description: List[Tuple] = description
72+
self.command_id: CommandId = command_id
73+
self.status: CommandState = status
74+
self.has_been_closed_server_side: bool = has_been_closed_server_side
75+
self.is_direct_results: bool = is_direct_results
76+
self.results: Optional[ResultSetQueue] = None # Children will set this
77+
self._is_staging_operation: bool = is_staging_operation
78+
self.lz4_compressed: bool = lz4_compressed
79+
self._arrow_schema_bytes: Optional[bytes] = arrow_schema_bytes
8380

8481
def __iter__(self):
8582
while True:
@@ -190,7 +187,6 @@ def __init__(
190187
self,
191188
connection: "Connection",
192189
execute_response: "ExecuteResponse",
193-
thrift_client: "ThriftDatabricksClient",
194190
buffer_size_bytes: int = 104857600,
195191
arraysize: int = 10000,
196192
use_cloud_fetch: bool = True,
@@ -205,7 +201,6 @@ def __init__(
205201
Parameters:
206202
:param connection: The parent connection
207203
:param execute_response: Response from the execute command
208-
:param thrift_client: The ThriftDatabricksClient instance for direct access
209204
:param buffer_size_bytes: Buffer size for fetching results
210205
:param arraysize: Default number of rows to fetch
211206
:param use_cloud_fetch: Whether to use cloud fetch for retrieving results
@@ -238,20 +233,28 @@ def __init__(
238233
# Call parent constructor with common attributes
239234
super().__init__(
240235
connection=connection,
241-
backend=thrift_client,
242236
arraysize=arraysize,
243237
buffer_size_bytes=buffer_size_bytes,
244238
command_id=execute_response.command_id,
245239
status=execute_response.status,
246240
has_been_closed_server_side=execute_response.has_been_closed_server_side,
247241
is_direct_results=is_direct_results,
248-
results_queue=results_queue,
249242
description=execute_response.description,
250243
is_staging_operation=execute_response.is_staging_operation,
251244
lz4_compressed=execute_response.lz4_compressed,
252245
arrow_schema_bytes=execute_response.arrow_schema_bytes,
253246
)
254247

248+
# Assert that the backend is of the correct type
249+
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient
250+
251+
assert isinstance(
252+
self.backend, ThriftDatabricksClient
253+
), "Backend must be a ThriftDatabricksClient"
254+
255+
# Set the results queue
256+
self.results = results_queue
257+
255258
# Initialize results queue if not provided
256259
if not self.results:
257260
self._fill_results_buffer()
@@ -307,6 +310,10 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
307310
"""
308311
if size < 0:
309312
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)
313+
314+
if self.results is None:
315+
raise RuntimeError("Results queue is not initialized")
316+
310317
results = self.results.next_n_rows(size)
311318
n_remaining_rows = size - results.num_rows
312319
self._next_row_index += results.num_rows
@@ -332,6 +339,9 @@ def fetchmany_columnar(self, size: int):
332339
if size < 0:
333340
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)
334341

342+
if self.results is None:
343+
raise RuntimeError("Results queue is not initialized")
344+
335345
results = self.results.next_n_rows(size)
336346
n_remaining_rows = size - results.num_rows
337347
self._next_row_index += results.num_rows
@@ -351,6 +361,9 @@ def fetchmany_columnar(self, size: int):
351361

352362
def fetchall_arrow(self) -> "pyarrow.Table":
353363
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
364+
if self.results is None:
365+
raise RuntimeError("Results queue is not initialized")
366+
354367
results = self.results.remaining_rows()
355368
self._next_row_index += results.num_rows
356369

@@ -377,6 +390,9 @@ def fetchall_arrow(self) -> "pyarrow.Table":
377390

378391
def fetchall_columnar(self):
379392
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
393+
if self.results is None:
394+
raise RuntimeError("Results queue is not initialized")
395+
380396
results = self.results.remaining_rows()
381397
self._next_row_index += results.num_rows
382398

@@ -393,6 +409,9 @@ def fetchone(self) -> Optional[Row]:
393409
Fetch the next row of a query result set, returning a single sequence,
394410
or None when no more data is available.
395411
"""
412+
if self.results is None:
413+
raise RuntimeError("Results queue is not initialized")
414+
396415
if isinstance(self.results, ColumnQueue):
397416
res = self._convert_columnar_table(self.fetchmany_columnar(1))
398417
else:

0 commit comments

Comments
 (0)