Skip to content

Commit 54c7f6d

Browse files
Merge branch 'exec-resp-norm' into cloudfetch-sea
2 parents cdc7f42 + 2cd04df commit 54c7f6d

23 files changed

+764
-1925
lines changed

examples/experimental/sea_connector_test.py

Lines changed: 15 additions & 611 deletions
Large diffs are not rendered by default.

src/databricks/sql/backend/databricks_client.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33
44
Implementations of this class are responsible for:
55
- Managing connections to Databricks SQL services
6-
- Handling authentication
76
- Executing SQL queries and commands
87
- Retrieving query results
98
- Fetching metadata about catalogs, schemas, tables, and columns
10-
- Managing error handling and retries
119
"""
1210

1311
from abc import ABC, abstractmethod
@@ -178,29 +176,28 @@ def get_columns(
178176
table_name: Optional[str] = None,
179177
column_name: Optional[str] = None,
180178
) -> "ResultSet":
181-
pass
182-
183-
# == Properties ==
184-
@property
185-
@abstractmethod
186-
def staging_allowed_local_path(self) -> Union[None, str, List[str]]:
187179
"""
188-
Gets the allowed local paths for staging operations.
180+
Retrieves a list of columns, optionally filtered by catalog, schema, table, and column name patterns.
189181
190-
Returns:
191-
Union[None, str, List[str]]: The allowed local paths for staging operations,
192-
or None if staging is not allowed
193-
"""
194-
pass
182+
This method fetches metadata about columns available in the specified table,
183+
or all tables if not specified.
195184
196-
@property
197-
@abstractmethod
198-
def ssl_options(self) -> SSLOptions:
199-
"""
200-
Gets the SSL options for this client.
185+
Args:
186+
session_id: The session identifier
187+
max_rows: Maximum number of rows to fetch in a single batch
188+
max_bytes: Maximum number of bytes to fetch in a single batch
189+
cursor: The cursor object that will handle the results
190+
catalog_name: Optional catalog name pattern to filter by
191+
schema_name: Optional schema name pattern to filter by
192+
table_name: Optional table name pattern to filter by
193+
column_name: Optional column name pattern to filter by
201194
202195
Returns:
203-
SSLOptions: The SSL configuration options
196+
ResultSet: An object containing the column metadata
197+
198+
Raises:
199+
ValueError: If the session ID is invalid
200+
OperationalError: If there's an error retrieving the columns
204201
"""
205202
pass
206203

src/databricks/sql/backend/sea/backend.py

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def __init__(
6969
http_headers: List[Tuple[str, str]],
7070
auth_provider,
7171
ssl_options: SSLOptions,
72-
staging_allowed_local_path: Union[None, str, List[str]] = None,
7372
**kwargs,
7473
):
7574
"""
@@ -82,25 +81,23 @@ def __init__(
8281
http_headers: List of HTTP headers to include in requests
8382
auth_provider: Authentication provider
8483
ssl_options: SSL configuration options
85-
staging_allowed_local_path: Allowed local paths for staging operations
8684
**kwargs: Additional keyword arguments
8785
"""
86+
8887
logger.debug(
89-
"SEADatabricksClient.__init__(server_hostname=%s, port=%s, http_path=%s)",
88+
"SeaDatabricksClient.__init__(server_hostname=%s, port=%s, http_path=%s)",
9089
server_hostname,
9190
port,
9291
http_path,
9392
)
9493

95-
self._staging_allowed_local_path = staging_allowed_local_path
96-
self._ssl_options = ssl_options
9794
self._max_download_threads = kwargs.get("max_download_threads", 10)
9895

9996
# Extract warehouse ID from http_path
10097
self.warehouse_id = self._extract_warehouse_id(http_path)
10198

10299
# Initialize HTTP client
103-
self.http_client = CustomHttpClient(
100+
self.http_client = SeaHttpClient(
104101
server_hostname=server_hostname,
105102
port=port,
106103
http_path=http_path,
@@ -114,47 +111,38 @@ def _extract_warehouse_id(self, http_path: str) -> str:
114111
"""
115112
Extract the warehouse ID from the HTTP path.
116113
117-
The warehouse ID is expected to be the last segment of the path when the
118-
second-to-last segment is either 'warehouses' or 'endpoints'.
119-
120114
Args:
121115
http_path: The HTTP path from which to extract the warehouse ID
122116
123117
Returns:
124118
The extracted warehouse ID
125119
126120
Raises:
127-
Error: If the warehouse ID cannot be extracted from the path
121+
ValueError: If the warehouse ID cannot be extracted from the path
128122
"""
129-
path_parts = http_path.strip("/").split("/")
130-
warehouse_id = None
131123

132-
if len(path_parts) >= 3 and path_parts[-2] in ["warehouses", "endpoints"]:
133-
warehouse_id = path_parts[-1]
124+
warehouse_pattern = re.compile(r".*/warehouses/(.+)")
125+
endpoint_pattern = re.compile(r".*/endpoints/(.+)")
126+
127+
for pattern in [warehouse_pattern, endpoint_pattern]:
128+
match = pattern.match(http_path)
129+
if not match:
130+
continue
131+
warehouse_id = match.group(1)
134132
logger.debug(
135133
f"Extracted warehouse ID: {warehouse_id} from path: {http_path}"
136134
)
137-
138-
if not warehouse_id:
139-
error_message = (
140-
f"Could not extract warehouse ID from http_path: {http_path}. "
141-
f"Expected format: /path/to/warehouses/{{warehouse_id}} or "
142-
f"/path/to/endpoints/{{warehouse_id}}"
143-
)
144-
logger.error(error_message)
145-
raise ValueError(error_message)
146-
147-
return warehouse_id
148-
149-
@property
150-
def staging_allowed_local_path(self) -> Union[None, str, List[str]]:
151-
"""Get the allowed local paths for staging operations."""
152-
return self._staging_allowed_local_path
153-
154-
@property
155-
def ssl_options(self) -> SSLOptions:
156-
"""Get the SSL options for this client."""
157-
return self._ssl_options
135+
return warehouse_id
136+
137+
# If no match found, raise error
138+
error_message = (
139+
f"Could not extract warehouse ID from http_path: {http_path}. "
140+
f"Expected format: /path/to/warehouses/{{warehouse_id}} or "
141+
f"/path/to/endpoints/{{warehouse_id}}."
142+
f"Note: SEA only works for warehouses."
143+
)
144+
logger.error(error_message)
145+
raise ValueError(error_message)
158146

159147
@property
160148
def max_download_threads(self) -> int:
@@ -171,7 +159,9 @@ def open_session(
171159
Opens a new session with the Databricks SQL service using SEA.
172160
173161
Args:
174-
session_configuration: Optional dictionary of configuration parameters for the session
162+
session_configuration: Optional dictionary of configuration parameters for the session.
163+
Only specific parameters are supported as documented at:
164+
https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters
175165
catalog: Optional catalog name to use as the initial catalog for the session
176166
schema: Optional schema name to use as the initial schema for the session
177167
@@ -182,28 +172,29 @@ def open_session(
182172
Error: If the session configuration is invalid
183173
OperationalError: If there's an error establishing the session
184174
"""
175+
185176
logger.debug(
186-
"SEADatabricksClient.open_session(session_configuration=%s, catalog=%s, schema=%s)",
177+
"SeaDatabricksClient.open_session(session_configuration=%s, catalog=%s, schema=%s)",
187178
session_configuration,
188179
catalog,
189180
schema,
190181
)
191182

192-
request = CreateSessionRequest(
183+
session_configuration = _filter_session_configuration(session_configuration)
184+
185+
request_data = CreateSessionRequest(
193186
warehouse_id=self.warehouse_id,
194187
session_confs=session_configuration,
195188
catalog=catalog,
196189
schema=schema,
197190
)
198191

199192
response = self.http_client._make_request(
200-
method="POST", path=self.SESSION_PATH, data=request.to_dict()
193+
method="POST", path=self.SESSION_PATH, data=request_data.to_dict()
201194
)
202195

203-
# Parse the response
204196
session_response = CreateSessionResponse.from_dict(response)
205197
session_id = session_response.session_id
206-
207198
if not session_id:
208199
raise ServerOperationError(
209200
"Failed to create session: No session ID returned",
@@ -226,14 +217,16 @@ def close_session(self, session_id: SessionId) -> None:
226217
ValueError: If the session ID is invalid
227218
OperationalError: If there's an error closing the session
228219
"""
229-
logger.debug("SEADatabricksClient.close_session(session_id=%s)", session_id)
220+
221+
logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id)
230222

231223
if session_id.backend_type != BackendType.SEA:
232224
raise ValueError("Not a valid SEA session ID")
233225
sea_session_id = session_id.to_sea_session_id()
234226

235-
request = DeleteSessionRequest(
236-
warehouse_id=self.warehouse_id, session_id=sea_session_id
227+
request_data = DeleteSessionRequest(
228+
warehouse_id=self.warehouse_id,
229+
session_id=sea_session_id,
237230
)
238231

239232
self.http_client._make_request(

src/databricks/sql/backend/sea/models/requests.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,5 @@ class DeleteSessionRequest:
129129
session_id: str
130130

131131
def to_dict(self) -> Dict[str, str]:
132-
"""
133-
Convert the request to query parameters.
134-
135-
In the SEA API, only the warehouse_id is sent as a query parameter.
136-
The session_id is included in the URL path.
137-
138-
Returns:
139-
A dictionary containing the warehouse_id as a query parameter
140-
"""
132+
"""Convert the request to a dictionary for JSON serialization."""
141133
return {"warehouse_id": self.warehouse_id, "session_id": self.session_id}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
Constants for the Statement Execution API (SEA) backend.
3+
"""
4+
5+
from typing import Dict
6+
7+
# from https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters
8+
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: Dict[str, str] = {
9+
"ANSI_MODE": "true",
10+
"ENABLE_PHOTON": "true",
11+
"LEGACY_TIME_PARSER_POLICY": "Exception",
12+
"MAX_FILE_PARTITION_BYTES": "128m",
13+
"READ_ONLY_EXTERNAL_METASTORE": "false",
14+
"STATEMENT_TIMEOUT": "0",
15+
"TIMEZONE": "UTC",
16+
"USE_CACHED_RESULT": "true",
17+
}

src/databricks/sql/backend/sea/utils/http_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
logger = logging.getLogger(__name__)
1111

1212

13-
class CustomHttpClient:
13+
class SeaHttpClient:
1414
"""
1515
HTTP client for Statement Execution API (SEA).
1616
@@ -145,7 +145,7 @@ def _make_request(
145145
if response.content:
146146
result = response.json()
147147
# Log response content (but limit it for large responses)
148-
content_str = json.dumps(result, indent=2, sort_keys=True)
148+
content_str = json.dumps(result)
149149
if len(content_str) > 1000:
150150
logger.debug(
151151
f"Response content (truncated): {content_str[:1000]}..."

0 commit comments

Comments
 (0)