|
1 | 1 | import logging |
2 | 2 | import re |
3 | | -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING |
| 3 | +from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set |
4 | 4 |
|
5 | 5 | if TYPE_CHECKING: |
6 | 6 | from databricks.sql.client import Cursor |
|
9 | 9 | from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType |
10 | 10 | from databricks.sql.exc import ServerOperationError |
11 | 11 | from databricks.sql.backend.sea.utils.http_client import SeaHttpClient |
| 12 | +from databricks.sql.backend.sea.utils.constants import ( |
| 13 | + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, |
| 14 | +) |
12 | 15 | from databricks.sql.thrift_api.TCLIService import ttypes |
13 | 16 | from databricks.sql.types import SSLOptions |
14 | 17 |
|
|
21 | 24 | logger = logging.getLogger(__name__) |
22 | 25 |
|
23 | 26 |
|
| 27 | +def _filter_session_configuration( |
| 28 | + session_configuration: Optional[Dict[str, str]] |
| 29 | +) -> Optional[Dict[str, str]]: |
| 30 | + if not session_configuration: |
| 31 | + return None |
| 32 | + |
| 33 | + filtered_session_configuration = {} |
| 34 | + ignored_configs: Set[str] = set() |
| 35 | + |
| 36 | + for key, value in session_configuration.items(): |
| 37 | + if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: |
| 38 | + filtered_session_configuration[key.lower()] = value |
| 39 | + else: |
| 40 | + ignored_configs.add(key) |
| 41 | + |
| 42 | + if ignored_configs: |
| 43 | + logger.warning( |
| 44 | + "Some session configurations were ignored because they are not supported: %s", |
| 45 | + ignored_configs, |
| 46 | + ) |
| 47 | + logger.warning( |
| 48 | + "Supported session configurations are: %s", |
| 49 | + list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()), |
| 50 | + ) |
| 51 | + |
| 52 | + return filtered_session_configuration |
| 53 | + |
| 54 | + |
24 | 55 | class SeaDatabricksClient(DatabricksClient): |
25 | 56 | """ |
26 | 57 | Statement Execution API (SEA) implementation of the DatabricksClient interface. |
@@ -111,7 +142,8 @@ def _extract_warehouse_id(self, http_path: str) -> str: |
111 | 142 | error_message = ( |
112 | 143 | f"Could not extract warehouse ID from http_path: {http_path}. " |
113 | 144 | f"Expected format: /path/to/warehouses/{{warehouse_id}} or " |
114 | | - f"/path/to/endpoints/{{warehouse_id}}" |
| 145 | + f"/path/to/endpoints/{{warehouse_id}}." |
| 146 | + f"Note: SEA only works for warehouses." |
115 | 147 | ) |
116 | 148 | logger.error(error_message) |
117 | 149 | raise ValueError(error_message) |
@@ -152,6 +184,8 @@ def open_session( |
152 | 184 | schema, |
153 | 185 | ) |
154 | 186 |
|
| 187 | + session_configuration = _filter_session_configuration(session_configuration) |
| 188 | + |
155 | 189 | request_data = CreateSessionRequest( |
156 | 190 | warehouse_id=self.warehouse_id, |
157 | 191 | session_confs=session_configuration, |
@@ -205,6 +239,29 @@ def close_session(self, session_id: SessionId) -> None: |
205 | 239 | data=request_data.to_dict(), |
206 | 240 | ) |
207 | 241 |
|
| 242 | + @staticmethod |
| 243 | + def get_default_session_configuration_value(name: str) -> Optional[str]: |
| 244 | + """ |
| 245 | + Get the default value for a session configuration parameter. |
| 246 | +
|
| 247 | + Args: |
| 248 | + name: The name of the session configuration parameter |
| 249 | +
|
| 250 | + Returns: |
| 251 | + The default value if the parameter is supported, None otherwise |
| 252 | + """ |
| 253 | + return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper()) |
| 254 | + |
| 255 | + @staticmethod |
| 256 | + def get_allowed_session_configurations() -> List[str]: |
| 257 | + """ |
| 258 | + Get the list of allowed session configuration parameters. |
| 259 | +
|
| 260 | + Returns: |
| 261 | + List of allowed session configuration parameter names |
| 262 | + """ |
| 263 | + return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) |
| 264 | + |
208 | 265 | # == Not Implemented Operations == |
209 | 266 | # These methods will be implemented in future iterations |
210 | 267 |
|
|
0 commit comments