Skip to content

Commit e939d9e

Browse files
Merge branch 'ext-links-sea' into sea-norm-cols
2 parents 8f12c33 + abef941 commit e939d9e

File tree

10 files changed

+196
-184
lines changed

10 files changed

+196
-184
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def execute_command(
9494
parameters: List,
9595
async_op: bool,
9696
enforce_embedded_schema_correctness: bool,
97+
row_limit: Optional[int] = None,
9798
) -> Union["ResultSet", None]:
9899
"""
99100
Executes a SQL command or query within the specified session.
@@ -112,6 +113,7 @@ def execute_command(
112113
parameters: List of parameters to bind to the query
113114
async_op: Whether to execute the command asynchronously
114115
enforce_embedded_schema_correctness: Whether to enforce schema correctness
116+
row_limit: Maximum number of rows in the operation result.
115117
116118
Returns:
117119
If async_op is False, returns a ResultSet object containing the

src/databricks/sql/backend/sea/backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ def execute_command(
410410
parameters: List[Dict[str, Any]],
411411
async_op: bool,
412412
enforce_embedded_schema_correctness: bool,
413+
row_limit: Optional[int] = None,
413414
) -> Union[SeaResultSet, None]:
414415
"""
415416
Execute a SQL command using the SEA backend.
@@ -467,7 +468,7 @@ def execute_command(
467468
format=format,
468469
wait_timeout=(WaitTimeout.ASYNC if async_op else WaitTimeout.SYNC).value,
469470
on_wait_timeout="CONTINUE",
470-
row_limit=max_rows,
471+
row_limit=row_limit,
471472
parameters=sea_parameters if sea_parameters else None,
472473
result_compression=result_compression,
473474
)

src/databricks/sql/backend/sea/queue.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ def build_queue(
3535
result_data: ResultData,
3636
manifest: ResultManifest,
3737
statement_id: str,
38-
ssl_options: Optional[SSLOptions] = None,
39-
description: List[Tuple] = [],
40-
max_download_threads: Optional[int] = None,
41-
sea_client: Optional[SeaDatabricksClient] = None,
42-
lz4_compressed: bool = False,
38+
ssl_options: SSLOptions,
39+
description: List[Tuple],
40+
max_download_threads: int,
41+
sea_client: SeaDatabricksClient,
42+
lz4_compressed: bool,
4343
) -> ResultSetQueue:
4444
"""
4545
Factory method to build a result set queue for SEA backend.
@@ -62,19 +62,6 @@ def build_queue(
6262
return JsonQueue(result_data.data)
6363
elif manifest.format == ResultFormat.ARROW_STREAM.value:
6464
# EXTERNAL_LINKS disposition
65-
if not max_download_threads:
66-
raise ValueError(
67-
"Max download threads is required for EXTERNAL_LINKS disposition"
68-
)
69-
if not ssl_options:
70-
raise ValueError(
71-
"SSL options are required for EXTERNAL_LINKS disposition"
72-
)
73-
if not sea_client:
74-
raise ValueError(
75-
"SEA client is required for EXTERNAL_LINKS disposition"
76-
)
77-
7865
return SeaCloudFetchQueue(
7966
initial_links=result_data.external_links or [],
8067
max_download_threads=max_download_threads,

src/databricks/sql/backend/thrift_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import math
55
import time
66
import threading
7-
from typing import List, Union, Any, TYPE_CHECKING
7+
from typing import List, Optional, Union, Any, TYPE_CHECKING
88

99
if TYPE_CHECKING:
1010
from databricks.sql.client import Cursor
@@ -925,6 +925,7 @@ def execute_command(
925925
parameters=[],
926926
async_op=False,
927927
enforce_embedded_schema_correctness=False,
928+
row_limit: Optional[int] = None,
928929
) -> Union["ResultSet", None]:
929930
thrift_handle = session_id.to_thrift_handle()
930931
if not thrift_handle:
@@ -965,6 +966,7 @@ def execute_command(
965966
useArrowNativeTypes=spark_arrow_types,
966967
parameters=parameters,
967968
enforceEmbeddedSchemaCorrectness=enforce_embedded_schema_correctness,
969+
resultRowLimit=row_limit,
968970
)
969971
resp = self.make_request(self._client.ExecuteStatement, req)
970972

src/databricks/sql/client.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,14 @@ def cursor(
335335
self,
336336
arraysize: int = DEFAULT_ARRAY_SIZE,
337337
buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
338+
row_limit: Optional[int] = None,
338339
) -> "Cursor":
339340
"""
341+
Args:
342+
arraysize: The maximum number of rows in direct results.
343+
buffer_size_bytes: The maximum number of bytes in direct results.
344+
row_limit: The maximum number of rows in the result.
345+
340346
Return a new Cursor object using the connection.
341347
342348
Will throw an Error if the connection has been closed.
@@ -349,6 +355,7 @@ def cursor(
349355
self.session.backend,
350356
arraysize=arraysize,
351357
result_buffer_size_bytes=buffer_size_bytes,
358+
row_limit=row_limit,
352359
)
353360
self._cursors.append(cursor)
354361
return cursor
@@ -382,6 +389,7 @@ def __init__(
382389
backend: DatabricksClient,
383390
result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
384391
arraysize: int = DEFAULT_ARRAY_SIZE,
392+
row_limit: Optional[int] = None,
385393
) -> None:
386394
"""
387395
These objects represent a database cursor, which is used to manage the context of a fetch
@@ -391,16 +399,18 @@ def __init__(
391399
visible by other cursors or connections.
392400
"""
393401

394-
self.connection = connection
395-
self.rowcount = -1 # Return -1 as this is not supported
396-
self.buffer_size_bytes = result_buffer_size_bytes
402+
self.connection: Connection = connection
403+
404+
self.rowcount: int = -1 # Return -1 as this is not supported
405+
self.buffer_size_bytes: int = result_buffer_size_bytes
397406
self.active_result_set: Union[ResultSet, None] = None
398-
self.arraysize = arraysize
407+
self.arraysize: int = arraysize
408+
self.row_limit: Optional[int] = row_limit
399409
# Note that Cursor closed => active result set closed, but not vice versa
400-
self.open = True
401-
self.executing_command_id = None
402-
self.backend = backend
403-
self.active_command_id = None
410+
self.open: bool = True
411+
self.executing_command_id: Optional[CommandId] = None
412+
self.backend: DatabricksClient = backend
413+
self.active_command_id: Optional[CommandId] = None
404414
self.escaper = ParamEscaper()
405415
self.lastrowid = None
406416

@@ -779,6 +789,7 @@ def execute(
779789
parameters=prepared_params,
780790
async_op=False,
781791
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
792+
row_limit=self.row_limit,
782793
)
783794

784795
if self.active_result_set and self.active_result_set.is_staging_operation:
@@ -835,6 +846,7 @@ def execute_async(
835846
parameters=prepared_params,
836847
async_op=True,
837848
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
849+
row_limit=self.row_limit,
838850
)
839851

840852
return self

src/databricks/sql/utils.py

Lines changed: 46 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
2+
from typing import Dict, List, Optional, Union
33

44
from dateutil import parser
55
import datetime
@@ -9,34 +9,25 @@
99
from collections.abc import Mapping
1010
from decimal import Decimal
1111
from enum import Enum
12-
from typing import Any, Dict, List, Optional, Tuple, Union, Sequence
12+
from typing import Dict, List, Optional, Tuple, Union, Sequence
1313
import re
1414

15-
import dateutil
1615
import lz4.frame
1716

18-
from databricks.sql.backend.sea.backend import SeaDatabricksClient
19-
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
20-
2117
try:
2218
import pyarrow
2319
except ImportError:
2420
pyarrow = None
2521

2622
from databricks.sql import OperationalError
27-
from databricks.sql.exc import ProgrammingError
2823
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
2924
from databricks.sql.thrift_api.TCLIService.ttypes import (
3025
TRowSet,
3126
TSparkArrowResultLink,
3227
TSparkRowSetType,
3328
)
3429
from databricks.sql.types import SSLOptions
35-
from databricks.sql.backend.sea.models.base import (
36-
ResultData,
37-
ExternalLink,
38-
ResultManifest,
39-
)
30+
4031
from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter
4132

4233
import logging
@@ -227,11 +218,12 @@ def __init__(
227218
lz4_compressed: Whether the data is LZ4 compressed
228219
description: Column descriptions
229220
"""
221+
222+
self.schema_bytes = schema_bytes
223+
self.max_download_threads = max_download_threads
230224
self.lz4_compressed = lz4_compressed
231225
self.description = description
232-
self.schema_bytes = schema_bytes
233226
self._ssl_options = ssl_options
234-
self.max_download_threads = max_download_threads
235227

236228
# Table state
237229
self.table = None
@@ -240,104 +232,73 @@ def __init__(
240232
# Initialize download manager
241233
self.download_manager: Optional["ResultFileDownloadManager"] = None
242234

243-
def remaining_rows(self) -> "pyarrow.Table":
235+
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
244236
"""
245-
Get all remaining rows of the cloud fetch Arrow dataframes.
237+
Get up to the next n rows of the cloud fetch Arrow dataframes.
246238
239+
Args:
240+
num_rows (int): Number of rows to retrieve.
247241
Returns:
248242
pyarrow.Table
249243
"""
250244
if not self.table:
245+
logger.debug("CloudFetchQueue: no more rows available")
251246
# Return empty pyarrow table to cause retry of fetch
252247
return self._create_empty_table()
253-
254-
results = pyarrow.Table.from_pydict({}) # Empty table
255-
while self.table:
256-
table_slice = self.table.slice(
257-
self.table_row_index, self.table.num_rows - self.table_row_index
258-
)
259-
if results.num_rows > 0:
260-
results = pyarrow.concat_tables([results, table_slice])
261-
else:
262-
results = table_slice
263-
264-
self.table_row_index += table_slice.num_rows
265-
self.table = self._create_next_table()
266-
self.table_row_index = 0
267-
268-
return results
269-
270-
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
271-
"""Get up to the next n rows of the cloud fetch Arrow dataframes."""
272-
if not self.table:
273-
# Return empty pyarrow table to cause retry of fetch
274-
return self._create_empty_table()
275-
276-
logger.info("SeaCloudFetchQueue: Retrieving up to {} rows".format(num_rows))
277-
results = pyarrow.Table.from_pydict({}) # Empty table
278-
rows_fetched = 0
279-
248+
logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows))
249+
results = self.table.slice(0, 0)
280250
while num_rows > 0 and self.table:
281251
# Get remaining of num_rows or the rest of the current table, whichever is smaller
282252
length = min(num_rows, self.table.num_rows - self.table_row_index)
283-
logger.info(
284-
"CloudFetchQueue: Slicing table from index {} for {} rows (table has {} rows total)".format(
285-
self.table_row_index, length, self.table.num_rows
286-
)
287-
)
288253
table_slice = self.table.slice(self.table_row_index, length)
289-
290-
# Concatenate results if we have any
291-
if results.num_rows > 0:
292-
logger.info(
293-
"CloudFetchQueue: Concatenating {} rows to existing {} rows".format(
294-
table_slice.num_rows, results.num_rows
295-
)
296-
)
297-
results = pyarrow.concat_tables([results, table_slice])
298-
else:
299-
results = table_slice
300-
254+
results = pyarrow.concat_tables([results, table_slice])
301255
self.table_row_index += table_slice.num_rows
302-
rows_fetched += table_slice.num_rows
303-
304-
logger.info(
305-
"CloudFetchQueue: After slice, table_row_index={}, rows_fetched={}".format(
306-
self.table_row_index, rows_fetched
307-
)
308-
)
309256

310257
# Replace current table with the next table if we are at the end of the current table
311258
if self.table_row_index == self.table.num_rows:
312-
logger.info(
313-
"CloudFetchQueue: Reached end of current table, fetching next"
314-
)
315259
self.table = self._create_next_table()
316260
self.table_row_index = 0
317-
318261
num_rows -= table_slice.num_rows
319262

320-
logger.info("CloudFetchQueue: Retrieved {} rows".format(results.num_rows))
263+
logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
321264
return results
322265

323-
def _create_empty_table(self) -> "pyarrow.Table":
324-
"""Create a 0-row table with just the schema bytes."""
325-
if not self.schema_bytes:
326-
return pyarrow.Table.from_pydict({})
327-
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)
266+
def remaining_rows(self) -> "pyarrow.Table":
267+
"""
268+
Get all remaining rows of the cloud fetch Arrow dataframes.
269+
270+
Returns:
271+
pyarrow.Table
272+
"""
273+
274+
if not self.table:
275+
# Return empty pyarrow table to cause retry of fetch
276+
return self._create_empty_table()
277+
results = self.table.slice(0, 0)
278+
while self.table:
279+
table_slice = self.table.slice(
280+
self.table_row_index, self.table.num_rows - self.table_row_index
281+
)
282+
results = pyarrow.concat_tables([results, table_slice])
283+
self.table_row_index += table_slice.num_rows
284+
self.table = self._create_next_table()
285+
self.table_row_index = 0
286+
return results
328287

329288
def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]:
330-
"""Create next table by retrieving the logical next downloaded file."""
289+
"""Create next table at the given row offset"""
331290
# Create next table by retrieving the logical next downloaded file, or return None to signal end of queue
332291
if not self.download_manager:
333292
logger.debug("CloudFetchQueue: No download manager available")
334293
return None
335294

336295
downloaded_file = self.download_manager.get_next_downloaded_file(offset)
337296
if not downloaded_file:
297+
logger.debug(
298+
"CloudFetchQueue: Cannot find downloaded file for row {}".format(offset)
299+
)
338300
# None signals no more Arrow tables can be built from the remaining handlers if any remain
339301
return None
340-
341302
arrow_table = create_arrow_table_from_arrow_file(
342303
downloaded_file.file_bytes, self.description
343304
)
@@ -357,6 +318,12 @@ def _create_next_table(self) -> Union["pyarrow.Table", None]:
357318
"""Create next table by retrieving the logical next downloaded file."""
358319
pass
359320

321+
def _create_empty_table(self) -> "pyarrow.Table":
322+
"""Create a 0-row table with just the schema bytes."""
323+
if not self.schema_bytes:
324+
return pyarrow.Table.from_pydict({})
325+
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)
326+
360327

361328
class ThriftCloudFetchQueue(CloudFetchQueue):
362329
"""Queue implementation for EXTERNAL_LINKS disposition with ARROW format for Thrift backend."""

0 commit comments

Comments
 (0)