Skip to content

Commit 65e89ce

Browse files
simplify change
Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent 5895b57 commit 65e89ce

File tree

4 files changed

+135
-54
lines changed

4 files changed

+135
-54
lines changed

src/databricks/sql/auth/thrift_http_client.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def __init__(
6262
self.path = parsed.path
6363
if parsed.query:
6464
self.path += "?%s" % parsed.query
65-
65+
6666
# Handle proxy settings using shared utility
6767
proxy_uri, proxy_auth = detect_and_parse_proxy(self.scheme, self.host)
68-
68+
6969
if proxy_uri:
7070
parsed_proxy = urllib.parse.urlparse(proxy_uri)
7171
# realhost and realport are the host and port of the actual request
@@ -77,7 +77,7 @@ def __init__(
7777
self.port = parsed_proxy.port
7878
self.proxy_auth = proxy_auth
7979
else:
80-
self.realhost = self.realport = self.proxy_auth = None
80+
self.realhost = self.realport = self.proxy_auth = self.proxy_uri = None
8181

8282
self.max_connections = max_connections
8383

@@ -104,17 +104,40 @@ def startRetryTimer(self):
104104
self.retry_policy and self.retry_policy.start_retry_timer()
105105

106106
def open(self):
107-
"""Initialize the connection pool using shared utility."""
108-
self.__pool = create_connection_pool(
109-
scheme=self.scheme,
110-
host=self.realhost if self.using_proxy() else self.host,
111-
port=self.realport if self.using_proxy() else self.port,
112-
ssl_options=self._ssl_options,
113-
proxy_uri=self.proxy_uri,
114-
proxy_headers=self.proxy_auth,
115-
retry_policy=self.retry_policy,
116-
max_connections=self.max_connections,
117-
)
107+
108+
# self.__pool replaces the self.__http used by the original THttpClient
109+
_pool_kwargs = {"maxsize": self.max_connections}
110+
111+
if self.scheme == "http":
112+
pool_class = HTTPConnectionPool
113+
elif self.scheme == "https":
114+
pool_class = HTTPSConnectionPool
115+
_pool_kwargs.update(
116+
{
117+
"cert_reqs": ssl.CERT_REQUIRED
118+
if self._ssl_options.tls_verify
119+
else ssl.CERT_NONE,
120+
"ca_certs": self._ssl_options.tls_trusted_ca_file,
121+
"cert_file": self._ssl_options.tls_client_cert_file,
122+
"key_file": self._ssl_options.tls_client_cert_key_file,
123+
"key_password": self._ssl_options.tls_client_cert_key_password,
124+
}
125+
)
126+
127+
if self.using_proxy():
128+
proxy_manager = ProxyManager(
129+
self.proxy_uri,
130+
num_pools=1,
131+
proxy_headers=self.proxy_auth,
132+
)
133+
self.__pool = proxy_manager.connection_from_host(
134+
host=self.realhost,
135+
port=self.realport,
136+
scheme=self.scheme,
137+
pool_kwargs=_pool_kwargs,
138+
)
139+
else:
140+
self.__pool = pool_class(self.host, self.port, **_pool_kwargs)
118141

119142
def close(self):
120143
self.__resp and self.__resp.drain_conn()

src/databricks/sql/backend/sea/utils/http_client.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
)
1818
from databricks.sql.common.http_utils import (
1919
detect_and_parse_proxy,
20-
create_retry_policy_from_kwargs,
2120
create_connection_pool,
2221
)
2322

@@ -128,7 +127,7 @@ def __init__(
128127

129128
# Handle proxy settings using shared utility
130129
proxy_uri, proxy_auth = detect_and_parse_proxy(self.scheme, self.host)
131-
130+
132131
if proxy_uri:
133132
parsed_proxy = urllib.parse.urlparse(proxy_uri)
134133
self.realhost = self.host
@@ -138,24 +137,46 @@ def __init__(
138137
self.port = parsed_proxy.port or (443 if self.scheme == "https" else 80)
139138
self.proxy_auth = proxy_auth
140139
else:
141-
self.realhost = self.realport = self.proxy_auth = None
140+
self.realhost = self.realport = self.proxy_auth = self.proxy_uri = None
142141

143142
# Initialize connection pool
144143
self._pool = None
145144
self._open()
146145

147146
def _open(self):
148-
"""Initialize the connection pool using shared utility."""
149-
self._pool = create_connection_pool(
150-
scheme=self.scheme,
151-
host=self.realhost if self.using_proxy() else self.host,
152-
port=self.realport if self.using_proxy() else self.port,
153-
ssl_options=self.ssl_options,
154-
proxy_uri=self.proxy_uri,
155-
proxy_headers=self.proxy_auth,
156-
retry_policy=self.retry_policy,
157-
max_connections=self.max_connections,
158-
)
147+
"""Initialize the connection pool."""
148+
pool_kwargs = {"maxsize": self.max_connections}
149+
150+
if self.scheme == "http":
151+
pool_class = HTTPConnectionPool
152+
else: # https
153+
pool_class = HTTPSConnectionPool
154+
pool_kwargs.update(
155+
{
156+
"cert_reqs": ssl.CERT_REQUIRED
157+
if self.ssl_options.tls_verify
158+
else ssl.CERT_NONE,
159+
"ca_certs": self.ssl_options.tls_trusted_ca_file,
160+
"cert_file": self.ssl_options.tls_client_cert_file,
161+
"key_file": self.ssl_options.tls_client_cert_key_file,
162+
"key_password": self.ssl_options.tls_client_cert_key_password,
163+
}
164+
)
165+
166+
if self.using_proxy():
167+
proxy_manager = ProxyManager(
168+
self.proxy_uri,
169+
num_pools=1,
170+
proxy_headers=self.proxy_auth,
171+
)
172+
self._pool = proxy_manager.connection_from_host(
173+
host=self.proxy_host,
174+
port=self.proxy_port,
175+
scheme=self.scheme,
176+
pool_kwargs=pool_kwargs,
177+
)
178+
else:
179+
self._pool = pool_class(self.host, self.port, **pool_kwargs)
159180

160181
def close(self):
161182
"""Close the connection pool."""
@@ -164,7 +185,7 @@ def close(self):
164185

165186
def using_proxy(self) -> bool:
166187
"""Check if proxy is being used."""
167-
return self.proxy_host is not None
188+
return self.realhost is not None
168189

169190
def set_retry_command_type(self, command_type: CommandType):
170191
"""Set the command type for retry policy decision making."""

src/databricks/sql/common/unified_http_client.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,37 @@ def __init__(self, client_context):
4343

4444
def _setup_pool_manager(self):
4545
"""Set up the urllib3 PoolManager with configuration from ClientContext."""
46+
47+
# SSL context setup
48+
ssl_context = None
49+
if self.config.ssl_options:
50+
ssl_context = ssl.create_default_context()
51+
52+
# Configure SSL verification
53+
if not self.config.ssl_options.tls_verify:
54+
ssl_context.check_hostname = False
55+
ssl_context.verify_mode = ssl.CERT_NONE
56+
elif not self.config.ssl_options.tls_verify_hostname:
57+
ssl_context.check_hostname = False
58+
ssl_context.verify_mode = ssl.CERT_REQUIRED
59+
60+
# Load custom CA file if specified
61+
if self.config.ssl_options.tls_trusted_ca_file:
62+
ssl_context.load_verify_locations(
63+
self.config.ssl_options.tls_trusted_ca_file
64+
)
65+
66+
# Load client certificate if specified
67+
if (
68+
self.config.ssl_options.tls_client_cert_file
69+
and self.config.ssl_options.tls_client_cert_key_file
70+
):
71+
ssl_context.load_cert_chain(
72+
self.config.ssl_options.tls_client_cert_file,
73+
self.config.ssl_options.tls_client_cert_key_file,
74+
self.config.ssl_options.tls_client_cert_key_password,
75+
)
76+
4677
# Create retry policy
4778
self._retry_policy = DatabricksRetryPolicy(
4879
delay_min=self.config.retry_delay_min,
@@ -58,15 +89,29 @@ def _setup_pool_manager(self):
5889
self._retry_policy._command_type = None
5990
self._retry_policy._retry_start_time = None
6091

61-
6292
parsed_url = urllib.parse.urlparse(self.config.hostname)
6393
self.scheme = parsed_url.scheme
6494
self.host = parsed_url.hostname
6595
self.port = parsed_url.port
6696

97+
# Common pool manager kwargs
98+
pool_kwargs = {
99+
"num_pools": self.config.pool_connections,
100+
"maxsize": self.config.pool_maxsize,
101+
"retries": self._retry_policy,
102+
"timeout": urllib3.Timeout(
103+
connect=self.config.socket_timeout, read=self.config.socket_timeout
104+
)
105+
if self.config.socket_timeout
106+
else None,
107+
"ssl_context": ssl_context,
108+
}
109+
67110
# Detect proxy using shared utility
68-
proxy_uri, proxy_auth = detect_and_parse_proxy(self.scheme, self.config.hostname)
69-
111+
proxy_uri, proxy_auth = detect_and_parse_proxy(
112+
self.scheme, self.config.hostname
113+
)
114+
70115
if proxy_uri:
71116
parsed_proxy = urllib.parse.urlparse(proxy_uri)
72117
# realhost and realport are the host and port of the actual request
@@ -78,27 +123,15 @@ def _setup_pool_manager(self):
78123
self.port = parsed_proxy.port
79124
self.proxy_auth = proxy_auth
80125
else:
81-
self.realhost = self.realport = self.proxy_auth = None
82-
83-
# Create pool
84-
additional_kwargs = {}
85-
if self.config.socket_timeout:
86-
additional_kwargs["timeout"] = urllib3.Timeout(
87-
connect=self.config.socket_timeout,
88-
read=self.config.socket_timeout
126+
self.realhost = self.realport = self.proxy_auth = self.proxy_uri = None
127+
128+
# Create proxy or regular pool manager
129+
if self.using_proxy():
130+
self._pool_manager = ProxyManager(
131+
self.config.http_proxy, proxy_headers=proxy_auth, **pool_kwargs
89132
)
90-
91-
self._pool_manager = create_connection_pool(
92-
scheme=self.scheme,
93-
host=self.realhost if self.using_proxy() else self.host,
94-
port=self.realport if self.using_proxy() else self.port,
95-
ssl_options=self.config.ssl_options,
96-
proxy_uri=proxy_uri,
97-
proxy_headers=proxy_auth,
98-
retry_policy=self._retry_policy,
99-
max_connections=self.config.pool_maxsize,
100-
**additional_kwargs
101-
)
133+
else:
134+
self._pool_manager = PoolManager(**pool_kwargs)
102135

103136
def _prepare_headers(
104137
self, headers: Optional[Dict[str, str]] = None
@@ -193,6 +226,10 @@ def request(
193226
response._body = response.data
194227
return response
195228

229+
def using_proxy(self) -> bool:
230+
"""Check if proxy is being used."""
231+
return self.realhost is not None
232+
196233
def close(self):
197234
"""Close the underlying connection pools."""
198235
if self._pool_manager:

tests/unit/test_thrift_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,14 @@ def test_headers_are_set(self, t_http_client_class):
206206

207207
def test_proxy_headers_are_set(self):
208208

209-
from databricks.sql.auth.thrift_http_client import THttpClient
209+
from databricks.sql.common.http_utils import create_basic_proxy_auth_headers
210210
from urllib.parse import urlparse
211211

212212
fake_proxy_spec = "https://someuser:somepassword@8.8.8.8:12340"
213213
parsed_proxy = urlparse(fake_proxy_spec)
214214

215215
try:
216-
result = THttpClient.basic_proxy_auth_headers(parsed_proxy)
216+
result = create_basic_proxy_auth_headers(parsed_proxy)
217217
except TypeError as e:
218218
assert False
219219

0 commit comments

Comments
 (0)