Skip to content

Commit 9395141

Browse files
preliminary table filtering
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent edd4c87 commit 9395141

File tree

3 files changed

+248
-67
lines changed

3 files changed

+248
-67
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""
2+
Client-side filtering utilities for Databricks SQL connector.
3+
4+
This module provides filtering capabilities for result sets returned by different backends.
5+
"""
6+
7+
import logging
8+
from typing import List, Optional, Any, Dict, Callable, TypeVar, Generic, cast
9+
10+
# Import SeaResultSet for type checking
11+
from databricks.sql.backend.sea_result_set import SeaResultSet
12+
13+
# Type variable for the result set type
14+
T = TypeVar('T')
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
class ResultSetFilter:
20+
"""
21+
A general-purpose filter for result sets that can be applied to any backend.
22+
23+
This class provides methods to filter result sets based on various criteria,
24+
similar to the client-side filtering in the JDBC connector.
25+
"""
26+
27+
@staticmethod
28+
def filter_by_column_values(
29+
result_set: Any,
30+
column_index: int,
31+
allowed_values: List[str],
32+
case_sensitive: bool = False
33+
) -> Any:
34+
"""
35+
Filter a result set by values in a specific column.
36+
37+
Args:
38+
result_set: The result set to filter
39+
column_index: The index of the column to filter on
40+
allowed_values: List of allowed values for the column
41+
case_sensitive: Whether to perform case-sensitive comparison
42+
43+
Returns:
44+
A filtered result set
45+
"""
46+
# Convert to uppercase for case-insensitive comparison if needed
47+
if not case_sensitive:
48+
allowed_values = [v.upper() for v in allowed_values]
49+
50+
# Determine the type of result set and apply appropriate filtering
51+
if isinstance(result_set, SeaResultSet):
52+
return ResultSetFilter._filter_sea_result_set(
53+
result_set,
54+
lambda row: (
55+
len(row) > column_index and
56+
isinstance(row[column_index], str) and
57+
(row[column_index].upper() if not case_sensitive else row[column_index]) in allowed_values
58+
)
59+
)
60+
61+
# For other result set types, return the original (should be handled by specific implementations)
62+
logger.warning(
63+
f"Filtering not implemented for result set type: {type(result_set).__name__}"
64+
)
65+
return result_set
66+
67+
@staticmethod
68+
def filter_tables_by_type(
69+
result_set: Any,
70+
table_types: Optional[List[str]] = None
71+
) -> Any:
72+
"""
73+
Filter a result set of tables by the specified table types.
74+
75+
This is a client-side filter that processes the result set after it has been
76+
retrieved from the server. It filters out tables whose type does not match
77+
any of the types in the table_types list.
78+
79+
Args:
80+
result_set: The original result set containing tables
81+
table_types: List of table types to include (e.g., ["TABLE", "VIEW"])
82+
83+
Returns:
84+
A filtered result set containing only tables of the specified types
85+
"""
86+
# Default table types if none specified
87+
DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"]
88+
valid_types = table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES
89+
90+
# Table type is typically in the 4th column (index 3)
91+
return ResultSetFilter.filter_by_column_values(result_set, 3, valid_types, case_sensitive=False)
92+
93+
@staticmethod
94+
def _filter_sea_result_set(result_set: Any, filter_func: Callable[[List[Any]], bool]) -> Any:
95+
"""
96+
Filter a SEA result set using the provided filter function.
97+
98+
Args:
99+
result_set: The SEA result set to filter
100+
filter_func: Function that takes a row and returns True if the row should be included
101+
102+
Returns:
103+
A filtered SEA result set
104+
"""
105+
# Type checking for SeaResultSet
106+
if not isinstance(result_set, SeaResultSet):
107+
return result_set
108+
109+
# Create a filtered version of the result set
110+
sea_result = cast(SeaResultSet, result_set)
111+
filtered_response = sea_result._response.copy() # type: ignore
112+
113+
# If there's a result with rows, filter them
114+
if "result" in filtered_response and "data_array" in filtered_response["result"]:
115+
rows = filtered_response["result"]["data_array"]
116+
filtered_rows = [row for row in rows if filter_func(row)]
117+
filtered_response["result"]["data_array"] = filtered_rows
118+
119+
# Update row count if present
120+
if "row_count" in filtered_response["result"]:
121+
filtered_response["result"]["row_count"] = len(filtered_rows)
122+
123+
# Create a new result set with the filtered data
124+
return SeaResultSet(
125+
connection=sea_result.connection,
126+
sea_response=filtered_response,
127+
sea_client=sea_result._sea_client, # type: ignore
128+
buffer_size_bytes=sea_result._buffer_size_bytes, # type: ignore
129+
arraysize=sea_result._arraysize, # type: ignore
130+
)

src/databricks/sql/backend/sea_backend.py

Lines changed: 2 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -554,76 +554,11 @@ def get_tables(
554554
assert result is not None, "execute_command returned None in synchronous mode"
555555

556556
# Apply client-side filtering by table_types if specified
557-
if table_types and len(table_types) > 0:
558-
result = self._filter_tables_by_type(result, table_types)
557+
from databricks.sql.backend.filters import ResultSetFilter
558+
result = ResultSetFilter.filter_tables_by_type(result, table_types)
559559

560560
return result
561561

562-
def _filter_tables_by_type(
563-
self, result_set: "ResultSet", table_types: List[str]
564-
) -> "ResultSet":
565-
"""
566-
Filter a result set of tables by the specified table types.
567-
568-
This is a client-side filter that processes the result set after it has been
569-
retrieved from the server. It filters out tables whose type does not match
570-
any of the types in the table_types list.
571-
572-
Args:
573-
result_set: The original result set containing tables
574-
table_types: List of table types to include (e.g., ["TABLE", "VIEW"])
575-
576-
Returns:
577-
A filtered result set containing only tables of the specified types
578-
"""
579-
# Default table types if none specified
580-
DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"]
581-
valid_types = table_types if table_types else DEFAULT_TABLE_TYPES
582-
583-
# Convert to uppercase for case-insensitive comparison
584-
valid_types_upper = [t.upper() for t in valid_types]
585-
586-
# Create a filtered version of the result set
587-
from databricks.sql.backend.sea_result_set import SeaResultSet
588-
589-
if isinstance(result_set, SeaResultSet):
590-
# For SEA result sets, we need to filter the rows in the response
591-
filtered_response = result_set._response.copy()
592-
593-
# If there's a result with rows, filter them
594-
if (
595-
"result" in filtered_response
596-
and "data_array" in filtered_response["result"]
597-
):
598-
rows = filtered_response["result"]["data_array"]
599-
# Table type is typically in the 4th column (index 3)
600-
filtered_rows = [
601-
row
602-
for row in rows
603-
if len(row) > 3
604-
and (
605-
isinstance(row[3], str) and row[3].upper() in valid_types_upper
606-
)
607-
]
608-
filtered_response["result"]["data_array"] = filtered_rows
609-
610-
# Update row count if present
611-
if "row_count" in filtered_response["result"]:
612-
filtered_response["result"]["row_count"] = len(filtered_rows)
613-
614-
# Create a new result set with the filtered data
615-
return SeaResultSet(
616-
connection=result_set.connection,
617-
sea_response=filtered_response,
618-
sea_client=result_set._sea_client,
619-
buffer_size_bytes=result_set._buffer_size_bytes,
620-
arraysize=result_set._arraysize,
621-
)
622-
623-
# For other result set types, return the original (should not happen with SEA)
624-
logger.warning("Table type filtering not implemented for this result set type")
625-
return result_set
626-
627562
def get_columns(
628563
self,
629564
session_id: SessionId,

tests/unit/test_filters.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""
2+
Tests for the ResultSetFilter class.
3+
"""
4+
5+
import unittest
6+
from unittest.mock import MagicMock, patch
7+
import sys
8+
from typing import List, Dict, Any
9+
10+
# Add the necessary path to import the filter module
11+
sys.path.append("/home/varun.edachali/conn/databricks-sql-python/src")
12+
13+
from databricks.sql.backend.filters import ResultSetFilter
14+
15+
16+
class TestResultSetFilter(unittest.TestCase):
17+
"""Tests for the ResultSetFilter class."""
18+
19+
def setUp(self):
20+
"""Set up test fixtures."""
21+
# Create a mock SeaResultSet
22+
self.mock_sea_result_set = MagicMock()
23+
self.mock_sea_result_set._response = {
24+
"result": {
25+
"data_array": [
26+
["catalog1", "schema1", "table1", "TABLE", ""],
27+
["catalog1", "schema1", "table2", "VIEW", ""],
28+
["catalog1", "schema1", "table3", "SYSTEM TABLE", ""],
29+
["catalog1", "schema1", "table4", "EXTERNAL TABLE", ""],
30+
],
31+
"row_count": 4
32+
}
33+
}
34+
35+
# Set up the connection and other required attributes
36+
self.mock_sea_result_set.connection = MagicMock()
37+
self.mock_sea_result_set._sea_client = MagicMock()
38+
self.mock_sea_result_set._buffer_size_bytes = 1000
39+
self.mock_sea_result_set._arraysize = 100
40+
41+
@patch("databricks.sql.backend.filters.SeaResultSet")
42+
def test_filter_tables_by_type(self, mock_sea_result_set_class):
43+
"""Test filtering tables by type."""
44+
# Set up the mock to return a new mock when instantiated
45+
mock_instance = MagicMock()
46+
mock_sea_result_set_class.return_value = mock_instance
47+
48+
# Test with specific table types
49+
table_types = ["TABLE", "VIEW"]
50+
51+
# Make the mock_sea_result_set appear to be a SeaResultSet
52+
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
53+
result = ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, table_types)
54+
55+
# Verify the filter was applied correctly
56+
mock_sea_result_set_class.assert_called_once()
57+
call_kwargs = mock_sea_result_set_class.call_args[1]
58+
59+
# Check that the filtered response contains only TABLE and VIEW
60+
filtered_data = call_kwargs["sea_response"]["result"]["data_array"]
61+
self.assertEqual(len(filtered_data), 2)
62+
self.assertEqual(filtered_data[0][3], "TABLE")
63+
self.assertEqual(filtered_data[1][3], "VIEW")
64+
65+
# Check row count was updated
66+
self.assertEqual(call_kwargs["sea_response"]["result"]["row_count"], 2)
67+
68+
@patch("databricks.sql.backend.filters.SeaResultSet")
69+
def test_filter_tables_by_type_case_insensitive(self, mock_sea_result_set_class):
70+
"""Test filtering tables by type with case insensitivity."""
71+
# Set up the mock to return a new mock when instantiated
72+
mock_instance = MagicMock()
73+
mock_sea_result_set_class.return_value = mock_instance
74+
75+
# Test with lowercase table types
76+
table_types = ["table", "view"]
77+
78+
# Make the mock_sea_result_set appear to be a SeaResultSet
79+
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
80+
result = ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, table_types)
81+
82+
# Verify the filter was applied correctly
83+
mock_sea_result_set_class.assert_called_once()
84+
call_kwargs = mock_sea_result_set_class.call_args[1]
85+
86+
# Check that the filtered response contains only TABLE and VIEW
87+
filtered_data = call_kwargs["sea_response"]["result"]["data_array"]
88+
self.assertEqual(len(filtered_data), 2)
89+
self.assertEqual(filtered_data[0][3], "TABLE")
90+
self.assertEqual(filtered_data[1][3], "VIEW")
91+
92+
@patch("databricks.sql.backend.filters.SeaResultSet")
93+
def test_filter_tables_by_type_default(self, mock_sea_result_set_class):
94+
"""Test filtering tables by type with default types."""
95+
# Set up the mock to return a new mock when instantiated
96+
mock_instance = MagicMock()
97+
mock_sea_result_set_class.return_value = mock_instance
98+
99+
# Make the mock_sea_result_set appear to be a SeaResultSet
100+
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
101+
result = ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, None)
102+
103+
# Verify the filter was applied correctly
104+
mock_sea_result_set_class.assert_called_once()
105+
call_kwargs = mock_sea_result_set_class.call_args[1]
106+
107+
# Check that the filtered response contains TABLE, VIEW, and SYSTEM TABLE
108+
filtered_data = call_kwargs["sea_response"]["result"]["data_array"]
109+
self.assertEqual(len(filtered_data), 3)
110+
self.assertEqual(filtered_data[0][3], "TABLE")
111+
self.assertEqual(filtered_data[1][3], "VIEW")
112+
self.assertEqual(filtered_data[2][3], "SYSTEM TABLE")
113+
114+
115+
if __name__ == "__main__":
116+
unittest.main()

0 commit comments

Comments
 (0)