Skip to content

Commit 2871e05

Browse files
stronger typing on ResultSet
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 7fdc01d commit 2871e05

File tree

1 file changed

+55
-42
lines changed

1 file changed

+55
-42
lines changed

src/databricks/sql/backend/filters.py

Lines changed: 55 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,24 @@
55
"""
66

77
import logging
8-
from typing import List, Optional, Any, Dict, Callable, TypeVar, Generic, cast
8+
from typing import (
9+
List,
10+
Optional,
11+
Any,
12+
Dict,
13+
Callable,
14+
TypeVar,
15+
Generic,
16+
cast,
17+
TYPE_CHECKING,
18+
)
919

1020
# Import SeaResultSet for type checking
1121
from databricks.sql.backend.sea_result_set import SeaResultSet
1222

23+
if TYPE_CHECKING:
24+
from databricks.sql.result_set import ResultSet
25+
1326
# Type variable for the result set type
1427
T = TypeVar("T")
1528

@@ -24,9 +37,48 @@ class ResultSetFilter:
2437
similar to the client-side filtering in the JDBC connector.
2538
"""
2639

40+
@staticmethod
41+
def _filter_sea_result_set(
42+
result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool]
43+
) -> SeaResultSet:
44+
"""
45+
Filter a SEA result set using the provided filter function.
46+
47+
Args:
48+
result_set: The SEA result set to filter
49+
filter_func: Function that takes a row and returns True if the row should be included
50+
51+
Returns:
52+
A filtered SEA result set
53+
"""
54+
# Create a filtered version of the result set
55+
filtered_response = result_set._response.copy()
56+
57+
# If there's a result with rows, filter them
58+
if (
59+
"result" in filtered_response
60+
and "data_array" in filtered_response["result"]
61+
):
62+
rows = filtered_response["result"]["data_array"]
63+
filtered_rows = [row for row in rows if filter_func(row)]
64+
filtered_response["result"]["data_array"] = filtered_rows
65+
66+
# Update row count if present
67+
if "row_count" in filtered_response["result"]:
68+
filtered_response["result"]["row_count"] = len(filtered_rows)
69+
70+
# Create a new result set with the filtered data
71+
return SeaResultSet(
72+
connection=result_set.connection,
73+
sea_response=filtered_response,
74+
sea_client=result_set.backend,
75+
buffer_size_bytes=result_set.buffer_size_bytes,
76+
arraysize=result_set.arraysize,
77+
)
78+
2779
@staticmethod
2880
def filter_by_column_values(
29-
result_set: Any,
81+
result_set: "ResultSet",
3082
column_index: int,
3183
allowed_values: List[str],
3284
case_sensitive: bool = False,
@@ -71,7 +123,7 @@ def filter_by_column_values(
71123

72124
@staticmethod
73125
def filter_tables_by_type(
74-
result_set: Any, table_types: Optional[List[str]] = None
126+
result_set: "ResultSet", table_types: Optional[List[str]] = None
75127
) -> Any:
76128
"""
77129
Filter a result set of tables by the specified table types.
@@ -97,42 +149,3 @@ def filter_tables_by_type(
97149
return ResultSetFilter.filter_by_column_values(
98150
result_set, 3, valid_types, case_sensitive=False
99151
)
100-
101-
@staticmethod
102-
def _filter_sea_result_set(
103-
result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool]
104-
) -> SeaResultSet:
105-
"""
106-
Filter a SEA result set using the provided filter function.
107-
108-
Args:
109-
result_set: The SEA result set to filter
110-
filter_func: Function that takes a row and returns True if the row should be included
111-
112-
Returns:
113-
A filtered SEA result set
114-
"""
115-
# Create a filtered version of the result set
116-
filtered_response = result_set._response.copy()
117-
118-
# If there's a result with rows, filter them
119-
if (
120-
"result" in filtered_response
121-
and "data_array" in filtered_response["result"]
122-
):
123-
rows = filtered_response["result"]["data_array"]
124-
filtered_rows = [row for row in rows if filter_func(row)]
125-
filtered_response["result"]["data_array"] = filtered_rows
126-
127-
# Update row count if present
128-
if "row_count" in filtered_response["result"]:
129-
filtered_response["result"]["row_count"] = len(filtered_rows)
130-
131-
# Create a new result set with the filtered data
132-
return SeaResultSet(
133-
connection=result_set.connection,
134-
sea_response=filtered_response,
135-
sea_client=result_set.backend,
136-
buffer_size_bytes=result_set.buffer_size_bytes,
137-
arraysize=result_set.arraysize,
138-
)

0 commit comments

Comments
 (0)