@@ -69,7 +69,6 @@ def __init__(
6969 http_headers : List [Tuple [str , str ]],
7070 auth_provider ,
7171 ssl_options : SSLOptions ,
72- staging_allowed_local_path : Union [None , str , List [str ]] = None ,
7372 ** kwargs ,
7473 ):
7574 """
@@ -82,25 +81,23 @@ def __init__(
8281 http_headers: List of HTTP headers to include in requests
8382 auth_provider: Authentication provider
8483 ssl_options: SSL configuration options
85- staging_allowed_local_path: Allowed local paths for staging operations
8684 **kwargs: Additional keyword arguments
8785 """
86+
8887 logger .debug (
89- "SEADatabricksClient .__init__(server_hostname=%s, port=%s, http_path=%s)" ,
88+ "SeaDatabricksClient .__init__(server_hostname=%s, port=%s, http_path=%s)" ,
9089 server_hostname ,
9190 port ,
9291 http_path ,
9392 )
9493
95- self ._staging_allowed_local_path = staging_allowed_local_path
96- self ._ssl_options = ssl_options
9794 self ._max_download_threads = kwargs .get ("max_download_threads" , 10 )
9895
9996 # Extract warehouse ID from http_path
10097 self .warehouse_id = self ._extract_warehouse_id (http_path )
10198
10299 # Initialize HTTP client
103- self .http_client = CustomHttpClient (
100+ self .http_client = SeaHttpClient (
104101 server_hostname = server_hostname ,
105102 port = port ,
106103 http_path = http_path ,
@@ -114,47 +111,38 @@ def _extract_warehouse_id(self, http_path: str) -> str:
114111 """
115112 Extract the warehouse ID from the HTTP path.
116113
117- The warehouse ID is expected to be the last segment of the path when the
118- second-to-last segment is either 'warehouses' or 'endpoints'.
119-
120114 Args:
121115 http_path: The HTTP path from which to extract the warehouse ID
122116
123117 Returns:
124118 The extracted warehouse ID
125119
126120 Raises:
127- Error : If the warehouse ID cannot be extracted from the path
121+ ValueError : If the warehouse ID cannot be extracted from the path
128122 """
129- path_parts = http_path .strip ("/" ).split ("/" )
130- warehouse_id = None
131123
132- if len (path_parts ) >= 3 and path_parts [- 2 ] in ["warehouses" , "endpoints" ]:
133- warehouse_id = path_parts [- 1 ]
124+ warehouse_pattern = re .compile (r".*/warehouses/(.+)" )
125+ endpoint_pattern = re .compile (r".*/endpoints/(.+)" )
126+
127+ for pattern in [warehouse_pattern , endpoint_pattern ]:
128+ match = pattern .match (http_path )
129+ if not match :
130+ continue
131+ warehouse_id = match .group (1 )
134132 logger .debug (
135133 f"Extracted warehouse ID: { warehouse_id } from path: { http_path } "
136134 )
137-
138- if not warehouse_id :
139- error_message = (
140- f"Could not extract warehouse ID from http_path: { http_path } . "
141- f"Expected format: /path/to/warehouses/{{warehouse_id}} or "
142- f"/path/to/endpoints/{{warehouse_id}}"
143- )
144- logger .error (error_message )
145- raise ValueError (error_message )
146-
147- return warehouse_id
148-
149- @property
150- def staging_allowed_local_path (self ) -> Union [None , str , List [str ]]:
151- """Get the allowed local paths for staging operations."""
152- return self ._staging_allowed_local_path
153-
154- @property
155- def ssl_options (self ) -> SSLOptions :
156- """Get the SSL options for this client."""
157- return self ._ssl_options
135+ return warehouse_id
136+
137+ # If no match found, raise error
138+ error_message = (
139+ f"Could not extract warehouse ID from http_path: { http_path } . "
140+ f"Expected format: /path/to/warehouses/{{warehouse_id}} or "
141+ f"/path/to/endpoints/{{warehouse_id}}."
142+ f"Note: SEA only works for warehouses."
143+ )
144+ logger .error (error_message )
145+ raise ValueError (error_message )
158146
159147 @property
160148 def max_download_threads (self ) -> int :
@@ -171,7 +159,9 @@ def open_session(
171159 Opens a new session with the Databricks SQL service using SEA.
172160
173161 Args:
174- session_configuration: Optional dictionary of configuration parameters for the session
162+ session_configuration: Optional dictionary of configuration parameters for the session.
163+ Only specific parameters are supported as documented at:
164+ https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters
175165 catalog: Optional catalog name to use as the initial catalog for the session
176166 schema: Optional schema name to use as the initial schema for the session
177167
@@ -182,28 +172,29 @@ def open_session(
182172 Error: If the session configuration is invalid
183173 OperationalError: If there's an error establishing the session
184174 """
175+
185176 logger .debug (
186- "SEADatabricksClient .open_session(session_configuration=%s, catalog=%s, schema=%s)" ,
177+ "SeaDatabricksClient .open_session(session_configuration=%s, catalog=%s, schema=%s)" ,
187178 session_configuration ,
188179 catalog ,
189180 schema ,
190181 )
191182
192- request = CreateSessionRequest (
183+ session_configuration = _filter_session_configuration (session_configuration )
184+
185+ request_data = CreateSessionRequest (
193186 warehouse_id = self .warehouse_id ,
194187 session_confs = session_configuration ,
195188 catalog = catalog ,
196189 schema = schema ,
197190 )
198191
199192 response = self .http_client ._make_request (
200- method = "POST" , path = self .SESSION_PATH , data = request .to_dict ()
193+ method = "POST" , path = self .SESSION_PATH , data = request_data .to_dict ()
201194 )
202195
203- # Parse the response
204196 session_response = CreateSessionResponse .from_dict (response )
205197 session_id = session_response .session_id
206-
207198 if not session_id :
208199 raise ServerOperationError (
209200 "Failed to create session: No session ID returned" ,
@@ -226,14 +217,16 @@ def close_session(self, session_id: SessionId) -> None:
226217 ValueError: If the session ID is invalid
227218 OperationalError: If there's an error closing the session
228219 """
229- logger .debug ("SEADatabricksClient.close_session(session_id=%s)" , session_id )
220+
221+ logger .debug ("SeaDatabricksClient.close_session(session_id=%s)" , session_id )
230222
231223 if session_id .backend_type != BackendType .SEA :
232224 raise ValueError ("Not a valid SEA session ID" )
233225 sea_session_id = session_id .to_sea_session_id ()
234226
235- request = DeleteSessionRequest (
236- warehouse_id = self .warehouse_id , session_id = sea_session_id
227+ request_data = DeleteSessionRequest (
228+ warehouse_id = self .warehouse_id ,
229+ session_id = sea_session_id ,
237230 )
238231
239232 self .http_client ._make_request (
0 commit comments