Skip to content

Commit 3b8e1ad

Browse files
Revert "cleaner HTTP client using requests.sessions"
This reverts commit b7a4677.
1 parent b7a4677 commit 3b8e1ad

File tree

1 file changed

+191
-97
lines changed

1 file changed

+191
-97
lines changed
Lines changed: 191 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import json
22
import logging
33
import ssl
4+
import urllib.parse
45
import urllib.request
56
from typing import Dict, Any, Optional, List, Tuple, Union
67
from urllib.parse import urljoin
78

8-
import requests
9-
from requests.adapters import HTTPAdapter
10-
from requests.exceptions import RequestException, HTTPError, ConnectionError
9+
from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
10+
from urllib3.util import make_headers
1111
from urllib3.exceptions import MaxRetryError
1212

1313
from databricks.sql.auth.authenticators import AuthProvider
@@ -23,46 +23,20 @@
2323
logger = logging.getLogger(__name__)
2424

2525

26-
class SSLContextAdapter(HTTPAdapter):
27-
"""
28-
An HTTP adapter that uses a custom SSLContext to handle advanced SSL settings,
29-
including client certificate key passwords.
30-
"""
31-
32-
def __init__(self, ssl_options: SSLOptions, **kwargs):
33-
self.ssl_context = self._create_ssl_context(ssl_options)
34-
super().__init__(**kwargs)
35-
36-
def _create_ssl_context(self, ssl_options: SSLOptions) -> ssl.SSLContext:
37-
"""
38-
Build a custom SSLContext based on the provided SSLOptions.
39-
"""
40-
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
41-
if not ssl_options.tls_verify:
42-
context.check_hostname = False
43-
context.verify_mode = ssl.CERT_NONE
44-
elif ssl_options.tls_trusted_ca_file:
45-
context.load_verify_locations(cafile=ssl_options.tls_trusted_ca_file)
46-
if ssl_options.tls_client_cert_file:
47-
context.load_cert_chain(
48-
certfile=ssl_options.tls_client_cert_file,
49-
keyfile=ssl_options.tls_client_cert_key_file,
50-
password=ssl_options.tls_client_cert_key_password,
51-
)
52-
return context
53-
54-
def init_poolmanager(self, *args, **kwargs):
55-
kwargs["ssl_context"] = self.ssl_context
56-
return super().init_poolmanager(*args, **kwargs)
57-
58-
5926
class SeaHttpClient:
6027
"""
61-
HTTP client for Statement Execution API (SEA), using the requests library.
28+
HTTP client for Statement Execution API (SEA).
29+
30+
This client uses urllib3 for robust HTTP communication with retry policies
31+
and connection pooling, similar to the Thrift HTTP client but simplified.
6232
"""
6333

6434
retry_policy: Union[DatabricksRetryPolicy, int]
65-
_session: requests.Session
35+
_pool: Optional[Union[HTTPConnectionPool, HTTPSConnectionPool]]
36+
proxy_uri: Optional[str]
37+
realhost: Optional[str]
38+
realport: Optional[int]
39+
proxy_auth: Optional[Dict[str, str]]
6640

6741
def __init__(
6842
self,
@@ -74,16 +48,39 @@ def __init__(
7448
ssl_options: SSLOptions,
7549
**kwargs,
7650
):
51+
"""
52+
Initialize the SEA HTTP client.
53+
54+
Args:
55+
server_hostname: Hostname of the Databricks server
56+
port: Port number for the connection
57+
http_path: HTTP path for the connection
58+
http_headers: List of HTTP headers to include in requests
59+
auth_provider: Authentication provider
60+
ssl_options: SSL configuration options
61+
**kwargs: Additional keyword arguments including retry policy settings
62+
"""
63+
7764
self.server_hostname = server_hostname
7865
self.port = port or 443
66+
self.http_path = http_path
7967
self.auth_provider = auth_provider
8068
self.ssl_options = ssl_options
81-
self.scheme = "https"
82-
self.base_url = f"{self.scheme}://{server_hostname}:{self.port}"
83-
self._session = requests.Session()
69+
70+
# Build base URL
71+
self.base_url = f"https://{server_hostname}:{self.port}"
72+
73+
# Parse URL for proxy handling
74+
parsed_url = urllib.parse.urlparse(self.base_url)
75+
self.scheme = parsed_url.scheme
76+
self.host = parsed_url.hostname
77+
self.port = parsed_url.port or (443 if self.scheme == "https" else 80)
78+
79+
# Setup headers
8480
self.headers: Dict[str, str] = dict(http_headers)
8581
self.headers.update({"Content-Type": "application/json"})
86-
self._session.headers.update(self.headers)
82+
83+
# Extract retry policy settings
8784
self._retry_delay_min = kwargs.get("_retry_delay_min", 1.0)
8885
self._retry_delay_max = kwargs.get("_retry_delay_max", 60.0)
8986
self._retry_stop_after_attempts_count = kwargs.get(
@@ -94,36 +91,23 @@ def __init__(
9491
)
9592
self._retry_delay_default = kwargs.get("_retry_delay_default", 5.0)
9693
self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
94+
95+
# Connection pooling settings
9796
self.max_connections = kwargs.get("max_connections", 10)
98-
self._configure_proxies()
99-
self.enable_v3_retries = kwargs.get("_enable_v3_retries", True)
100-
self._configure_retries_and_ssl(**kwargs)
10197

102-
def _configure_proxies(self):
103-
try:
104-
proxy = urllib.request.getproxies().get(self.scheme)
105-
except (KeyError, AttributeError):
106-
proxy = None
107-
else:
108-
if self.server_hostname and urllib.request.proxy_bypass(
109-
self.server_hostname
110-
):
111-
proxy = None
112-
if proxy:
113-
self._session.proxies = {"http": proxy, "https": proxy}
98+
# Setup retry policy
99+
self.enable_v3_retries = kwargs.get("_enable_v3_retries", True)
114100

115-
def _configure_retries_and_ssl(self, **kwargs):
116101
if self.enable_v3_retries:
117102
urllib3_kwargs = {"allowed_methods": ["GET", "POST", "DELETE"]}
118103
_max_redirects = kwargs.get("_retry_max_redirects")
119104
if _max_redirects:
120105
if _max_redirects > self._retry_stop_after_attempts_count:
121106
logger.warning(
122-
"_retry_max_redirects > _retry_stop_after_attempts_count "
123-
"so it will have no effect!"
107+
"_retry_max_redirects > _retry_stop_after_attempts_count so it will have no affect!"
124108
)
125109
urllib3_kwargs["redirect"] = _max_redirects
126-
self._session.max_redirects = _max_redirects
110+
127111
self.retry_policy = DatabricksRetryPolicy(
128112
delay_min=self._retry_delay_min,
129113
delay_max=self._retry_delay_max,
@@ -133,34 +117,104 @@ def _configure_retries_and_ssl(self, **kwargs):
133117
force_dangerous_codes=self.force_dangerous_codes,
134118
urllib3_kwargs=urllib3_kwargs,
135119
)
136-
retry_strategy = self.retry_policy
137120
else:
121+
# Legacy behavior - no automatic retries
138122
logger.warning(
139123
"Legacy retry behavior is enabled for this connection."
140124
" This behaviour is not supported for the SEA backend."
141125
)
142126
self.retry_policy = 0
143-
retry_strategy = 0
144-
adapter = SSLContextAdapter(
145-
ssl_options=self.ssl_options,
146-
pool_connections=self.max_connections,
147-
max_retries=retry_strategy,
148-
)
149-
self._session.mount("https://", adapter)
150-
self._session.mount("http://", adapter)
127+
128+
# Handle proxy settings
129+
try:
130+
proxy = urllib.request.getproxies().get(self.scheme)
131+
except (KeyError, AttributeError):
132+
proxy = None
133+
else:
134+
if self.host and urllib.request.proxy_bypass(self.host):
135+
proxy = None
136+
137+
if proxy:
138+
parsed_proxy = urllib.parse.urlparse(proxy)
139+
self.realhost = self.host
140+
self.realport = self.port
141+
self.proxy_uri = proxy
142+
self.host = parsed_proxy.hostname
143+
self.port = parsed_proxy.port or (443 if self.scheme == "https" else 80)
144+
self.proxy_auth = self._basic_proxy_auth_headers(parsed_proxy)
145+
else:
146+
self.realhost = None
147+
self.realport = None
148+
self.proxy_auth = None
149+
self.proxy_uri = None
150+
151+
# Initialize connection pool
152+
self._pool = None
153+
self._open()
154+
155+
def _basic_proxy_auth_headers(self, proxy_parsed) -> Optional[Dict[str, str]]:
156+
"""Create basic auth headers for proxy if credentials are provided."""
157+
if proxy_parsed is None or not proxy_parsed.username:
158+
return None
159+
ap = f"{urllib.parse.unquote(proxy_parsed.username)}:{urllib.parse.unquote(proxy_parsed.password)}"
160+
return make_headers(proxy_basic_auth=ap)
161+
162+
def _open(self):
163+
"""Initialize the connection pool."""
164+
pool_kwargs = {"maxsize": self.max_connections}
165+
166+
if self.scheme == "http":
167+
pool_class = HTTPConnectionPool
168+
else: # https
169+
pool_class = HTTPSConnectionPool
170+
pool_kwargs.update(
171+
{
172+
"cert_reqs": ssl.CERT_REQUIRED
173+
if self.ssl_options.tls_verify
174+
else ssl.CERT_NONE,
175+
"ca_certs": self.ssl_options.tls_trusted_ca_file,
176+
"cert_file": self.ssl_options.tls_client_cert_file,
177+
"key_file": self.ssl_options.tls_client_cert_key_file,
178+
"key_password": self.ssl_options.tls_client_cert_key_password,
179+
}
180+
)
181+
182+
if self.using_proxy():
183+
proxy_manager = ProxyManager(
184+
self.proxy_uri,
185+
num_pools=1,
186+
proxy_headers=self.proxy_auth,
187+
)
188+
self._pool = proxy_manager.connection_from_host(
189+
host=self.realhost,
190+
port=self.realport,
191+
scheme=self.scheme,
192+
pool_kwargs=pool_kwargs,
193+
)
194+
else:
195+
self._pool = pool_class(self.host, self.port, **pool_kwargs)
151196

152197
def close(self):
153-
self._session.close()
198+
"""Close the connection pool."""
199+
if self._pool:
200+
self._pool.clear()
201+
202+
def using_proxy(self) -> bool:
203+
"""Check if proxy is being used (for compatibility with Thrift client)."""
204+
return self.realhost is not None
154205

155206
def set_retry_command_type(self, command_type: CommandType):
207+
"""Set the command type for retry policy decision making."""
156208
if isinstance(self.retry_policy, DatabricksRetryPolicy):
157209
self.retry_policy.command_type = command_type
158210

159211
def start_retry_timer(self):
212+
"""Start the retry timer for duration-based retry limits."""
160213
if isinstance(self.retry_policy, DatabricksRetryPolicy):
161214
self.retry_policy.start_retry_timer()
162215

163216
def _get_auth_headers(self) -> Dict[str, str]:
217+
"""Get authentication headers from the auth provider."""
164218
headers: Dict[str, str] = {}
165219
self.auth_provider.add_headers(headers)
166220
return headers
@@ -171,51 +225,91 @@ def _make_request(
171225
path: str,
172226
data: Optional[Dict[str, Any]] = None,
173227
) -> Dict[str, Any]:
174-
full_url = urljoin(self.base_url, path)
175-
auth_headers = self._get_auth_headers()
228+
"""
229+
Make an HTTP request to the SEA endpoint.
230+
231+
Args:
232+
method: HTTP method (GET, POST, DELETE)
233+
path: API endpoint path
234+
data: Request payload data
235+
236+
Returns:
237+
Dict[str, Any]: Response data parsed from JSON
238+
239+
Raises:
240+
RequestError: If the request fails after retries
241+
"""
242+
243+
# Prepare headers
244+
headers = {**self.headers, **self._get_auth_headers()}
245+
246+
# Prepare request body
247+
body = json.dumps(data).encode("utf-8") if data else b""
248+
if body:
249+
headers["Content-Length"] = str(len(body))
250+
251+
# Set command type for retry policy
176252
command_type = self._get_command_type_from_path(path, method)
177253
self.set_retry_command_type(command_type)
178254
self.start_retry_timer()
179-
logger.debug(f"Making {method} request to {full_url}")
255+
256+
logger.debug(f"Making {method} request to {path}")
257+
258+
# When v3 retries are enabled, urllib3 handles retries internally via DatabricksRetryPolicy
259+
# When disabled, we let exceptions bubble up (similar to Thrift backend approach)
260+
if self._pool is None:
261+
raise RequestError("Connection pool not initialized", None)
262+
180263
try:
181-
with self._session.request(
264+
response = self._pool.request(
182265
method=method.upper(),
183-
url=full_url,
184-
json=data,
185-
headers=auth_headers,
186-
) as response:
187-
logger.debug(f"Response status: {response.status_code}")
188-
response.raise_for_status()
189-
return response.json()
190-
except requests.exceptions.ConnectionError as e:
191-
# Check if the first argument of the ConnectionError is a MaxRetryError
192-
if e.args and isinstance(e.args[0], MaxRetryError):
193-
# We want to raise the original MaxRetryError, not the wrapper
194-
original_error = e.args[0]
195-
logger.error(
196-
f"SEA HTTP request failed with MaxRetryError: {original_error}"
197-
)
198-
raise original_error
199-
else:
200-
logger.error(f"SEA HTTP request failed with ConnectionError: {e}")
201-
raise RequestError("Error during request to server.", None, None, e)
202-
except RequestException as e:
203-
error_message = f"Error during request to server: {e}"
266+
url=path,
267+
body=body,
268+
headers=headers,
269+
preload_content=False,
270+
retries=self.retry_policy,
271+
)
272+
except MaxRetryError as e:
273+
# urllib3 MaxRetryError should bubble up for redirect tests to catch
274+
logger.error(f"SEA HTTP request failed with MaxRetryError: {e}")
275+
raise
276+
except Exception as e:
277+
logger.error(f"SEA HTTP request failed with exception: {e}")
278+
error_message = f"Error during request to server. {e}"
279+
# Construct RequestError with proper 3-argument format (message, context, error)
204280
raise RequestError(error_message, None, None, e)
205281

282+
logger.debug(f"Response status: {response.status}")
283+
284+
# Handle successful responses
285+
if 200 <= response.status < 300:
286+
return response.json()
287+
288+
error_message = f"SEA HTTP request failed with status {response.status}"
289+
290+
raise RequestError(error_message, None)
291+
206292
def _get_command_type_from_path(self, path: str, method: str) -> CommandType:
293+
"""
294+
Determine the command type based on the API path and method.
295+
296+
This helps the retry policy make appropriate decisions for different
297+
types of SEA operations.
298+
"""
207299
path = path.lower()
208300
method = method.upper()
301+
209302
if "/statements" in path:
210303
if method == "POST" and path.endswith("/statements"):
211304
return CommandType.EXECUTE_STATEMENT
212305
elif "/cancel" in path:
213-
return CommandType.OTHER
306+
return CommandType.OTHER # Cancel operation
214307
elif method == "DELETE":
215308
return CommandType.CLOSE_OPERATION
216309
elif method == "GET":
217310
return CommandType.GET_OPERATION_STATUS
218311
elif "/sessions" in path:
219312
if method == "DELETE":
220313
return CommandType.CLOSE_SESSION
314+
221315
return CommandType.OTHER

0 commit comments

Comments
 (0)