88from databricks .sql .backend .databricks_client import DatabricksClient
99from databricks .sql .backend .types import SessionId , CommandId , CommandState , BackendType
1010from databricks .sql .exc import Error , NotSupportedError
11- from databricks .sql .sea . http_client import SEAHttpClient
11+ from databricks .sql .backend . utils . http_client import CustomHttpClient
1212from databricks .sql .thrift_api .TCLIService import ttypes
1313from databricks .sql .types import SSLOptions
1414
1515logger = logging .getLogger (__name__ )
1616
1717
18- class SEADatabricksClient (DatabricksClient ):
18+ class SeaDatabricksClient (DatabricksClient ):
1919 """
2020 Statement Execution API (SEA) implementation of the DatabricksClient interface.
2121
@@ -67,28 +67,10 @@ def __init__(
6767 self ._max_download_threads = kwargs .get ("max_download_threads" , 10 )
6868
6969 # Extract warehouse ID from http_path
70- # Format could be either:
71- # - /sql/1.0/endpoints/{warehouse_id}
72- # - /sql/1.0/warehouses/{warehouse_id}
73- path_parts = http_path .strip ("/" ).split ("/" )
74- self .warehouse_id = None
75-
76- if len (path_parts ) >= 3 :
77- if path_parts [- 2 ] in ["endpoints" , "warehouses" ]:
78- self .warehouse_id = path_parts [- 1 ]
79- logger .debug (
80- f"Extracted warehouse ID: { self .warehouse_id } from path: { http_path } "
81- )
82-
83- if not self .warehouse_id :
84- logger .warning (
85- "Could not extract warehouse ID from http_path: %s. "
86- "Session creation may fail if warehouse ID is required." ,
87- http_path ,
88- )
70+ self .warehouse_id = self ._extract_warehouse_id (http_path )
8971
9072 # Initialize HTTP client
91- self .http_client = SEAHttpClient (
73+ self .http_client = CustomHttpClient (
9274 server_hostname = server_hostname ,
9375 port = port ,
9476 http_path = http_path ,
@@ -98,6 +80,41 @@ def __init__(
9880 ** kwargs ,
9981 )
10082
83+ def _extract_warehouse_id (self , http_path : str ) -> str :
84+ """
85+ Extract the warehouse ID from the HTTP path.
86+
87+ The warehouse ID is expected to be the last segment of the path when the
88+ second-to-last segment is either 'warehouses' or 'endpoints'.
89+ This matches the JDBC implementation which supports both formats.
90+
91+ Args:
92+ http_path: The HTTP path from which to extract the warehouse ID
93+
94+ Returns:
95+ The extracted warehouse ID
96+
97+ Raises:
98+ Error: If the warehouse ID cannot be extracted from the path
99+ """
100+ path_parts = http_path .strip ("/" ).split ("/" )
101+ warehouse_id = None
102+
103+ if len (path_parts ) >= 3 and path_parts [- 2 ] in ["warehouses" , "endpoints" ]:
104+ warehouse_id = path_parts [- 1 ]
105+ logger .debug (f"Extracted warehouse ID: { warehouse_id } from path: { http_path } " )
106+
107+ if not warehouse_id :
108+ error_message = (
109+ f"Could not extract warehouse ID from http_path: { http_path } . "
110+ f"Expected format: /path/to/warehouses/{{warehouse_id}} or "
111+ f"/path/to/endpoints/{{warehouse_id}}"
112+ )
113+ logger .error (error_message )
114+ raise ValueError (error_message )
115+
116+ return warehouse_id
117+
101118 @property
102119 def staging_allowed_local_path (self ) -> Union [None , str , List [str ]]:
103120 """Get the allowed local paths for staging operations."""
@@ -115,7 +132,7 @@ def max_download_threads(self) -> int:
115132
116133 def open_session (
117134 self ,
118- session_configuration : Optional [Dict [str , Any ]],
135+ session_configuration : Optional [Dict [str , str ]],
119136 catalog : Optional [str ],
120137 schema : Optional [str ],
121138 ) -> SessionId :
@@ -141,36 +158,23 @@ def open_session(
141158 schema ,
142159 )
143160
144- # Prepare request payload
145- request_data : Dict [str , Any ] = {}
146-
147- if self .warehouse_id :
148- request_data ["warehouse_id" ] = self .warehouse_id
149-
161+ request_data : Dict [str , Any ] = {"warehouse_id" : self .warehouse_id }
150162 if session_configuration :
151- # The SEA API expects "session_confs" as the key for session configuration
152163 request_data ["session_confs" ] = session_configuration
153-
154164 if catalog :
155165 request_data ["catalog" ] = catalog
156-
157166 if schema :
158167 request_data ["schema" ] = schema
159168
160- # Make API request
161169 response = self .http_client ._make_request (
162170 method = "POST" , path = self .SESSION_PATH , data = request_data
163171 )
164172
165- # Extract session ID from response
166173 session_id = response .get ("session_id" )
167174 if not session_id :
168175 raise Error ("Failed to create session: No session ID returned" )
169176
170- # Create and return SessionId object
171- return SessionId .from_sea_session_id (
172- session_id , {"warehouse_id" : self .warehouse_id }
173- )
177+ return SessionId .from_sea_session_id (session_id )
174178
175179 def close_session (self , session_id : SessionId ) -> None :
176180 """
@@ -185,16 +189,11 @@ def close_session(self, session_id: SessionId) -> None:
185189 """
186190 logger .debug ("SEADatabricksClient.close_session(session_id=%s)" , session_id )
187191
188- # Validate session ID
189192 if session_id .backend_type != BackendType .SEA :
190193 raise ValueError ("Not a valid SEA session ID" )
191-
192194 sea_session_id = session_id .to_sea_session_id ()
193195
194- # Make API request with warehouse_id as a query parameter
195- request_data = {}
196- if self .warehouse_id :
197- request_data ["warehouse_id" ] = self .warehouse_id
196+ request_data = {"warehouse_id" : self .warehouse_id }
198197
199198 self .http_client ._make_request (
200199 method = "DELETE" ,
0 commit comments