1+ """
2+ Tests for the ResultSetFilter class.
3+
4+ This module contains tests for the ResultSetFilter class, which provides
5+ filtering capabilities for result sets returned by different backends.
6+ """
7+
8+ import pytest
9+ from unittest .mock import patch , MagicMock , Mock
10+
11+ from databricks .sql .backend .filters import ResultSetFilter
12+ from databricks .sql .result_set import SeaResultSet
13+ from databricks .sql .backend .types import CommandId , CommandState , BackendType
14+
15+
16+ class TestResultSetFilter :
17+ """Test suite for the ResultSetFilter class."""
18+
19+ @pytest .fixture
20+ def mock_connection (self ):
21+ """Create a mock connection."""
22+ connection = Mock ()
23+ connection .open = True
24+ return connection
25+
26+ @pytest .fixture
27+ def mock_sea_client (self ):
28+ """Create a mock SEA client."""
29+ return Mock ()
30+
31+ @pytest .fixture
32+ def sea_response_with_tables (self ):
33+ """Create a sample SEA response with table data based on the server schema."""
34+ return {
35+ "statement_id" : "test-statement-123" ,
36+ "status" : {"state" : "SUCCEEDED" },
37+ "manifest" : {
38+ "format" : "JSON_ARRAY" ,
39+ "schema" : {
40+ "column_count" : 7 ,
41+ "columns" : [
42+ {"name" : "namespace" , "type_text" : "STRING" , "position" : 0 },
43+ {"name" : "tableName" , "type_text" : "STRING" , "position" : 1 },
44+ {"name" : "isTemporary" , "type_text" : "BOOLEAN" , "position" : 2 },
45+ {"name" : "information" , "type_text" : "STRING" , "position" : 3 },
46+ {"name" : "catalogName" , "type_text" : "STRING" , "position" : 4 },
47+ {"name" : "tableType" , "type_text" : "STRING" , "position" : 5 },
48+ {"name" : "remarks" , "type_text" : "STRING" , "position" : 6 },
49+ ],
50+ },
51+ "total_row_count" : 4 ,
52+ },
53+ "result" : {
54+ "data_array" : [
55+ ["schema1" , "table1" , False , None , "catalog1" , "TABLE" , None ],
56+ ["schema1" , "table2" , False , None , "catalog1" , "VIEW" , None ],
57+ ["schema1" , "table3" , False , None , "catalog1" , "SYSTEM TABLE" , None ],
58+ ["schema1" , "table4" , False , None , "catalog1" , "EXTERNAL" , None ],
59+ ]
60+ },
61+ }
62+
63+ @pytest .fixture
64+ def sea_result_set_with_tables (self , mock_connection , mock_sea_client , sea_response_with_tables ):
65+ """Create a SeaResultSet with table data."""
66+ # Create a deep copy of the response to avoid test interference
67+ import copy
68+ sea_response_copy = copy .deepcopy (sea_response_with_tables )
69+
70+ return SeaResultSet (
71+ connection = mock_connection ,
72+ sea_client = mock_sea_client ,
73+ sea_response = sea_response_copy ,
74+ buffer_size_bytes = 1000 ,
75+ arraysize = 100 ,
76+ )
77+
78+ def test_filter_tables_by_type_default (self , sea_result_set_with_tables ):
79+ """Test filtering tables by type with default types."""
80+ # Default types are TABLE, VIEW, SYSTEM TABLE
81+ filtered_result = ResultSetFilter .filter_tables_by_type (sea_result_set_with_tables )
82+
83+ # Verify that only the default types are included
84+ assert len (filtered_result ._response ["result" ]["data_array" ]) == 3
85+ table_types = [row [5 ] for row in filtered_result ._response ["result" ]["data_array" ]]
86+ assert "TABLE" in table_types
87+ assert "VIEW" in table_types
88+ assert "SYSTEM TABLE" in table_types
89+ assert "EXTERNAL" not in table_types
90+
91+ def test_filter_tables_by_type_custom (self , sea_result_set_with_tables ):
92+ """Test filtering tables by type with custom types."""
93+ # Filter for only TABLE and EXTERNAL
94+ filtered_result = ResultSetFilter .filter_tables_by_type (
95+ sea_result_set_with_tables , table_types = ["TABLE" , "EXTERNAL" ]
96+ )
97+
98+ # Verify that only the specified types are included
99+ assert len (filtered_result ._response ["result" ]["data_array" ]) == 2
100+ table_types = [row [5 ] for row in filtered_result ._response ["result" ]["data_array" ]]
101+ assert "TABLE" in table_types
102+ assert "EXTERNAL" in table_types
103+ assert "VIEW" not in table_types
104+ assert "SYSTEM TABLE" not in table_types
105+
106+ def test_filter_tables_by_type_case_insensitive (self , sea_result_set_with_tables ):
107+ """Test that table type filtering is case-insensitive."""
108+ # Filter for lowercase "table" and "view"
109+ filtered_result = ResultSetFilter .filter_tables_by_type (
110+ sea_result_set_with_tables , table_types = ["table" , "view" ]
111+ )
112+
113+ # Verify that the matching types are included despite case differences
114+ assert len (filtered_result ._response ["result" ]["data_array" ]) == 2
115+ table_types = [row [5 ] for row in filtered_result ._response ["result" ]["data_array" ]]
116+ assert "TABLE" in table_types
117+ assert "VIEW" in table_types
118+ assert "SYSTEM TABLE" not in table_types
119+ assert "EXTERNAL" not in table_types
120+
121+ def test_filter_tables_by_type_empty_list (self , sea_result_set_with_tables ):
122+ """Test filtering tables with an empty type list (should use defaults)."""
123+ filtered_result = ResultSetFilter .filter_tables_by_type (
124+ sea_result_set_with_tables , table_types = []
125+ )
126+
127+ # Verify that default types are used
128+ assert len (filtered_result ._response ["result" ]["data_array" ]) == 3
129+ table_types = [row [5 ] for row in filtered_result ._response ["result" ]["data_array" ]]
130+ assert "TABLE" in table_types
131+ assert "VIEW" in table_types
132+ assert "SYSTEM TABLE" in table_types
133+ assert "EXTERNAL" not in table_types
134+
135+ def test_filter_by_column_values (self , sea_result_set_with_tables ):
136+ """Test filtering by values in a specific column."""
137+ # Filter by namespace in column index 0
138+ filtered_result = ResultSetFilter .filter_by_column_values (
139+ sea_result_set_with_tables , column_index = 0 , allowed_values = ["schema1" ]
140+ )
141+
142+ # All rows have schema1 in namespace, so all should be included
143+ assert len (filtered_result ._response ["result" ]["data_array" ]) == 4
144+
145+ # Filter by table name in column index 1
146+ filtered_result = ResultSetFilter .filter_by_column_values (
147+ sea_result_set_with_tables , column_index = 1 , allowed_values = ["table1" , "table3" ]
148+ )
149+
150+ # Only rows with table1 or table3 should be included
151+ assert len (filtered_result ._response ["result" ]["data_array" ]) == 2
152+ table_names = [row [1 ] for row in filtered_result ._response ["result" ]["data_array" ]]
153+ assert "table1" in table_names
154+ assert "table3" in table_names
155+ assert "table2" not in table_names
156+ assert "table4" not in table_names
157+
158+ def test_filter_by_column_values_case_sensitive (self , mock_connection , mock_sea_client , sea_response_with_tables ):
159+ """Test case-sensitive filtering by column values."""
160+ import copy
161+
162+ # Create a fresh result set for the first test
163+ sea_response_copy1 = copy .deepcopy (sea_response_with_tables )
164+ result_set1 = SeaResultSet (
165+ connection = mock_connection ,
166+ sea_client = mock_sea_client ,
167+ sea_response = sea_response_copy1 ,
168+ buffer_size_bytes = 1000 ,
169+ arraysize = 100 ,
170+ )
171+
172+ # First test: Case-sensitive filtering with lowercase values (should find no matches)
173+ filtered_result = ResultSetFilter .filter_by_column_values (
174+ result_set1 ,
175+ column_index = 5 , # tableType column
176+ allowed_values = ["table" , "view" ], # lowercase
177+ case_sensitive = True ,
178+ )
179+
180+ # Verify no matches with lowercase values
181+ assert len (filtered_result ._response ["result" ]["data_array" ]) == 0
182+
183+ # Create a fresh result set for the second test
184+ sea_response_copy2 = copy .deepcopy (sea_response_with_tables )
185+ result_set2 = SeaResultSet (
186+ connection = mock_connection ,
187+ sea_client = mock_sea_client ,
188+ sea_response = sea_response_copy2 ,
189+ buffer_size_bytes = 1000 ,
190+ arraysize = 100 ,
191+ )
192+
193+ # Second test: Case-sensitive filtering with correct case (should find matches)
194+ filtered_result = ResultSetFilter .filter_by_column_values (
195+ result_set2 ,
196+ column_index = 5 , # tableType column
197+ allowed_values = ["TABLE" , "VIEW" ], # correct case
198+ case_sensitive = True ,
199+ )
200+
201+ # Verify matches with correct case
202+ assert len (filtered_result ._response ["result" ]["data_array" ]) == 2
203+
204+ # Extract the table types from the filtered results
205+ table_types = [row [5 ] for row in filtered_result ._response ["result" ]["data_array" ]]
206+ assert "TABLE" in table_types
207+ assert "VIEW" in table_types
208+
209+ def test_filter_by_column_values_out_of_bounds (self , sea_result_set_with_tables ):
210+ """Test filtering with a column index that's out of bounds."""
211+ # Filter by column index 10 (out of bounds)
212+ filtered_result = ResultSetFilter .filter_by_column_values (
213+ sea_result_set_with_tables , column_index = 10 , allowed_values = ["value" ]
214+ )
215+
216+ # No rows should match
217+ assert len (filtered_result ._response ["result" ]["data_array" ]) == 0
0 commit comments