Skip to content

Commit edd4c87

Browse files
client side table_types filtering
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 611d79f commit edd4c87

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

src/databricks/sql/backend/models/requests.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Dict, List, Any, Optional, Union
88
from dataclasses import dataclass, field
99

10+
1011
@dataclass
1112
class StatementParameter:
1213
"""Parameter for a SQL statement."""
@@ -57,7 +58,7 @@ def to_dict(self) -> Dict[str, Any]:
5758

5859
if self.schema:
5960
result["schema"] = self.schema
60-
61+
6162
if self.result_compression:
6263
result["result_compression"] = self.result_compression
6364

src/databricks/sql/backend/sea_backend.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def execute_command(
278278
format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY"
279279
disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE"
280280
result_compression = "LZ4_FRAME" if lz4_compression else "NONE"
281-
281+
282282
request = ExecuteStatementRequest(
283283
warehouse_id=self.warehouse_id,
284284
session_id=sea_session_id,
@@ -552,8 +552,78 @@ def get_tables(
552552
enforce_embedded_schema_correctness=False,
553553
)
554554
assert result is not None, "execute_command returned None in synchronous mode"
555+
556+
# Apply client-side filtering by table_types if specified
557+
if table_types and len(table_types) > 0:
558+
result = self._filter_tables_by_type(result, table_types)
559+
555560
return result
556561

562+
def _filter_tables_by_type(
563+
self, result_set: "ResultSet", table_types: List[str]
564+
) -> "ResultSet":
565+
"""
566+
Filter a result set of tables by the specified table types.
567+
568+
This is a client-side filter that processes the result set after it has been
569+
retrieved from the server. It filters out tables whose type does not match
570+
any of the types in the table_types list.
571+
572+
Args:
573+
result_set: The original result set containing tables
574+
table_types: List of table types to include (e.g., ["TABLE", "VIEW"])
575+
576+
Returns:
577+
A filtered result set containing only tables of the specified types
578+
"""
579+
# Default table types if none specified
580+
DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"]
581+
valid_types = table_types if table_types else DEFAULT_TABLE_TYPES
582+
583+
# Convert to uppercase for case-insensitive comparison
584+
valid_types_upper = [t.upper() for t in valid_types]
585+
586+
# Create a filtered version of the result set
587+
from databricks.sql.backend.sea_result_set import SeaResultSet
588+
589+
if isinstance(result_set, SeaResultSet):
590+
# For SEA result sets, we need to filter the rows in the response
591+
filtered_response = result_set._response.copy()
592+
593+
# If there's a result with rows, filter them
594+
if (
595+
"result" in filtered_response
596+
and "data_array" in filtered_response["result"]
597+
):
598+
rows = filtered_response["result"]["data_array"]
599+
# Table type is typically in the 4th column (index 3)
600+
filtered_rows = [
601+
row
602+
for row in rows
603+
if len(row) > 3
604+
and (
605+
isinstance(row[3], str) and row[3].upper() in valid_types_upper
606+
)
607+
]
608+
filtered_response["result"]["data_array"] = filtered_rows
609+
610+
# Update row count if present
611+
if "row_count" in filtered_response["result"]:
612+
filtered_response["result"]["row_count"] = len(filtered_rows)
613+
614+
# Create a new result set with the filtered data
615+
return SeaResultSet(
616+
connection=result_set.connection,
617+
sea_response=filtered_response,
618+
sea_client=result_set._sea_client,
619+
buffer_size_bytes=result_set._buffer_size_bytes,
620+
arraysize=result_set._arraysize,
621+
)
622+
623+
# For other result set types, return the original (should not happen with SEA)
624+
logger.warning("Table type filtering not implemented for this result set type")
625+
return result_set
626+
557627
def get_columns(
558628
self,
559629
session_id: SessionId,

0 commit comments

Comments
 (0)