Skip to content

Commit 11bc165

Browse files
preliminary (robust) SEA HTTP Client
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent b99d0c4 commit 11bc165

File tree

1 file changed

+237
-98
lines changed

1 file changed

+237
-98
lines changed
Lines changed: 237 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
import json
22
import logging
3-
import requests
4-
from typing import Callable, Dict, Any, Optional, List, Tuple
3+
import ssl
4+
import urllib.parse
5+
import urllib.request
6+
from typing import Dict, Any, Optional, List, Tuple, Union
57
from urllib.parse import urljoin
68

9+
from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
10+
from urllib3.exceptions import HTTPError, MaxRetryError
11+
from urllib3.util import make_headers
12+
713
from databricks.sql.auth.authenticators import AuthProvider
14+
from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy
815
from databricks.sql.types import SSLOptions
16+
from databricks.sql.exc import RequestError
917

1018
logger = logging.getLogger(__name__)
1119

@@ -14,8 +22,8 @@ class SeaHttpClient:
1422
"""
1523
HTTP client for Statement Execution API (SEA).
1624
17-
This client handles the HTTP communication with the SEA endpoints,
18-
including authentication, request formatting, and response parsing.
25+
This client uses urllib3 for robust HTTP communication with retry policies
26+
and connection pooling, similar to the Thrift HTTP client but simplified.
1927
"""
2028

2129
def __init__(
@@ -38,66 +46,143 @@ def __init__(
3846
http_headers: List of HTTP headers to include in requests
3947
auth_provider: Authentication provider
4048
ssl_options: SSL configuration options
41-
**kwargs: Additional keyword arguments
49+
**kwargs: Additional keyword arguments including retry policy settings
4250
"""
4351

4452
self.server_hostname = server_hostname
45-
self.port = port
53+
self.port = port or 443
4654
self.http_path = http_path
4755
self.auth_provider = auth_provider
4856
self.ssl_options = ssl_options
4957

50-
self.base_url = f"https://{server_hostname}:{port}"
58+
# Build base URL
59+
self.base_url = f"https://{server_hostname}:{self.port}"
60+
61+
# Parse URL for proxy handling
62+
parsed_url = urllib.parse.urlparse(self.base_url)
63+
self.scheme = parsed_url.scheme
64+
self.host = parsed_url.hostname
65+
self.port = parsed_url.port or (443 if self.scheme == "https" else 80)
5166

67+
# Setup headers
5268
self.headers: Dict[str, str] = dict(http_headers)
5369
self.headers.update({"Content-Type": "application/json"})
5470

55-
self.max_retries = kwargs.get("_retry_stop_after_attempts_count", 30)
56-
57-
# Create a session for connection pooling
58-
self.session = requests.Session()
71+
# Extract retry policy settings
72+
self._retry_delay_min = kwargs.get("_retry_delay_min", 1.0)
73+
self._retry_delay_max = kwargs.get("_retry_delay_max", 60.0)
74+
self._retry_stop_after_attempts_count = kwargs.get(
75+
"_retry_stop_after_attempts_count", 30
76+
)
77+
self._retry_stop_after_attempts_duration = kwargs.get(
78+
"_retry_stop_after_attempts_duration", 900.0
79+
)
80+
self._retry_delay_default = kwargs.get("_retry_delay_default", 5.0)
81+
self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
82+
83+
# Connection pooling settings
84+
self.max_connections = kwargs.get("max_connections", 10)
85+
86+
# Setup retry policy
87+
self.enable_v3_retries = kwargs.get("_enable_v3_retries", True)
88+
89+
if self.enable_v3_retries:
90+
self.retry_policy = DatabricksRetryPolicy(
91+
delay_min=self._retry_delay_min,
92+
delay_max=self._retry_delay_max,
93+
stop_after_attempts_count=self._retry_stop_after_attempts_count,
94+
stop_after_attempts_duration=self._retry_stop_after_attempts_duration,
95+
delay_default=self._retry_delay_default,
96+
force_dangerous_codes=self.force_dangerous_codes,
97+
urllib3_kwargs={"allowed_methods": ["GET", "POST", "DELETE"]},
98+
)
99+
else:
100+
# Legacy behavior - no automatic retries
101+
self.retry_policy = 0
59102

60-
# Configure SSL verification
61-
if ssl_options.tls_verify:
62-
self.session.verify = ssl_options.tls_trusted_ca_file or True
103+
# Handle proxy settings
104+
try:
105+
proxy = urllib.request.getproxies().get(self.scheme)
106+
except (KeyError, AttributeError):
107+
proxy = None
108+
else:
109+
if urllib.request.proxy_bypass(self.host):
110+
proxy = None
111+
112+
if proxy:
113+
parsed_proxy = urllib.parse.urlparse(proxy)
114+
self.realhost = self.host
115+
self.realport = self.port
116+
self.proxy_uri = proxy
117+
self.host = parsed_proxy.hostname
118+
self.port = parsed_proxy.port
119+
self.proxy_auth = self._basic_proxy_auth_headers(parsed_proxy)
63120
else:
64-
self.session.verify = False
65-
66-
# Configure client certificates if provided
67-
if ssl_options.tls_client_cert_file:
68-
client_cert = ssl_options.tls_client_cert_file
69-
client_key = ssl_options.tls_client_cert_key_file
70-
client_key_password = ssl_options.tls_client_cert_key_password
71-
72-
if client_key:
73-
self.session.cert = (client_cert, client_key)
74-
else:
75-
self.session.cert = client_cert
76-
77-
if client_key_password:
78-
# Note: requests doesn't directly support key passwords
79-
# This would require more complex handling with libraries like pyOpenSSL
80-
logger.warning(
81-
"Client key password provided but not supported by requests library"
82-
)
121+
self.realhost = self.realport = self.proxy_auth = self.proxy_uri = None
122+
123+
# Initialize connection pool
124+
self._pool = None
125+
self._open()
126+
127+
def _basic_proxy_auth_headers(self, proxy_parsed) -> Optional[Dict[str, str]]:
128+
"""Create basic auth headers for proxy if credentials are provided."""
129+
if proxy_parsed is None or not proxy_parsed.username:
130+
return None
131+
ap = f"{urllib.parse.unquote(proxy_parsed.username)}:{urllib.parse.unquote(proxy_parsed.password)}"
132+
return make_headers(proxy_basic_auth=ap)
133+
134+
def _open(self):
135+
"""Initialize the connection pool."""
136+
pool_kwargs = {"maxsize": self.max_connections}
137+
138+
if self.scheme == "http":
139+
pool_class = HTTPConnectionPool
140+
else: # https
141+
pool_class = HTTPSConnectionPool
142+
pool_kwargs.update({
143+
"cert_reqs": ssl.CERT_REQUIRED if self.ssl_options.tls_verify else ssl.CERT_NONE,
144+
"ca_certs": self.ssl_options.tls_trusted_ca_file,
145+
"cert_file": self.ssl_options.tls_client_cert_file,
146+
"key_file": self.ssl_options.tls_client_cert_key_file,
147+
"key_password": self.ssl_options.tls_client_cert_key_password,
148+
})
149+
150+
if self.proxy_uri:
151+
proxy_manager = ProxyManager(
152+
self.proxy_uri,
153+
num_pools=1,
154+
proxy_headers=self.proxy_auth,
155+
)
156+
self._pool = proxy_manager.connection_from_host(
157+
host=self.realhost,
158+
port=self.realport,
159+
scheme=self.scheme,
160+
pool_kwargs=pool_kwargs,
161+
)
162+
else:
163+
self._pool = pool_class(self.host, self.port, **pool_kwargs)
164+
165+
def close(self):
166+
"""Close the connection pool."""
167+
if self._pool:
168+
self._pool.clear()
169+
170+
def set_retry_command_type(self, command_type: CommandType):
171+
"""Set the command type for retry policy decision making."""
172+
if isinstance(self.retry_policy, DatabricksRetryPolicy):
173+
self.retry_policy.command_type = command_type
174+
175+
def start_retry_timer(self):
176+
"""Start the retry timer for duration-based retry limits."""
177+
if isinstance(self.retry_policy, DatabricksRetryPolicy):
178+
self.retry_policy.start_retry_timer()
83179

84180
def _get_auth_headers(self) -> Dict[str, str]:
85181
"""Get authentication headers from the auth provider."""
86182
headers: Dict[str, str] = {}
87183
self.auth_provider.add_headers(headers)
88184
return headers
89185

90-
def _get_call(self, method: str) -> Callable:
91-
"""Get the appropriate HTTP method function."""
92-
method = method.upper()
93-
if method == "GET":
94-
return self.session.get
95-
if method == "POST":
96-
return self.session.post
97-
if method == "DELETE":
98-
return self.session.delete
99-
raise ValueError(f"Unsupported HTTP method: {method}")
100-
101186
def _make_request(
102187
self,
103188
method: str,
@@ -118,69 +203,123 @@ def _make_request(
118203
Dict[str, Any]: Response data parsed from JSON
119204
120205
Raises:
121-
RequestError: If the request fails
206+
RequestError: If the request fails after retries
122207
"""
123208

124-
url = urljoin(self.base_url, path)
125-
headers: Dict[str, str] = {**self.headers, **self._get_auth_headers()}
209+
# Build full URL
210+
if path.startswith("/"):
211+
url = path
212+
else:
213+
url = f"/{path.lstrip('/')}"
214+
215+
# Prepare headers
216+
headers = {**self.headers, **self._get_auth_headers()}
126217

127-
logger.debug(f"making {method} request to {url}")
218+
# Prepare request body
219+
body = json.dumps(data).encode("utf-8") if data else b""
220+
if body:
221+
headers["Content-Length"] = str(len(body))
222+
223+
# Set command type for retry policy
224+
command_type = self._get_command_type_from_path(path, method)
225+
self.set_retry_command_type(command_type)
226+
self.start_retry_timer()
227+
228+
logger.debug(f"Making {method} request to {url}")
128229

129230
try:
130-
call = self._get_call(method)
131-
response = call(
231+
response = self._pool.request(
232+
method=method.upper(),
132233
url=url,
234+
body=body,
133235
headers=headers,
134-
json=data,
135-
params=params,
236+
preload_content=True,
237+
retries=self.retry_policy,
136238
)
137239

138-
# Check for HTTP errors
139-
response.raise_for_status()
140-
141-
# Log response details
142-
logger.debug(f"Response status: {response.status_code}")
143-
144-
# Parse JSON response
145-
if response.content:
146-
result = response.json()
147-
# Log response content (but limit it for large responses)
148-
content_str = json.dumps(result)
149-
if len(content_str) > 1000:
150-
logger.debug(
151-
f"Response content (truncated): {content_str[:1000]}..."
152-
)
240+
logger.debug(f"Response status: {response.status}")
241+
242+
# Handle successful responses
243+
if 200 <= response.status < 300:
244+
if response.data:
245+
try:
246+
result = json.loads(response.data.decode("utf-8"))
247+
logger.debug("Successfully parsed JSON response")
248+
return result
249+
except (json.JSONDecodeError, UnicodeDecodeError) as e:
250+
logger.error(f"Failed to parse JSON response: {e}")
251+
raise RequestError(f"Invalid JSON response: {e}", e)
252+
return {}
253+
254+
# Handle error responses
255+
error_message = f"SEA HTTP request failed with status {response.status}"
256+
257+
try:
258+
if response.data:
259+
error_details = json.loads(response.data.decode("utf-8"))
260+
if isinstance(error_details, dict) and "message" in error_details:
261+
error_message = f"{error_message}: {error_details['message']}"
262+
logger.error(f"Request failed: {error_details}")
263+
except (json.JSONDecodeError, UnicodeDecodeError):
264+
# Log raw response if we can't parse JSON
265+
content = response.data.decode("utf-8", errors="replace") if response.data else ""
266+
logger.error(f"Request failed with non-JSON response: {content}")
267+
268+
raise RequestError(error_message, None)
269+
270+
except MaxRetryError as e:
271+
# Extract the most recent error from the retry history
272+
error_message = f"SEA request failed after retries: {str(e)}"
273+
274+
if hasattr(e, "reason") and e.reason:
275+
if hasattr(e.reason, "response"):
276+
# Extract status code and body from the final failed response
277+
response = e.reason.response
278+
error_message = f"SEA request failed after retries (status {response.status})"
279+
try:
280+
if response.data:
281+
error_details = json.loads(response.data.decode("utf-8"))
282+
if isinstance(error_details, dict) and "message" in error_details:
283+
error_message = f"{error_message}: {error_details['message']}"
284+
except (json.JSONDecodeError, UnicodeDecodeError):
285+
pass
153286
else:
154-
logger.debug(f"Response content: {content_str}")
155-
return result
156-
return {}
157-
158-
except requests.exceptions.RequestException as e:
159-
# Handle request errors and extract details from response if available
160-
error_message = f"SEA HTTP request failed: {str(e)}"
161-
162-
if hasattr(e, "response") and e.response is not None:
163-
status_code = e.response.status_code
164-
try:
165-
error_details = e.response.json()
166-
error_message = (
167-
f"{error_message}: {error_details.get('message', '')}"
168-
)
169-
logger.error(
170-
f"Request failed (status {status_code}): {error_details}"
171-
)
172-
except (ValueError, KeyError):
173-
# If we can't parse JSON, log raw content
174-
content = (
175-
e.response.content.decode("utf-8", errors="replace")
176-
if isinstance(e.response.content, bytes)
177-
else str(e.response.content)
178-
)
179-
logger.error(f"Request failed (status {status_code}): {content}")
180-
else:
181-
logger.error(error_message)
182-
183-
# Re-raise as a RequestError
184-
from databricks.sql.exc import RequestError
287+
error_message = f"SEA request failed after retries: {str(e.reason)}"
288+
289+
logger.error(error_message)
290+
raise RequestError(error_message, e)
185291

292+
except HTTPError as e:
293+
error_message = f"SEA HTTP error: {str(e)}"
294+
logger.error(error_message)
186295
raise RequestError(error_message, e)
296+
297+
except Exception as e:
298+
error_message = f"Unexpected error in SEA request: {str(e)}"
299+
logger.error(error_message)
300+
raise RequestError(error_message, e)
301+
302+
def _get_command_type_from_path(self, path: str, method: str) -> CommandType:
303+
"""
304+
Determine the command type based on the API path and method.
305+
306+
This helps the retry policy make appropriate decisions for different
307+
types of SEA operations.
308+
"""
309+
path = path.lower()
310+
method = method.upper()
311+
312+
if "/statements" in path:
313+
if method == "POST" and path.endswith("/statements"):
314+
return CommandType.EXECUTE_STATEMENT
315+
elif "/cancel" in path:
316+
return CommandType.OTHER # Cancel operation
317+
elif method == "DELETE":
318+
return CommandType.CLOSE_OPERATION
319+
elif method == "GET":
320+
return CommandType.GET_OPERATION_STATUS
321+
elif "/sessions" in path:
322+
if method == "DELETE":
323+
return CommandType.CLOSE_SESSION
324+
325+
return CommandType.OTHER

0 commit comments

Comments
 (0)