Skip to content

Commit 7fdc01d

Browse files
filters and sea_result_set unit tests
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 9395141 commit 7fdc01d

File tree

5 files changed

+379
-60
lines changed

5 files changed

+379
-60
lines changed

src/databricks/sql/backend/filters.py

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,120 +11,128 @@
1111
from databricks.sql.backend.sea_result_set import SeaResultSet
1212

1313
# Type variable for the result set type
14-
T = TypeVar('T')
14+
T = TypeVar("T")
1515

1616
logger = logging.getLogger(__name__)
1717

1818

1919
class ResultSetFilter:
2020
"""
2121
A general-purpose filter for result sets that can be applied to any backend.
22-
22+
2323
This class provides methods to filter result sets based on various criteria,
2424
similar to the client-side filtering in the JDBC connector.
2525
"""
26-
26+
2727
@staticmethod
2828
def filter_by_column_values(
2929
result_set: Any,
3030
column_index: int,
3131
allowed_values: List[str],
32-
case_sensitive: bool = False
32+
case_sensitive: bool = False,
3333
) -> Any:
3434
"""
3535
Filter a result set by values in a specific column.
36-
36+
3737
Args:
3838
result_set: The result set to filter
3939
column_index: The index of the column to filter on
4040
allowed_values: List of allowed values for the column
4141
case_sensitive: Whether to perform case-sensitive comparison
42-
42+
4343
Returns:
4444
A filtered result set
4545
"""
4646
# Convert to uppercase for case-insensitive comparison if needed
4747
if not case_sensitive:
4848
allowed_values = [v.upper() for v in allowed_values]
49-
49+
5050
# Determine the type of result set and apply appropriate filtering
5151
if isinstance(result_set, SeaResultSet):
5252
return ResultSetFilter._filter_sea_result_set(
53-
result_set,
53+
result_set,
5454
lambda row: (
55-
len(row) > column_index and
56-
isinstance(row[column_index], str) and
57-
(row[column_index].upper() if not case_sensitive else row[column_index]) in allowed_values
58-
)
55+
len(row) > column_index
56+
and isinstance(row[column_index], str)
57+
and (
58+
row[column_index].upper()
59+
if not case_sensitive
60+
else row[column_index]
61+
)
62+
in allowed_values
63+
),
5964
)
60-
65+
6166
# For other result set types, return the original (should be handled by specific implementations)
6267
logger.warning(
6368
f"Filtering not implemented for result set type: {type(result_set).__name__}"
6469
)
6570
return result_set
66-
71+
6772
@staticmethod
6873
def filter_tables_by_type(
69-
result_set: Any,
70-
table_types: Optional[List[str]] = None
74+
result_set: Any, table_types: Optional[List[str]] = None
7175
) -> Any:
7276
"""
7377
Filter a result set of tables by the specified table types.
74-
78+
7579
This is a client-side filter that processes the result set after it has been
7680
retrieved from the server. It filters out tables whose type does not match
7781
any of the types in the table_types list.
78-
82+
7983
Args:
8084
result_set: The original result set containing tables
8185
table_types: List of table types to include (e.g., ["TABLE", "VIEW"])
82-
86+
8387
Returns:
8488
A filtered result set containing only tables of the specified types
8589
"""
8690
# Default table types if none specified
8791
DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"]
88-
valid_types = table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES
89-
92+
valid_types = (
93+
table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES
94+
)
95+
9096
# Table type is typically in the 4th column (index 3)
91-
return ResultSetFilter.filter_by_column_values(result_set, 3, valid_types, case_sensitive=False)
92-
97+
return ResultSetFilter.filter_by_column_values(
98+
result_set, 3, valid_types, case_sensitive=False
99+
)
100+
93101
@staticmethod
94-
def _filter_sea_result_set(result_set: Any, filter_func: Callable[[List[Any]], bool]) -> Any:
102+
def _filter_sea_result_set(
103+
result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool]
104+
) -> SeaResultSet:
95105
"""
96106
Filter a SEA result set using the provided filter function.
97-
107+
98108
Args:
99109
result_set: The SEA result set to filter
100110
filter_func: Function that takes a row and returns True if the row should be included
101-
111+
102112
Returns:
103113
A filtered SEA result set
104114
"""
105-
# Type checking for SeaResultSet
106-
if not isinstance(result_set, SeaResultSet):
107-
return result_set
108-
109115
# Create a filtered version of the result set
110-
sea_result = cast(SeaResultSet, result_set)
111-
filtered_response = sea_result._response.copy() # type: ignore
112-
116+
filtered_response = result_set._response.copy()
117+
113118
# If there's a result with rows, filter them
114-
if "result" in filtered_response and "data_array" in filtered_response["result"]:
119+
if (
120+
"result" in filtered_response
121+
and "data_array" in filtered_response["result"]
122+
):
115123
rows = filtered_response["result"]["data_array"]
116124
filtered_rows = [row for row in rows if filter_func(row)]
117125
filtered_response["result"]["data_array"] = filtered_rows
118-
126+
119127
# Update row count if present
120128
if "row_count" in filtered_response["result"]:
121129
filtered_response["result"]["row_count"] = len(filtered_rows)
122-
130+
123131
# Create a new result set with the filtered data
124132
return SeaResultSet(
125-
connection=sea_result.connection,
133+
connection=result_set.connection,
126134
sea_response=filtered_response,
127-
sea_client=sea_result._sea_client, # type: ignore
128-
buffer_size_bytes=sea_result._buffer_size_bytes, # type: ignore
129-
arraysize=sea_result._arraysize, # type: ignore
130-
)
135+
sea_client=result_set.backend,
136+
buffer_size_bytes=result_set.buffer_size_bytes,
137+
arraysize=result_set.arraysize,
138+
)

src/databricks/sql/backend/sea_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ def get_tables(
555555

556556
# Apply client-side filtering by table_types if specified
557557
from databricks.sql.backend.filters import ResultSetFilter
558+
558559
result = ResultSetFilter.filter_tables_by_type(result, table_types)
559560

560561
return result

src/databricks/sql/backend/sea_result_set.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ def __init__(
3939
"""Initialize a SeaResultSet with the response from a SEA query execution."""
4040
super().__init__(connection, sea_client, arraysize, buffer_size_bytes)
4141

42+
# Store the original response for filtering operations
43+
self._response = sea_response
44+
self._sea_client = sea_client
45+
self._buffer_size_bytes = buffer_size_bytes
46+
self._arraysize = arraysize
47+
4248
# Extract and store SEA-specific properties
4349
self.statement_id = sea_response.get("statement_id")
4450

@@ -89,7 +95,7 @@ def __init__(
8995
result_data = sea_response.get("result")
9096
if result_data:
9197
self.result: Optional[ResultData] = ResultData(
92-
data=result_data.get("data"),
98+
data=result_data.get("data_array"), # Changed from data to data_array
9399
external_links=result_data.get("external_links"),
94100
)
95101
else:

tests/unit/test_filters.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ def setUp(self):
2828
["catalog1", "schema1", "table3", "SYSTEM TABLE", ""],
2929
["catalog1", "schema1", "table4", "EXTERNAL TABLE", ""],
3030
],
31-
"row_count": 4
31+
"row_count": 4,
3232
}
3333
}
34-
34+
3535
# Set up the connection and other required attributes
3636
self.mock_sea_result_set.connection = MagicMock()
3737
self.mock_sea_result_set._sea_client = MagicMock()
@@ -44,24 +44,26 @@ def test_filter_tables_by_type(self, mock_sea_result_set_class):
4444
# Set up the mock to return a new mock when instantiated
4545
mock_instance = MagicMock()
4646
mock_sea_result_set_class.return_value = mock_instance
47-
47+
4848
# Test with specific table types
4949
table_types = ["TABLE", "VIEW"]
50-
50+
5151
# Make the mock_sea_result_set appear to be a SeaResultSet
5252
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
53-
result = ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, table_types)
54-
53+
result = ResultSetFilter.filter_tables_by_type(
54+
self.mock_sea_result_set, table_types
55+
)
56+
5557
# Verify the filter was applied correctly
5658
mock_sea_result_set_class.assert_called_once()
5759
call_kwargs = mock_sea_result_set_class.call_args[1]
58-
60+
5961
# Check that the filtered response contains only TABLE and VIEW
6062
filtered_data = call_kwargs["sea_response"]["result"]["data_array"]
6163
self.assertEqual(len(filtered_data), 2)
6264
self.assertEqual(filtered_data[0][3], "TABLE")
6365
self.assertEqual(filtered_data[1][3], "VIEW")
64-
66+
6567
# Check row count was updated
6668
self.assertEqual(call_kwargs["sea_response"]["result"]["row_count"], 2)
6769

@@ -71,18 +73,20 @@ def test_filter_tables_by_type_case_insensitive(self, mock_sea_result_set_class)
7173
# Set up the mock to return a new mock when instantiated
7274
mock_instance = MagicMock()
7375
mock_sea_result_set_class.return_value = mock_instance
74-
76+
7577
# Test with lowercase table types
7678
table_types = ["table", "view"]
77-
79+
7880
# Make the mock_sea_result_set appear to be a SeaResultSet
7981
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
80-
result = ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, table_types)
81-
82+
result = ResultSetFilter.filter_tables_by_type(
83+
self.mock_sea_result_set, table_types
84+
)
85+
8286
# Verify the filter was applied correctly
8387
mock_sea_result_set_class.assert_called_once()
8488
call_kwargs = mock_sea_result_set_class.call_args[1]
85-
89+
8690
# Check that the filtered response contains only TABLE and VIEW
8791
filtered_data = call_kwargs["sea_response"]["result"]["data_array"]
8892
self.assertEqual(len(filtered_data), 2)
@@ -95,15 +99,17 @@ def test_filter_tables_by_type_default(self, mock_sea_result_set_class):
9599
# Set up the mock to return a new mock when instantiated
96100
mock_instance = MagicMock()
97101
mock_sea_result_set_class.return_value = mock_instance
98-
102+
99103
# Make the mock_sea_result_set appear to be a SeaResultSet
100104
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
101-
result = ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, None)
102-
105+
result = ResultSetFilter.filter_tables_by_type(
106+
self.mock_sea_result_set, None
107+
)
108+
103109
# Verify the filter was applied correctly
104110
mock_sea_result_set_class.assert_called_once()
105111
call_kwargs = mock_sea_result_set_class.call_args[1]
106-
112+
107113
# Check that the filtered response contains TABLE, VIEW, and SYSTEM TABLE
108114
filtered_data = call_kwargs["sea_response"]["result"]["data_array"]
109115
self.assertEqual(len(filtered_data), 3)
@@ -113,4 +119,4 @@ def test_filter_tables_by_type_default(self, mock_sea_result_set_class):
113119

114120

115121
if __name__ == "__main__":
116-
unittest.main()
122+
unittest.main()

0 commit comments

Comments
 (0)