Skip to content

Commit 9fb0444

Browse files
Merge branch 'main' into col-normalisation
2 parents b2ae83c + 36d3ec4 commit 9fb0444

File tree

8 files changed

+971
-240
lines changed

8 files changed

+971
-240
lines changed

src/databricks/sql/auth/retry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(
127127
total=_attempts_remaining,
128128
respect_retry_after_header=True,
129129
backoff_factor=self.delay_min,
130-
allowed_methods=["POST"],
130+
allowed_methods=["POST", "GET", "DELETE"],
131131
status_forcelist=[429, 503, *self.force_dangerous_codes],
132132
)
133133

src/databricks/sql/backend/sea/backend.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def __init__(
159159
)
160160

161161
self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True)
162+
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
162163

163164
# Extract warehouse ID from http_path
164165
self.warehouse_id = self._extract_warehouse_id(http_path)
@@ -695,7 +696,7 @@ def get_catalogs(
695696
max_bytes=max_bytes,
696697
lz4_compression=False,
697698
cursor=cursor,
698-
use_cloud_fetch=False,
699+
use_cloud_fetch=self.use_cloud_fetch,
699700
parameters=[],
700701
async_op=False,
701702
enforce_embedded_schema_correctness=False,
@@ -731,7 +732,7 @@ def get_schemas(
731732
max_bytes=max_bytes,
732733
lz4_compression=False,
733734
cursor=cursor,
734-
use_cloud_fetch=False,
735+
use_cloud_fetch=self.use_cloud_fetch,
735736
parameters=[],
736737
async_op=False,
737738
enforce_embedded_schema_correctness=False,
@@ -775,7 +776,7 @@ def get_tables(
775776
max_bytes=max_bytes,
776777
lz4_compression=False,
777778
cursor=cursor,
778-
use_cloud_fetch=False,
779+
use_cloud_fetch=self.use_cloud_fetch,
779780
parameters=[],
780781
async_op=False,
781782
enforce_embedded_schema_correctness=False,
@@ -825,7 +826,7 @@ def get_columns(
825826
max_bytes=max_bytes,
826827
lz4_compression=False,
827828
cursor=cursor,
828-
use_cloud_fetch=False,
829+
use_cloud_fetch=self.use_cloud_fetch,
829830
parameters=[],
830831
async_op=False,
831832
enforce_embedded_schema_correctness=False,

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 183 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77
from __future__ import annotations
88

9+
import io
910
import logging
1011
from typing import (
1112
List,
1213
Optional,
1314
Any,
14-
Callable,
1515
cast,
1616
TYPE_CHECKING,
1717
)
@@ -20,6 +20,16 @@
2020
from databricks.sql.backend.sea.result_set import SeaResultSet
2121

2222
from databricks.sql.backend.types import ExecuteResponse
23+
from databricks.sql.backend.sea.models.base import ResultData
24+
from databricks.sql.backend.sea.backend import SeaDatabricksClient
25+
from databricks.sql.utils import CloudFetchQueue, ArrowQueue
26+
27+
try:
28+
import pyarrow
29+
import pyarrow.compute as pc
30+
except ImportError:
31+
pyarrow = None
32+
pc = None
2333

2434
logger = logging.getLogger(__name__)
2535

@@ -30,32 +40,18 @@ class ResultSetFilter:
3040
"""
3141

3242
@staticmethod
33-
def _filter_sea_result_set(
34-
result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool]
35-
) -> SeaResultSet:
43+
def _create_execute_response(result_set: SeaResultSet) -> ExecuteResponse:
3644
"""
37-
Filter a SEA result set using the provided filter function.
45+
Create an ExecuteResponse with parameters from the original result set.
3846
3947
Args:
40-
result_set: The SEA result set to filter
41-
filter_func: Function that takes a row and returns True if the row should be included
48+
result_set: Original result set to copy parameters from
4249
4350
Returns:
44-
A filtered SEA result set
51+
ExecuteResponse: New execute response object
4552
"""
46-
47-
# Get all remaining rows
48-
all_rows = result_set.results.remaining_rows()
49-
50-
# Filter rows
51-
filtered_rows = [row for row in all_rows if filter_func(row)]
52-
53-
# Reuse the command_id from the original result set
54-
command_id = result_set.command_id
55-
56-
# Create an ExecuteResponse for the filtered data
57-
execute_response = ExecuteResponse(
58-
command_id=command_id,
53+
return ExecuteResponse(
54+
command_id=result_set.command_id,
5955
status=result_set.status,
6056
description=result_set.description,
6157
has_been_closed_server_side=result_set.has_been_closed_server_side,
@@ -64,39 +60,145 @@ def _filter_sea_result_set(
6460
is_staging_operation=False,
6561
)
6662

67-
# Create a new ResultData object with filtered data
68-
from databricks.sql.backend.sea.models.base import ResultData
63+
@staticmethod
64+
def _update_manifest(result_set: SeaResultSet, new_row_count: int):
65+
"""
66+
Create a copy of the manifest with updated row count.
6967
70-
result_data = ResultData(data=filtered_rows, external_links=None)
68+
Args:
69+
result_set: Original result set to copy manifest from
70+
new_row_count: New total row count for filtered data
71+
72+
Returns:
73+
Updated manifest copy
74+
"""
75+
filtered_manifest = result_set.manifest
76+
filtered_manifest.total_row_count = new_row_count
77+
return filtered_manifest
78+
79+
@staticmethod
80+
def _create_filtered_result_set(
81+
result_set: SeaResultSet,
82+
result_data: ResultData,
83+
row_count: int,
84+
) -> "SeaResultSet":
85+
"""
86+
Create a new filtered SeaResultSet with the provided data.
7187
72-
from databricks.sql.backend.sea.backend import SeaDatabricksClient
88+
Args:
89+
result_set: Original result set to copy parameters from
90+
result_data: New result data for the filtered set
91+
row_count: Number of rows in the filtered data
92+
93+
Returns:
94+
New filtered SeaResultSet
95+
"""
7396
from databricks.sql.backend.sea.result_set import SeaResultSet
7497

75-
# Create a new SeaResultSet with the filtered data
76-
manifest = result_set.manifest
77-
manifest.total_row_count = len(filtered_rows)
98+
execute_response = ResultSetFilter._create_execute_response(result_set)
99+
filtered_manifest = ResultSetFilter._update_manifest(result_set, row_count)
78100

79-
filtered_result_set = SeaResultSet(
101+
return SeaResultSet(
80102
connection=result_set.connection,
81103
execute_response=execute_response,
82104
sea_client=cast(SeaDatabricksClient, result_set.backend),
83105
result_data=result_data,
84-
manifest=manifest,
106+
manifest=filtered_manifest,
85107
buffer_size_bytes=result_set.buffer_size_bytes,
86108
arraysize=result_set.arraysize,
87109
)
88110

89-
# Preserve metadata columns setup from original result set
90-
if hasattr(result_set, "_metadata_columns") and result_set._metadata_columns:
91-
filtered_result_set._metadata_columns = result_set._metadata_columns
92-
filtered_result_set._column_index_mapping = result_set._column_index_mapping
93-
# Update the description to match the original prepared description
94-
filtered_result_set.description = result_set.description
111+
@staticmethod
112+
def _filter_arrow_table(
113+
table: Any, # pyarrow.Table
114+
column_name: str,
115+
allowed_values: List[str],
116+
case_sensitive: bool = True,
117+
) -> Any: # returns pyarrow.Table
118+
"""
119+
Filter a PyArrow table by column values.
120+
121+
Args:
122+
table: The PyArrow table to filter
123+
column_name: The name of the column to filter on
124+
allowed_values: List of allowed values for the column
125+
case_sensitive: Whether to perform case-sensitive comparison
126+
127+
Returns:
128+
A filtered PyArrow table
129+
"""
130+
if not pyarrow:
131+
raise ImportError("PyArrow is required for Arrow table filtering")
132+
133+
if table.num_rows == 0:
134+
return table
135+
136+
# Handle case-insensitive filtering by normalizing both column and allowed values
137+
if not case_sensitive:
138+
# Convert allowed values to uppercase
139+
allowed_values = [v.upper() for v in allowed_values]
140+
# Get column values as uppercase
141+
column = pc.utf8_upper(table[column_name])
142+
else:
143+
# Use column as-is
144+
column = table[column_name]
145+
146+
# Convert allowed_values to PyArrow Array
147+
allowed_array = pyarrow.array(allowed_values)
148+
149+
# Construct a boolean mask: True where column is in allowed_list
150+
mask = pc.is_in(column, value_set=allowed_array)
151+
return table.filter(mask)
152+
153+
@staticmethod
154+
def _filter_arrow_result_set(
155+
result_set: SeaResultSet,
156+
column_index: int,
157+
allowed_values: List[str],
158+
case_sensitive: bool = True,
159+
) -> SeaResultSet:
160+
"""
161+
Filter a SEA result set that contains Arrow tables.
162+
163+
Args:
164+
result_set: The SEA result set to filter (containing Arrow data)
165+
column_index: The index of the column to filter on
166+
allowed_values: List of allowed values for the column
167+
case_sensitive: Whether to perform case-sensitive comparison
168+
169+
Returns:
170+
A filtered SEA result set
171+
"""
172+
# Validate column index and get column name
173+
if column_index >= len(result_set.description):
174+
raise ValueError(f"Column index {column_index} is out of bounds")
175+
column_name = result_set.description[column_index][0]
176+
177+
# Get all remaining rows as Arrow table and filter it
178+
arrow_table = result_set.results.remaining_rows()
179+
filtered_table = ResultSetFilter._filter_arrow_table(
180+
arrow_table, column_name, allowed_values, case_sensitive
181+
)
182+
183+
# Convert the filtered table to Arrow stream format for ResultData
184+
sink = io.BytesIO()
185+
with pyarrow.ipc.new_stream(sink, filtered_table.schema) as writer:
186+
writer.write_table(filtered_table)
187+
arrow_stream_bytes = sink.getvalue()
188+
189+
# Create ResultData with attachment containing the filtered data
190+
result_data = ResultData(
191+
data=None, # No JSON data
192+
external_links=None, # No external links
193+
attachment=arrow_stream_bytes, # Arrow data as attachment
194+
)
95195

96-
return filtered_result_set
196+
return ResultSetFilter._create_filtered_result_set(
197+
result_set, result_data, filtered_table.num_rows
198+
)
97199

98200
@staticmethod
99-
def filter_by_column_values(
201+
def _filter_json_result_set(
100202
result_set: SeaResultSet,
101203
column_index: int,
102204
allowed_values: List[str],
@@ -114,22 +216,35 @@ def filter_by_column_values(
114216
Returns:
115217
A filtered result set
116218
"""
219+
# Validate column index (optional - not in arrow version but good practice)
220+
if column_index >= len(result_set.description):
221+
raise ValueError(f"Column index {column_index} is out of bounds")
117222

118-
# Convert to uppercase for case-insensitive comparison if needed
223+
# Extract rows
224+
all_rows = result_set.results.remaining_rows()
225+
226+
# Convert allowed values if case-insensitive
119227
if not case_sensitive:
120228
allowed_values = [v.upper() for v in allowed_values]
229+
# Helper lambda to get column value based on case sensitivity
230+
get_column_value = (
231+
lambda row: row[column_index].upper()
232+
if not case_sensitive
233+
else row[column_index]
234+
)
235+
236+
# Filter rows based on allowed values
237+
filtered_rows = [
238+
row
239+
for row in all_rows
240+
if len(row) > column_index and get_column_value(row) in allowed_values
241+
]
242+
243+
# Create filtered result set
244+
result_data = ResultData(data=filtered_rows, external_links=None)
121245

122-
return ResultSetFilter._filter_sea_result_set(
123-
result_set,
124-
lambda row: (
125-
len(row) > column_index
126-
and (
127-
row[column_index].upper()
128-
if not case_sensitive
129-
else row[column_index]
130-
)
131-
in allowed_values
132-
),
246+
return ResultSetFilter._create_filtered_result_set(
247+
result_set, result_data, len(filtered_rows)
133248
)
134249

135250
@staticmethod
@@ -150,14 +265,25 @@ def filter_tables_by_type(
150265
Returns:
151266
A filtered result set containing only tables of the specified types
152267
"""
153-
154268
# Default table types if none specified
155269
DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"]
156-
valid_types = (
157-
table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES
158-
)
270+
valid_types = table_types if table_types else DEFAULT_TABLE_TYPES
159271

272+
# Check if we have an Arrow table (cloud fetch) or JSON data
160273
# Table type is the 6th column (index 5)
161-
return ResultSetFilter.filter_by_column_values(
162-
result_set, 5, valid_types, case_sensitive=True
163-
)
274+
if isinstance(result_set.results, (CloudFetchQueue, ArrowQueue)):
275+
# For Arrow tables, we need to handle filtering differently
276+
return ResultSetFilter._filter_arrow_result_set(
277+
result_set,
278+
column_index=5,
279+
allowed_values=valid_types,
280+
case_sensitive=True,
281+
)
282+
else:
283+
# For JSON data, use the existing filter method
284+
return ResultSetFilter._filter_json_result_set(
285+
result_set,
286+
column_index=5,
287+
allowed_values=valid_types,
288+
case_sensitive=True,
289+
)

0 commit comments

Comments
 (0)