Skip to content

Commit ab3410f

Browse files
unify ssl proxy
Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent d3df719 commit ab3410f

File tree

3 files changed

+81
-169
lines changed

3 files changed

+81
-169
lines changed

src/databricks/sql/auth/thrift_http_client.py

Lines changed: 28 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
from urllib3.util import make_headers
1616
from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy
1717
from databricks.sql.types import SSLOptions
18+
from databricks.sql.common.http_utils import (
19+
detect_and_parse_proxy,
20+
create_connection_pool,
21+
)
1822

1923
logger = logging.getLogger(__name__)
2024

@@ -58,25 +62,22 @@ def __init__(
5862
self.path = parsed.path
5963
if parsed.query:
6064
self.path += "?%s" % parsed.query
61-
try:
62-
proxy = urllib.request.getproxies()[self.scheme]
63-
except KeyError:
64-
proxy = None
65-
else:
66-
if urllib.request.proxy_bypass(self.host):
67-
proxy = None
68-
if proxy:
69-
parsed = urllib.parse.urlparse(proxy)
65+
66+
# Handle proxy settings using shared utility
67+
proxy_uri, proxy_auth = detect_and_parse_proxy(self.scheme, self.host)
68+
69+
if proxy_uri:
70+
parsed_proxy = urllib.parse.urlparse(proxy_uri)
7071

7172
# realhost and realport are the host and port of the actual request
7273
self.realhost = self.host
7374
self.realport = self.port
7475

7576
# this is passed to ProxyManager
76-
self.proxy_uri: str = proxy
77-
self.host = parsed.hostname
78-
self.port = parsed.port
79-
self.proxy_auth = self.basic_proxy_auth_headers(parsed)
77+
self.proxy_uri: str = proxy_uri
78+
self.host = parsed_proxy.hostname
79+
self.port = parsed_proxy.port
80+
self.proxy_auth = proxy_auth
8081
else:
8182
self.realhost = self.realport = self.proxy_auth = None
8283

@@ -105,40 +106,17 @@ def startRetryTimer(self):
105106
self.retry_policy and self.retry_policy.start_retry_timer()
106107

107108
def open(self):
108-
109-
# self.__pool replaces the self.__http used by the original THttpClient
110-
_pool_kwargs = {"maxsize": self.max_connections}
111-
112-
if self.scheme == "http":
113-
pool_class = HTTPConnectionPool
114-
elif self.scheme == "https":
115-
pool_class = HTTPSConnectionPool
116-
_pool_kwargs.update(
117-
{
118-
"cert_reqs": ssl.CERT_REQUIRED
119-
if self._ssl_options.tls_verify
120-
else ssl.CERT_NONE,
121-
"ca_certs": self._ssl_options.tls_trusted_ca_file,
122-
"cert_file": self._ssl_options.tls_client_cert_file,
123-
"key_file": self._ssl_options.tls_client_cert_key_file,
124-
"key_password": self._ssl_options.tls_client_cert_key_password,
125-
}
126-
)
127-
128-
if self.using_proxy():
129-
proxy_manager = ProxyManager(
130-
self.proxy_uri,
131-
num_pools=1,
132-
proxy_headers=self.proxy_auth,
133-
)
134-
self.__pool = proxy_manager.connection_from_host(
135-
host=self.realhost,
136-
port=self.realport,
137-
scheme=self.scheme,
138-
pool_kwargs=_pool_kwargs,
139-
)
140-
else:
141-
self.__pool = pool_class(self.host, self.port, **_pool_kwargs)
109+
"""Initialize the connection pool using shared utility."""
110+
self.__pool = create_connection_pool(
111+
scheme=self.scheme,
112+
host=self.realhost if self.using_proxy() else self.host,
113+
port=self.realport if self.using_proxy() else self.port,
114+
ssl_options=self._ssl_options,
115+
proxy_uri=getattr(self, 'proxy_uri', None),
116+
proxy_headers=self.proxy_auth,
117+
retry_policy=self.retry_policy,
118+
max_connections=self.max_connections,
119+
)
142120

143121
def close(self):
144122
self.__resp and self.__resp.drain_conn()
@@ -204,15 +182,9 @@ def flush(self):
204182
)
205183
)
206184

207-
@staticmethod
208-
def basic_proxy_auth_headers(proxy):
209-
if proxy is None or not proxy.username:
210-
return None
211-
ap = "%s:%s" % (
212-
urllib.parse.unquote(proxy.username),
213-
urllib.parse.unquote(proxy.password),
214-
)
215-
return make_headers(proxy_basic_auth=ap)
185+
def using_proxy(self) -> bool:
186+
"""Check if proxy is being used."""
187+
return self.realhost is not None
216188

217189
def set_retry_command_type(self, value: CommandType):
218190
"""Pass the provided CommandType to the retry policy"""

src/databricks/sql/backend/sea/utils/http_client.py

Lines changed: 23 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
from databricks.sql.exc import (
1616
RequestError,
1717
)
18+
from databricks.sql.common.http_utils import (
19+
detect_and_parse_proxy,
20+
create_retry_policy_from_kwargs,
21+
create_connection_pool,
22+
)
1823

1924
logger = logging.getLogger(__name__)
2025

@@ -121,27 +126,17 @@ def __init__(
121126
)
122127
self.retry_policy = 0
123128

124-
# Handle proxy settings
125-
try:
126-
# returns a dictionary of scheme -> proxy server URL mappings.
127-
# https://docs.python.org/3/library/urllib.request.html#urllib.request.getproxies
128-
proxy = urllib.request.getproxies().get(self.scheme)
129-
except (KeyError, AttributeError):
130-
# No proxy found or getproxies() failed - disable proxy
131-
proxy = None
132-
else:
133-
# Proxy found, but check if this host should bypass proxy
134-
if self.host and urllib.request.proxy_bypass(self.host):
135-
proxy = None # Host bypasses proxy per system rules
136-
137-
if proxy:
138-
parsed_proxy = urllib.parse.urlparse(proxy)
129+
# Handle proxy settings using shared utility
130+
proxy_uri, proxy_auth = detect_and_parse_proxy(self.scheme, self.host)
131+
132+
if proxy_uri:
133+
parsed_proxy = urllib.parse.urlparse(proxy_uri)
139134
self.proxy_host = self.host
140135
self.proxy_port = self.port
141-
self.proxy_uri = proxy
136+
self.proxy_uri = proxy_uri
142137
self.host = parsed_proxy.hostname
143138
self.port = parsed_proxy.port or (443 if self.scheme == "https" else 80)
144-
self.proxy_auth = self._basic_proxy_auth_headers(parsed_proxy)
139+
self.proxy_auth = proxy_auth
145140
else:
146141
self.proxy_host = None
147142
self.proxy_port = None
@@ -152,47 +147,18 @@ def __init__(
152147
self._pool = None
153148
self._open()
154149

155-
def _basic_proxy_auth_headers(self, proxy_parsed) -> Optional[Dict[str, str]]:
156-
"""Create basic auth headers for proxy if credentials are provided."""
157-
if proxy_parsed is None or not proxy_parsed.username:
158-
return None
159-
ap = f"{urllib.parse.unquote(proxy_parsed.username)}:{urllib.parse.unquote(proxy_parsed.password)}"
160-
return make_headers(proxy_basic_auth=ap)
161-
162150
def _open(self):
163-
"""Initialize the connection pool."""
164-
pool_kwargs = {"maxsize": self.max_connections}
165-
166-
if self.scheme == "http":
167-
pool_class = HTTPConnectionPool
168-
else: # https
169-
pool_class = HTTPSConnectionPool
170-
pool_kwargs.update(
171-
{
172-
"cert_reqs": ssl.CERT_REQUIRED
173-
if self.ssl_options.tls_verify
174-
else ssl.CERT_NONE,
175-
"ca_certs": self.ssl_options.tls_trusted_ca_file,
176-
"cert_file": self.ssl_options.tls_client_cert_file,
177-
"key_file": self.ssl_options.tls_client_cert_key_file,
178-
"key_password": self.ssl_options.tls_client_cert_key_password,
179-
}
180-
)
181-
182-
if self.using_proxy():
183-
proxy_manager = ProxyManager(
184-
self.proxy_uri,
185-
num_pools=1,
186-
proxy_headers=self.proxy_auth,
187-
)
188-
self._pool = proxy_manager.connection_from_host(
189-
host=self.proxy_host,
190-
port=self.proxy_port,
191-
scheme=self.scheme,
192-
pool_kwargs=pool_kwargs,
193-
)
194-
else:
195-
self._pool = pool_class(self.host, self.port, **pool_kwargs)
151+
"""Initialize the connection pool using shared utility."""
152+
self._pool = create_connection_pool(
153+
scheme=self.scheme,
154+
host=self.proxy_host if self.using_proxy() else self.host,
155+
port=self.proxy_port if self.using_proxy() else self.port,
156+
ssl_options=self.ssl_options,
157+
proxy_uri=self.proxy_uri,
158+
proxy_headers=self.proxy_auth,
159+
retry_policy=self.retry_policy,
160+
max_connections=self.max_connections,
161+
)
196162

197163
def close(self):
198164
"""Close the connection pool."""

src/databricks/sql/common/unified_http_client.py

Lines changed: 30 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType
1313
from databricks.sql.exc import RequestError
1414
from databricks.sql.common.http import HttpMethod
15+
from databricks.sql.common.http_utils import (
16+
detect_and_parse_proxy,
17+
create_retry_policy_from_kwargs,
18+
create_connection_pool,
19+
)
1520

1621
logger = logging.getLogger(__name__)
1722

@@ -39,37 +44,6 @@ def __init__(self, client_context):
3944

4045
def _setup_pool_manager(self):
4146
"""Set up the urllib3 PoolManager with configuration from ClientContext."""
42-
43-
# SSL context setup
44-
ssl_context = None
45-
if self.config.ssl_options:
46-
ssl_context = ssl.create_default_context()
47-
48-
# Configure SSL verification
49-
if not self.config.ssl_options.tls_verify:
50-
ssl_context.check_hostname = False
51-
ssl_context.verify_mode = ssl.CERT_NONE
52-
elif not self.config.ssl_options.tls_verify_hostname:
53-
ssl_context.check_hostname = False
54-
ssl_context.verify_mode = ssl.CERT_REQUIRED
55-
56-
# Load custom CA file if specified
57-
if self.config.ssl_options.tls_trusted_ca_file:
58-
ssl_context.load_verify_locations(
59-
self.config.ssl_options.tls_trusted_ca_file
60-
)
61-
62-
# Load client certificate if specified
63-
if (
64-
self.config.ssl_options.tls_client_cert_file
65-
and self.config.ssl_options.tls_client_cert_key_file
66-
):
67-
ssl_context.load_cert_chain(
68-
self.config.ssl_options.tls_client_cert_file,
69-
self.config.ssl_options.tls_client_cert_key_file,
70-
self.config.ssl_options.tls_client_cert_key_password,
71-
)
72-
7347
# Create retry policy
7448
self._retry_policy = DatabricksRetryPolicy(
7549
delay_min=self.config.retry_delay_min,
@@ -85,32 +59,32 @@ def _setup_pool_manager(self):
8559
self._retry_policy._command_type = None
8660
self._retry_policy._retry_start_time = None
8761

88-
# Common pool manager kwargs
89-
pool_kwargs = {
90-
"num_pools": self.config.pool_connections,
91-
"maxsize": self.config.pool_maxsize,
92-
"retries": self._retry_policy,
93-
"timeout": urllib3.Timeout(
94-
connect=self.config.socket_timeout, read=self.config.socket_timeout
62+
63+
parsed_url = urllib.parse.urlparse(self.config.hostname)
64+
self.scheme = parsed_url.scheme
65+
# Detect proxy using shared utility
66+
proxy_uri, proxy_headers = detect_and_parse_proxy(self.scheme, self.config.hostname)
67+
68+
69+
# Create pool
70+
additional_kwargs = {}
71+
if self.config.socket_timeout:
72+
additional_kwargs["timeout"] = urllib3.Timeout(
73+
connect=self.config.socket_timeout,
74+
read=self.config.socket_timeout
9575
)
96-
if self.config.socket_timeout
97-
else None,
98-
"ssl_context": ssl_context,
99-
}
100-
101-
# Create proxy or regular pool manager
102-
if self.config.http_proxy:
103-
proxy_headers = None
104-
if self.config.proxy_username and self.config.proxy_password:
105-
proxy_headers = make_headers(
106-
proxy_basic_auth=f"{self.config.proxy_username}:{self.config.proxy_password}"
107-
)
108-
109-
self._pool_manager = ProxyManager(
110-
self.config.http_proxy, proxy_headers=proxy_headers, **pool_kwargs
111-
)
112-
else:
113-
self._pool_manager = PoolManager(**pool_kwargs)
76+
77+
self._pool_manager = create_connection_pool(
78+
scheme=self.scheme,
79+
host=self.config.hostname,
80+
port=443,
81+
ssl_options=self.config.ssl_options,
82+
proxy_uri=proxy_uri,
83+
proxy_headers=proxy_headers,
84+
retry_policy=self._retry_policy,
85+
max_connections=self.config.pool_maxsize,
86+
**additional_kwargs
87+
)
11488

11589
def _prepare_headers(
11690
self, headers: Optional[Dict[str, str]] = None

0 commit comments

Comments
 (0)