Skip to content

Commit 390c1e7

Browse files
add client side filtering for session confs, add note on warehouses over endoints
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 46104e2 commit 390c1e7

File tree

3 files changed

+117
-6
lines changed

3 files changed

+117
-6
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import re
3-
from typing import Dict, Tuple, List, Optional, TYPE_CHECKING
3+
from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set
44

55
if TYPE_CHECKING:
66
from databricks.sql.client import Cursor
@@ -9,6 +9,9 @@
99
from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType
1010
from databricks.sql.exc import ServerOperationError
1111
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient
12+
from databricks.sql.backend.sea.utils.constants import (
13+
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
14+
)
1215
from databricks.sql.thrift_api.TCLIService import ttypes
1316
from databricks.sql.types import SSLOptions
1417

@@ -21,6 +24,34 @@
2124
logger = logging.getLogger(__name__)
2225

2326

27+
def _filter_session_configuration(
28+
session_configuration: Optional[Dict[str, str]]
29+
) -> Optional[Dict[str, str]]:
30+
if not session_configuration:
31+
return None
32+
33+
filtered_session_configuration = {}
34+
ignored_configs: Set[str] = set()
35+
36+
for key, value in session_configuration.items():
37+
if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP:
38+
filtered_session_configuration[key.lower()] = value
39+
else:
40+
ignored_configs.add(key)
41+
42+
if ignored_configs:
43+
logger.warning(
44+
"Some session configurations were ignored because they are not supported: %s",
45+
ignored_configs,
46+
)
47+
logger.warning(
48+
"Supported session configurations are: %s",
49+
list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()),
50+
)
51+
52+
return filtered_session_configuration
53+
54+
2455
class SeaDatabricksClient(DatabricksClient):
2556
"""
2657
Statement Execution API (SEA) implementation of the DatabricksClient interface.
@@ -111,7 +142,8 @@ def _extract_warehouse_id(self, http_path: str) -> str:
111142
error_message = (
112143
f"Could not extract warehouse ID from http_path: {http_path}. "
113144
f"Expected format: /path/to/warehouses/{{warehouse_id}} or "
114-
f"/path/to/endpoints/{{warehouse_id}}"
145+
f"/path/to/endpoints/{{warehouse_id}}."
146+
f"Note: SEA only works for warehouses."
115147
)
116148
logger.error(error_message)
117149
raise ValueError(error_message)
@@ -152,6 +184,8 @@ def open_session(
152184
schema,
153185
)
154186

187+
session_configuration = _filter_session_configuration(session_configuration)
188+
155189
request_data = CreateSessionRequest(
156190
warehouse_id=self.warehouse_id,
157191
session_confs=session_configuration,
@@ -205,6 +239,29 @@ def close_session(self, session_id: SessionId) -> None:
205239
data=request_data.to_dict(),
206240
)
207241

242+
@staticmethod
243+
def get_default_session_configuration_value(name: str) -> Optional[str]:
244+
"""
245+
Get the default value for a session configuration parameter.
246+
247+
Args:
248+
name: The name of the session configuration parameter
249+
250+
Returns:
251+
The default value if the parameter is supported, None otherwise
252+
"""
253+
return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper())
254+
255+
@staticmethod
256+
def get_allowed_session_configurations() -> List[str]:
257+
"""
258+
Get the list of allowed session configuration parameters.
259+
260+
Returns:
261+
List of allowed session configuration parameter names
262+
"""
263+
return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys())
264+
208265
# == Not Implemented Operations ==
209266
# These methods will be implemented in future iterations
210267

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
Constants for the Statement Execution API (SEA) backend.
3+
"""
4+
5+
from typing import Dict
6+
7+
# from https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters
8+
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: Dict[str, str] = {
9+
"ANSI_MODE": "true",
10+
"ENABLE_PHOTON": "true",
11+
"LEGACY_TIME_PARSER_POLICY": "Exception",
12+
"MAX_FILE_PARTITION_BYTES": "128m",
13+
"READ_ONLY_EXTERNAL_METASTORE": "false",
14+
"STATEMENT_TIMEOUT": "0",
15+
"TIMEZONE": "UTC",
16+
"USE_CACHED_RESULT": "true",
17+
}

tests/unit/test_sea_backend.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,12 @@ def test_open_session_with_all_parameters(self, sea_client, mock_http_client):
101101
# Set up mock response
102102
mock_http_client._make_request.return_value = {"session_id": "test-session-456"}
103103

104-
# Call the method with all parameters
105-
session_config = {"ANSI_MODE": "FALSE", "STATEMENT_TIMEOUT": "3600"}
104+
# Call the method with all parameters, including both supported and unsupported configurations
105+
session_config = {
106+
"ANSI_MODE": "FALSE", # Supported parameter
107+
"STATEMENT_TIMEOUT": "3600", # Supported parameter
108+
"unsupported_param": "value", # Unsupported parameter
109+
}
106110
catalog = "test_catalog"
107111
schema = "test_schema"
108112

@@ -113,10 +117,14 @@ def test_open_session_with_all_parameters(self, sea_client, mock_http_client):
113117
assert session_id.backend_type == BackendType.SEA
114118
assert session_id.guid == "test-session-456"
115119

116-
# Verify the HTTP request
120+
# Verify the HTTP request - only supported parameters should be included
121+
# and keys should be in lowercase
117122
expected_data = {
118123
"warehouse_id": "abc123",
119-
"session_confs": session_config,
124+
"session_confs": {
125+
"ansi_mode": "FALSE",
126+
"statement_timeout": "3600",
127+
},
120128
"catalog": catalog,
121129
"schema": schema,
122130
}
@@ -166,3 +174,32 @@ def test_close_session_invalid_id_type(self, sea_client):
166174
sea_client.close_session(session_id)
167175

168176
assert "Not a valid SEA session ID" in str(excinfo.value)
177+
178+
def test_session_configuration_helpers(self):
179+
"""Test the session configuration helper methods."""
180+
# Test getting default value for a supported parameter
181+
default_value = SeaDatabricksClient.get_default_session_configuration_value(
182+
"ANSI_MODE"
183+
)
184+
assert default_value == "true"
185+
186+
# Test getting default value for an unsupported parameter
187+
default_value = SeaDatabricksClient.get_default_session_configuration_value(
188+
"UNSUPPORTED_PARAM"
189+
)
190+
assert default_value is None
191+
192+
# Test getting the list of allowed configurations
193+
allowed_configs = SeaDatabricksClient.get_allowed_session_configurations()
194+
195+
expected_keys = {
196+
"ANSI_MODE",
197+
"ENABLE_PHOTON",
198+
"LEGACY_TIME_PARSER_POLICY",
199+
"MAX_FILE_PARTITION_BYTES",
200+
"READ_ONLY_EXTERNAL_METASTORE",
201+
"STATEMENT_TIMEOUT",
202+
"TIMEZONE",
203+
"USE_CACHED_RESULT",
204+
}
205+
assert set(allowed_configs) == expected_keys

0 commit comments

Comments
 (0)