Skip to content

Commit 8985c62

Browse files
[squashed from exec-sea] init execution func
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent ea9d456 commit 8985c62

File tree

15 files changed

+2166
-219
lines changed

15 files changed

+2166
-219
lines changed

examples/experimental/sea_connector_test.py

Lines changed: 110 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,122 @@
66
logging.basicConfig(level=logging.DEBUG)
77
logger = logging.getLogger(__name__)
88

9+
10+
def test_sea_query_exec():
11+
"""
12+
Test executing a query using the SEA backend with result compression.
13+
14+
This function connects to a Databricks SQL endpoint using the SEA backend,
15+
executes a simple query with result compression enabled and disabled,
16+
and verifies that execution completes successfully.
17+
"""
18+
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
19+
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
20+
access_token = os.environ.get("DATABRICKS_TOKEN")
21+
catalog = os.environ.get("DATABRICKS_CATALOG")
22+
23+
if not all([server_hostname, http_path, access_token]):
24+
logger.error("Missing required environment variables.")
25+
logger.error(
26+
"Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN."
27+
)
28+
sys.exit(1)
29+
30+
try:
31+
# Test with compression enabled
32+
logger.info("Creating connection with LZ4 compression enabled")
33+
connection = Connection(
34+
server_hostname=server_hostname,
35+
http_path=http_path,
36+
access_token=access_token,
37+
catalog=catalog,
38+
schema="default",
39+
use_sea=True,
40+
user_agent_entry="SEA-Test-Client",
41+
use_cloud_fetch=True, # Enable cloud fetch to use compression
42+
enable_query_result_lz4_compression=True, # Enable LZ4 compression
43+
)
44+
45+
logger.info(
46+
f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}"
47+
)
48+
logger.info(f"backend type: {type(connection.session.backend)}")
49+
50+
# Execute a simple query with compression enabled
51+
cursor = connection.cursor(arraysize=0, buffer_size_bytes=0)
52+
logger.info("Executing query with LZ4 compression: SELECT 1 as test_value")
53+
cursor.execute("SELECT 1 as test_value")
54+
logger.info("Query with compression executed successfully")
55+
cursor.close()
56+
connection.close()
57+
logger.info("Successfully closed SEA session with compression enabled")
58+
59+
# Test with compression disabled
60+
logger.info("Creating connection with LZ4 compression disabled")
61+
connection = Connection(
62+
server_hostname=server_hostname,
63+
http_path=http_path,
64+
access_token=access_token,
65+
catalog=catalog,
66+
schema="default",
67+
use_sea=True,
68+
user_agent_entry="SEA-Test-Client",
69+
use_cloud_fetch=False, # Enable cloud fetch
70+
enable_query_result_lz4_compression=False, # Disable LZ4 compression
71+
)
72+
73+
logger.info(
74+
f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}"
75+
)
76+
77+
# Execute a simple query with compression disabled
78+
cursor = connection.cursor(arraysize=0, buffer_size_bytes=0)
79+
logger.info("Executing query without compression: SELECT 1 as test_value")
80+
cursor.execute("SELECT 1 as test_value")
81+
logger.info("Query without compression executed successfully")
82+
cursor.close()
83+
connection.close()
84+
logger.info("Successfully closed SEA session with compression disabled")
85+
86+
except Exception as e:
87+
logger.error(f"Error during SEA query execution test: {str(e)}")
88+
import traceback
89+
90+
logger.error(traceback.format_exc())
91+
sys.exit(1)
92+
93+
logger.info("SEA query execution test with compression completed successfully")
94+
95+
996
def test_sea_session():
1097
"""
1198
Test opening and closing a SEA session using the connector.
12-
99+
13100
This function connects to a Databricks SQL endpoint using the SEA backend,
14101
opens a session, and then closes it.
15-
102+
16103
Required environment variables:
17104
- DATABRICKS_SERVER_HOSTNAME: Databricks server hostname
18105
- DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint
19106
- DATABRICKS_TOKEN: Personal access token for authentication
20107
"""
21-
22108
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
23109
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
24110
access_token = os.environ.get("DATABRICKS_TOKEN")
25111
catalog = os.environ.get("DATABRICKS_CATALOG")
26-
112+
27113
if not all([server_hostname, http_path, access_token]):
28114
logger.error("Missing required environment variables.")
29-
logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.")
115+
logger.error(
116+
"Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN."
117+
)
30118
sys.exit(1)
31-
119+
32120
logger.info(f"Connecting to {server_hostname}")
33121
logger.info(f"HTTP Path: {http_path}")
34122
if catalog:
35123
logger.info(f"Using catalog: {catalog}")
36-
124+
37125
try:
38126
logger.info("Creating connection with SEA backend...")
39127
connection = Connection(
@@ -42,25 +130,33 @@ def test_sea_session():
42130
access_token=access_token,
43131
catalog=catalog,
44132
schema="default",
45-
use_sea=True,
46-
user_agent_entry="SEA-Test-Client" # add custom user agent
133+
use_sea=True,
134+
user_agent_entry="SEA-Test-Client", # add custom user agent
135+
)
136+
137+
logger.info(
138+
f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}"
47139
)
48-
49-
logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}")
50140
logger.info(f"backend type: {type(connection.session.backend)}")
51-
141+
52142
# Close the connection
53143
logger.info("Closing the SEA session...")
54144
connection.close()
55145
logger.info("Successfully closed SEA session")
56-
146+
57147
except Exception as e:
58148
logger.error(f"Error testing SEA session: {str(e)}")
59149
import traceback
150+
60151
logger.error(traceback.format_exc())
61152
sys.exit(1)
62-
153+
63154
logger.info("SEA session test completed successfully")
64155

156+
65157
if __name__ == "__main__":
158+
# Test session management
66159
test_sea_session()
160+
161+
# Test query execution with compression
162+
test_sea_query_exec()

src/databricks/sql/backend/databricks_client.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -86,34 +86,6 @@ def execute_command(
8686
async_op: bool,
8787
enforce_embedded_schema_correctness: bool,
8888
) -> Union["ResultSet", None]:
89-
"""
90-
Executes a SQL command or query within the specified session.
91-
92-
This method sends a SQL command to the server for execution and handles
93-
the response. It can operate in both synchronous and asynchronous modes.
94-
95-
Args:
96-
operation: The SQL command or query to execute
97-
session_id: The session identifier in which to execute the command
98-
max_rows: Maximum number of rows to fetch in a single fetch batch
99-
max_bytes: Maximum number of bytes to fetch in a single fetch batch
100-
lz4_compression: Whether to use LZ4 compression for result data
101-
cursor: The cursor object that will handle the results
102-
use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets
103-
parameters: List of parameters to bind to the query
104-
async_op: Whether to execute the command asynchronously
105-
enforce_embedded_schema_correctness: Whether to enforce schema correctness
106-
107-
Returns:
108-
If async_op is False, returns a ResultSet object containing the
109-
query results and metadata. If async_op is True, returns None and the
110-
results must be fetched later using get_execution_result().
111-
112-
Raises:
113-
ValueError: If the session ID is invalid
114-
OperationalError: If there's an error executing the command
115-
ServerOperationError: If the server encounters an error during execution
116-
"""
11789
pass
11890

11991
@abstractmethod
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""
2+
Client-side filtering utilities for Databricks SQL connector.
3+
4+
This module provides filtering capabilities for result sets returned by different backends.
5+
"""
6+
7+
import logging
8+
from typing import (
9+
List,
10+
Optional,
11+
Any,
12+
Callable,
13+
TYPE_CHECKING,
14+
)
15+
16+
if TYPE_CHECKING:
17+
from databricks.sql.result_set import ResultSet
18+
19+
from databricks.sql.result_set import SeaResultSet
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class ResultSetFilter:
25+
"""
26+
A general-purpose filter for result sets that can be applied to any backend.
27+
28+
This class provides methods to filter result sets based on various criteria,
29+
similar to the client-side filtering in the JDBC connector.
30+
"""
31+
32+
@staticmethod
33+
def _filter_sea_result_set(
34+
result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool]
35+
) -> "SeaResultSet":
36+
"""
37+
Filter a SEA result set using the provided filter function.
38+
39+
Args:
40+
result_set: The SEA result set to filter
41+
filter_func: Function that takes a row and returns True if the row should be included
42+
43+
Returns:
44+
A filtered SEA result set
45+
"""
46+
# Create a filtered version of the result set
47+
filtered_response = result_set._response.copy()
48+
49+
# If there's a result with rows, filter them
50+
if (
51+
"result" in filtered_response
52+
and "data_array" in filtered_response["result"]
53+
):
54+
rows = filtered_response["result"]["data_array"]
55+
filtered_rows = [row for row in rows if filter_func(row)]
56+
filtered_response["result"]["data_array"] = filtered_rows
57+
58+
# Update row count if present
59+
if "row_count" in filtered_response["result"]:
60+
filtered_response["result"]["row_count"] = len(filtered_rows)
61+
62+
# Create a new result set with the filtered data
63+
return SeaResultSet(
64+
connection=result_set.connection,
65+
sea_response=filtered_response,
66+
sea_client=result_set.backend,
67+
buffer_size_bytes=result_set.buffer_size_bytes,
68+
arraysize=result_set.arraysize,
69+
)
70+
71+
@staticmethod
72+
def filter_by_column_values(
73+
result_set: "ResultSet",
74+
column_index: int,
75+
allowed_values: List[str],
76+
case_sensitive: bool = False,
77+
) -> "ResultSet":
78+
"""
79+
Filter a result set by values in a specific column.
80+
81+
Args:
82+
result_set: The result set to filter
83+
column_index: The index of the column to filter on
84+
allowed_values: List of allowed values for the column
85+
case_sensitive: Whether to perform case-sensitive comparison
86+
87+
Returns:
88+
A filtered result set
89+
"""
90+
# Convert to uppercase for case-insensitive comparison if needed
91+
if not case_sensitive:
92+
allowed_values = [v.upper() for v in allowed_values]
93+
94+
# Determine the type of result set and apply appropriate filtering
95+
if isinstance(result_set, SeaResultSet):
96+
return ResultSetFilter._filter_sea_result_set(
97+
result_set,
98+
lambda row: (
99+
len(row) > column_index
100+
and isinstance(row[column_index], str)
101+
and (
102+
row[column_index].upper()
103+
if not case_sensitive
104+
else row[column_index]
105+
)
106+
in allowed_values
107+
),
108+
)
109+
110+
# For other result set types, return the original (should be handled by specific implementations)
111+
logger.warning(
112+
f"Filtering not implemented for result set type: {type(result_set).__name__}"
113+
)
114+
return result_set
115+
116+
@staticmethod
117+
def filter_tables_by_type(
118+
result_set: "ResultSet", table_types: Optional[List[str]] = None
119+
) -> "ResultSet":
120+
"""
121+
Filter a result set of tables by the specified table types.
122+
123+
This is a client-side filter that processes the result set after it has been
124+
retrieved from the server. It filters out tables whose type does not match
125+
any of the types in the table_types list.
126+
127+
Args:
128+
result_set: The original result set containing tables
129+
table_types: List of table types to include (e.g., ["TABLE", "VIEW"])
130+
131+
Returns:
132+
A filtered result set containing only tables of the specified types
133+
"""
134+
# Default table types if none specified
135+
DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"]
136+
valid_types = (
137+
table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES
138+
)
139+
140+
# Table type is typically in the 6th column (index 5)
141+
return ResultSetFilter.filter_by_column_values(
142+
result_set, 5, valid_types, case_sensitive=False
143+
)

0 commit comments

Comments
 (0)