55"""
66
77import logging
8- from typing import List , Optional , Any , Dict , Callable , TypeVar , Generic , cast
8+ from typing import (
9+ List ,
10+ Optional ,
11+ Any ,
12+ Dict ,
13+ Callable ,
14+ TypeVar ,
15+ Generic ,
16+ cast ,
17+ TYPE_CHECKING ,
18+ )
919
1020# Import SeaResultSet for type checking
1121from databricks .sql .backend .sea_result_set import SeaResultSet
1222
23+ if TYPE_CHECKING :
24+ from databricks .sql .result_set import ResultSet
25+
1326# Type variable for the result set type
1427T = TypeVar ("T" )
1528
@@ -24,9 +37,48 @@ class ResultSetFilter:
2437 similar to the client-side filtering in the JDBC connector.
2538 """
2639
40+ @staticmethod
41+ def _filter_sea_result_set (
42+ result_set : SeaResultSet , filter_func : Callable [[List [Any ]], bool ]
43+ ) -> SeaResultSet :
44+ """
45+ Filter a SEA result set using the provided filter function.
46+
47+ Args:
48+ result_set: The SEA result set to filter
49+ filter_func: Function that takes a row and returns True if the row should be included
50+
51+ Returns:
52+ A filtered SEA result set
53+ """
54+ # Create a filtered version of the result set
55+ filtered_response = result_set ._response .copy ()
56+
57+ # If there's a result with rows, filter them
58+ if (
59+ "result" in filtered_response
60+ and "data_array" in filtered_response ["result" ]
61+ ):
62+ rows = filtered_response ["result" ]["data_array" ]
63+ filtered_rows = [row for row in rows if filter_func (row )]
64+ filtered_response ["result" ]["data_array" ] = filtered_rows
65+
66+ # Update row count if present
67+ if "row_count" in filtered_response ["result" ]:
68+ filtered_response ["result" ]["row_count" ] = len (filtered_rows )
69+
70+ # Create a new result set with the filtered data
71+ return SeaResultSet (
72+ connection = result_set .connection ,
73+ sea_response = filtered_response ,
74+ sea_client = result_set .backend ,
75+ buffer_size_bytes = result_set .buffer_size_bytes ,
76+ arraysize = result_set .arraysize ,
77+ )
78+
2779 @staticmethod
2880 def filter_by_column_values (
29- result_set : Any ,
81+ result_set : "ResultSet" ,
3082 column_index : int ,
3183 allowed_values : List [str ],
3284 case_sensitive : bool = False ,
@@ -71,7 +123,7 @@ def filter_by_column_values(
71123
72124 @staticmethod
73125 def filter_tables_by_type (
74- result_set : Any , table_types : Optional [List [str ]] = None
126+ result_set : "ResultSet" , table_types : Optional [List [str ]] = None
75127 ) -> Any :
76128 """
77129 Filter a result set of tables by the specified table types.
@@ -97,42 +149,3 @@ def filter_tables_by_type(
97149 return ResultSetFilter .filter_by_column_values (
98150 result_set , 3 , valid_types , case_sensitive = False
99151 )
100-
101- @staticmethod
102- def _filter_sea_result_set (
103- result_set : SeaResultSet , filter_func : Callable [[List [Any ]], bool ]
104- ) -> SeaResultSet :
105- """
106- Filter a SEA result set using the provided filter function.
107-
108- Args:
109- result_set: The SEA result set to filter
110- filter_func: Function that takes a row and returns True if the row should be included
111-
112- Returns:
113- A filtered SEA result set
114- """
115- # Create a filtered version of the result set
116- filtered_response = result_set ._response .copy ()
117-
118- # If there's a result with rows, filter them
119- if (
120- "result" in filtered_response
121- and "data_array" in filtered_response ["result" ]
122- ):
123- rows = filtered_response ["result" ]["data_array" ]
124- filtered_rows = [row for row in rows if filter_func (row )]
125- filtered_response ["result" ]["data_array" ] = filtered_rows
126-
127- # Update row count if present
128- if "row_count" in filtered_response ["result" ]:
129- filtered_response ["result" ]["row_count" ] = len (filtered_rows )
130-
131- # Create a new result set with the filtered data
132- return SeaResultSet (
133- connection = result_set .connection ,
134- sea_response = filtered_response ,
135- sea_client = result_set .backend ,
136- buffer_size_bytes = result_set .buffer_size_bytes ,
137- arraysize = result_set .arraysize ,
138- )
0 commit comments