1111from databricks .sql .backend .sea_result_set import SeaResultSet
1212
1313# Type variable for the result set type
14- T = TypeVar ('T' )
14+ T = TypeVar ("T" )
1515
1616logger = logging .getLogger (__name__ )
1717
1818
1919class ResultSetFilter :
2020 """
2121 A general-purpose filter for result sets that can be applied to any backend.
22-
22+
2323 This class provides methods to filter result sets based on various criteria,
2424 similar to the client-side filtering in the JDBC connector.
2525 """
26-
26+
2727 @staticmethod
2828 def filter_by_column_values (
2929 result_set : Any ,
3030 column_index : int ,
3131 allowed_values : List [str ],
32- case_sensitive : bool = False
32+ case_sensitive : bool = False ,
3333 ) -> Any :
3434 """
3535 Filter a result set by values in a specific column.
36-
36+
3737 Args:
3838 result_set: The result set to filter
3939 column_index: The index of the column to filter on
4040 allowed_values: List of allowed values for the column
4141 case_sensitive: Whether to perform case-sensitive comparison
42-
42+
4343 Returns:
4444 A filtered result set
4545 """
4646 # Convert to uppercase for case-insensitive comparison if needed
4747 if not case_sensitive :
4848 allowed_values = [v .upper () for v in allowed_values ]
49-
49+
5050 # Determine the type of result set and apply appropriate filtering
5151 if isinstance (result_set , SeaResultSet ):
5252 return ResultSetFilter ._filter_sea_result_set (
53- result_set ,
53+ result_set ,
5454 lambda row : (
55- len (row ) > column_index and
56- isinstance (row [column_index ], str ) and
57- (row [column_index ].upper () if not case_sensitive else row [column_index ]) in allowed_values
58- )
55+ len (row ) > column_index
56+ and isinstance (row [column_index ], str )
57+ and (
58+ row [column_index ].upper ()
59+ if not case_sensitive
60+ else row [column_index ]
61+ )
62+ in allowed_values
63+ ),
5964 )
60-
65+
6166 # For other result set types, return the original (should be handled by specific implementations)
6267 logger .warning (
6368 f"Filtering not implemented for result set type: { type (result_set ).__name__ } "
6469 )
6570 return result_set
66-
71+
6772 @staticmethod
6873 def filter_tables_by_type (
69- result_set : Any ,
70- table_types : Optional [List [str ]] = None
74+ result_set : Any , table_types : Optional [List [str ]] = None
7175 ) -> Any :
7276 """
7377 Filter a result set of tables by the specified table types.
74-
78+
7579 This is a client-side filter that processes the result set after it has been
7680 retrieved from the server. It filters out tables whose type does not match
7781 any of the types in the table_types list.
78-
82+
7983 Args:
8084 result_set: The original result set containing tables
8185 table_types: List of table types to include (e.g., ["TABLE", "VIEW"])
82-
86+
8387 Returns:
8488 A filtered result set containing only tables of the specified types
8589 """
8690 # Default table types if none specified
8791 DEFAULT_TABLE_TYPES = ["TABLE" , "VIEW" , "SYSTEM TABLE" ]
88- valid_types = table_types if table_types and len (table_types ) > 0 else DEFAULT_TABLE_TYPES
89-
92+ valid_types = (
93+ table_types if table_types and len (table_types ) > 0 else DEFAULT_TABLE_TYPES
94+ )
95+
9096 # Table type is typically in the 4th column (index 3)
91- return ResultSetFilter .filter_by_column_values (result_set , 3 , valid_types , case_sensitive = False )
92-
97+ return ResultSetFilter .filter_by_column_values (
98+ result_set , 3 , valid_types , case_sensitive = False
99+ )
100+
93101 @staticmethod
94- def _filter_sea_result_set (result_set : Any , filter_func : Callable [[List [Any ]], bool ]) -> Any :
102+ def _filter_sea_result_set (
103+ result_set : SeaResultSet , filter_func : Callable [[List [Any ]], bool ]
104+ ) -> SeaResultSet :
95105 """
96106 Filter a SEA result set using the provided filter function.
97-
107+
98108 Args:
99109 result_set: The SEA result set to filter
100110 filter_func: Function that takes a row and returns True if the row should be included
101-
111+
102112 Returns:
103113 A filtered SEA result set
104114 """
105- # Type checking for SeaResultSet
106- if not isinstance (result_set , SeaResultSet ):
107- return result_set
108-
109115 # Create a filtered version of the result set
110- sea_result = cast (SeaResultSet , result_set )
111- filtered_response = sea_result ._response .copy () # type: ignore
112-
116+ filtered_response = result_set ._response .copy ()
117+
113118 # If there's a result with rows, filter them
114- if "result" in filtered_response and "data_array" in filtered_response ["result" ]:
119+ if (
120+ "result" in filtered_response
121+ and "data_array" in filtered_response ["result" ]
122+ ):
115123 rows = filtered_response ["result" ]["data_array" ]
116124 filtered_rows = [row for row in rows if filter_func (row )]
117125 filtered_response ["result" ]["data_array" ] = filtered_rows
118-
126+
119127 # Update row count if present
120128 if "row_count" in filtered_response ["result" ]:
121129 filtered_response ["result" ]["row_count" ] = len (filtered_rows )
122-
130+
123131 # Create a new result set with the filtered data
124132 return SeaResultSet (
125- connection = sea_result .connection ,
133+ connection = result_set .connection ,
126134 sea_response = filtered_response ,
127- sea_client = sea_result . _sea_client , # type: ignore
128- buffer_size_bytes = sea_result . _buffer_size_bytes , # type: ignore
129- arraysize = sea_result . _arraysize , # type: ignore
130- )
135+ sea_client = result_set . backend ,
136+ buffer_size_bytes = result_set . buffer_size_bytes ,
137+ arraysize = result_set . arraysize ,
138+ )
0 commit comments