Skip to content

Commit d01f3e4

Browse files
add unit tests for SeaResultSet and ResultSetFilter
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 0d51d1b commit d01f3e4

File tree

3 files changed

+458
-6
lines changed

3 files changed

+458
-6
lines changed

src/databricks/sql/backend/filters.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@
99
List,
1010
Optional,
1111
Any,
12-
Dict,
1312
Callable,
14-
TypeVar,
15-
Generic,
16-
cast,
1713
TYPE_CHECKING,
1814
)
1915

@@ -141,7 +137,7 @@ def filter_tables_by_type(
141137
table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES
142138
)
143139

144-
# Table type is typically in the 4th column (index 3)
140+
# Table type is typically in the 6th column (index 5)
145141
return ResultSetFilter.filter_by_column_values(
146-
result_set, 3, valid_types, case_sensitive=False
142+
result_set, 5, valid_types, case_sensitive=False
147143
)
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""
2+
Tests for the ResultSetFilter class.
3+
4+
This module contains tests for the ResultSetFilter class, which provides
5+
filtering capabilities for result sets returned by different backends.
6+
"""
7+
8+
import pytest
9+
from unittest.mock import patch, MagicMock, Mock
10+
11+
from databricks.sql.backend.filters import ResultSetFilter
12+
from databricks.sql.result_set import SeaResultSet
13+
from databricks.sql.backend.types import CommandId, CommandState, BackendType
14+
15+
16+
class TestResultSetFilter:
17+
"""Test suite for the ResultSetFilter class."""
18+
19+
@pytest.fixture
20+
def mock_connection(self):
21+
"""Create a mock connection."""
22+
connection = Mock()
23+
connection.open = True
24+
return connection
25+
26+
@pytest.fixture
27+
def mock_sea_client(self):
28+
"""Create a mock SEA client."""
29+
return Mock()
30+
31+
@pytest.fixture
32+
def sea_response_with_tables(self):
33+
"""Create a sample SEA response with table data based on the server schema."""
34+
return {
35+
"statement_id": "test-statement-123",
36+
"status": {"state": "SUCCEEDED"},
37+
"manifest": {
38+
"format": "JSON_ARRAY",
39+
"schema": {
40+
"column_count": 7,
41+
"columns": [
42+
{"name": "namespace", "type_text": "STRING", "position": 0},
43+
{"name": "tableName", "type_text": "STRING", "position": 1},
44+
{"name": "isTemporary", "type_text": "BOOLEAN", "position": 2},
45+
{"name": "information", "type_text": "STRING", "position": 3},
46+
{"name": "catalogName", "type_text": "STRING", "position": 4},
47+
{"name": "tableType", "type_text": "STRING", "position": 5},
48+
{"name": "remarks", "type_text": "STRING", "position": 6},
49+
],
50+
},
51+
"total_row_count": 4,
52+
},
53+
"result": {
54+
"data_array": [
55+
["schema1", "table1", False, None, "catalog1", "TABLE", None],
56+
["schema1", "table2", False, None, "catalog1", "VIEW", None],
57+
["schema1", "table3", False, None, "catalog1", "SYSTEM TABLE", None],
58+
["schema1", "table4", False, None, "catalog1", "EXTERNAL", None],
59+
]
60+
},
61+
}
62+
63+
@pytest.fixture
64+
def sea_result_set_with_tables(self, mock_connection, mock_sea_client, sea_response_with_tables):
65+
"""Create a SeaResultSet with table data."""
66+
# Create a deep copy of the response to avoid test interference
67+
import copy
68+
sea_response_copy = copy.deepcopy(sea_response_with_tables)
69+
70+
return SeaResultSet(
71+
connection=mock_connection,
72+
sea_client=mock_sea_client,
73+
sea_response=sea_response_copy,
74+
buffer_size_bytes=1000,
75+
arraysize=100,
76+
)
77+
78+
def test_filter_tables_by_type_default(self, sea_result_set_with_tables):
79+
"""Test filtering tables by type with default types."""
80+
# Default types are TABLE, VIEW, SYSTEM TABLE
81+
filtered_result = ResultSetFilter.filter_tables_by_type(sea_result_set_with_tables)
82+
83+
# Verify that only the default types are included
84+
assert len(filtered_result._response["result"]["data_array"]) == 3
85+
table_types = [row[5] for row in filtered_result._response["result"]["data_array"]]
86+
assert "TABLE" in table_types
87+
assert "VIEW" in table_types
88+
assert "SYSTEM TABLE" in table_types
89+
assert "EXTERNAL" not in table_types
90+
91+
def test_filter_tables_by_type_custom(self, sea_result_set_with_tables):
92+
"""Test filtering tables by type with custom types."""
93+
# Filter for only TABLE and EXTERNAL
94+
filtered_result = ResultSetFilter.filter_tables_by_type(
95+
sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"]
96+
)
97+
98+
# Verify that only the specified types are included
99+
assert len(filtered_result._response["result"]["data_array"]) == 2
100+
table_types = [row[5] for row in filtered_result._response["result"]["data_array"]]
101+
assert "TABLE" in table_types
102+
assert "EXTERNAL" in table_types
103+
assert "VIEW" not in table_types
104+
assert "SYSTEM TABLE" not in table_types
105+
106+
def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables):
107+
"""Test that table type filtering is case-insensitive."""
108+
# Filter for lowercase "table" and "view"
109+
filtered_result = ResultSetFilter.filter_tables_by_type(
110+
sea_result_set_with_tables, table_types=["table", "view"]
111+
)
112+
113+
# Verify that the matching types are included despite case differences
114+
assert len(filtered_result._response["result"]["data_array"]) == 2
115+
table_types = [row[5] for row in filtered_result._response["result"]["data_array"]]
116+
assert "TABLE" in table_types
117+
assert "VIEW" in table_types
118+
assert "SYSTEM TABLE" not in table_types
119+
assert "EXTERNAL" not in table_types
120+
121+
def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables):
122+
"""Test filtering tables with an empty type list (should use defaults)."""
123+
filtered_result = ResultSetFilter.filter_tables_by_type(
124+
sea_result_set_with_tables, table_types=[]
125+
)
126+
127+
# Verify that default types are used
128+
assert len(filtered_result._response["result"]["data_array"]) == 3
129+
table_types = [row[5] for row in filtered_result._response["result"]["data_array"]]
130+
assert "TABLE" in table_types
131+
assert "VIEW" in table_types
132+
assert "SYSTEM TABLE" in table_types
133+
assert "EXTERNAL" not in table_types
134+
135+
def test_filter_by_column_values(self, sea_result_set_with_tables):
136+
"""Test filtering by values in a specific column."""
137+
# Filter by namespace in column index 0
138+
filtered_result = ResultSetFilter.filter_by_column_values(
139+
sea_result_set_with_tables, column_index=0, allowed_values=["schema1"]
140+
)
141+
142+
# All rows have schema1 in namespace, so all should be included
143+
assert len(filtered_result._response["result"]["data_array"]) == 4
144+
145+
# Filter by table name in column index 1
146+
filtered_result = ResultSetFilter.filter_by_column_values(
147+
sea_result_set_with_tables, column_index=1, allowed_values=["table1", "table3"]
148+
)
149+
150+
# Only rows with table1 or table3 should be included
151+
assert len(filtered_result._response["result"]["data_array"]) == 2
152+
table_names = [row[1] for row in filtered_result._response["result"]["data_array"]]
153+
assert "table1" in table_names
154+
assert "table3" in table_names
155+
assert "table2" not in table_names
156+
assert "table4" not in table_names
157+
158+
def test_filter_by_column_values_case_sensitive(self, mock_connection, mock_sea_client, sea_response_with_tables):
159+
"""Test case-sensitive filtering by column values."""
160+
import copy
161+
162+
# Create a fresh result set for the first test
163+
sea_response_copy1 = copy.deepcopy(sea_response_with_tables)
164+
result_set1 = SeaResultSet(
165+
connection=mock_connection,
166+
sea_client=mock_sea_client,
167+
sea_response=sea_response_copy1,
168+
buffer_size_bytes=1000,
169+
arraysize=100,
170+
)
171+
172+
# First test: Case-sensitive filtering with lowercase values (should find no matches)
173+
filtered_result = ResultSetFilter.filter_by_column_values(
174+
result_set1,
175+
column_index=5, # tableType column
176+
allowed_values=["table", "view"], # lowercase
177+
case_sensitive=True,
178+
)
179+
180+
# Verify no matches with lowercase values
181+
assert len(filtered_result._response["result"]["data_array"]) == 0
182+
183+
# Create a fresh result set for the second test
184+
sea_response_copy2 = copy.deepcopy(sea_response_with_tables)
185+
result_set2 = SeaResultSet(
186+
connection=mock_connection,
187+
sea_client=mock_sea_client,
188+
sea_response=sea_response_copy2,
189+
buffer_size_bytes=1000,
190+
arraysize=100,
191+
)
192+
193+
# Second test: Case-sensitive filtering with correct case (should find matches)
194+
filtered_result = ResultSetFilter.filter_by_column_values(
195+
result_set2,
196+
column_index=5, # tableType column
197+
allowed_values=["TABLE", "VIEW"], # correct case
198+
case_sensitive=True,
199+
)
200+
201+
# Verify matches with correct case
202+
assert len(filtered_result._response["result"]["data_array"]) == 2
203+
204+
# Extract the table types from the filtered results
205+
table_types = [row[5] for row in filtered_result._response["result"]["data_array"]]
206+
assert "TABLE" in table_types
207+
assert "VIEW" in table_types
208+
209+
def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables):
210+
"""Test filtering with a column index that's out of bounds."""
211+
# Filter by column index 10 (out of bounds)
212+
filtered_result = ResultSetFilter.filter_by_column_values(
213+
sea_result_set_with_tables, column_index=10, allowed_values=["value"]
214+
)
215+
216+
# No rows should match
217+
assert len(filtered_result._response["result"]["data_array"]) == 0

0 commit comments

Comments
 (0)