Skip to content

Commit 1dd40a1

Browse files
review comments
Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent 2a1f719 commit 1dd40a1

File tree

10 files changed

+127
-101
lines changed

10 files changed

+127
-101
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def get_auth_provider(cfg: ClientContext, http_client):
5656
cfg.oauth_client_id,
5757
cfg.oauth_scopes,
5858
http_client,
59-
cfg.auth_type or "databricks-oauth",
59+
cfg.auth_type or AuthType.DATABRICKS_OAUTH.value,
6060
)
6161
else:
6262
raise RuntimeError("No valid authentication settings!")

src/databricks/sql/auth/common.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import logging
33
from typing import Optional, List
44
from urllib.parse import urlparse
5+
from databricks.sql.auth.retry import DatabricksRetryPolicy
6+
from databricks.sql.common.http import HttpMethod
57

68
logger = logging.getLogger(__name__)
79

@@ -38,17 +40,17 @@ def __init__(
3840
# HTTP client configuration parameters
3941
ssl_options=None, # SSLOptions type
4042
socket_timeout: Optional[float] = None,
41-
retry_stop_after_attempts_count: Optional[int] = None,
42-
retry_delay_min: Optional[float] = None,
43-
retry_delay_max: Optional[float] = None,
44-
retry_stop_after_attempts_duration: Optional[float] = None,
45-
retry_delay_default: Optional[float] = None,
43+
retry_stop_after_attempts_count: int = 5,
44+
retry_delay_min: float = 1.0,
45+
retry_delay_max: float = 60.0,
46+
retry_stop_after_attempts_duration: float = 900.0,
47+
retry_delay_default: float = 5.0,
4648
retry_dangerous_codes: Optional[List[int]] = None,
4749
http_proxy: Optional[str] = None,
4850
proxy_username: Optional[str] = None,
4951
proxy_password: Optional[str] = None,
50-
pool_connections: Optional[int] = None,
51-
pool_maxsize: Optional[int] = None,
52+
pool_connections: int = 10,
53+
pool_maxsize: int = 20,
5254
user_agent: Optional[str] = None,
5355
):
5456
self.hostname = hostname
@@ -69,19 +71,17 @@ def __init__(
6971
# HTTP client configuration
7072
self.ssl_options = ssl_options
7173
self.socket_timeout = socket_timeout
72-
self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5
73-
self.retry_delay_min = retry_delay_min or 1.0
74-
self.retry_delay_max = retry_delay_max or 60.0
75-
self.retry_stop_after_attempts_duration = (
76-
retry_stop_after_attempts_duration or 900.0
77-
)
78-
self.retry_delay_default = retry_delay_default or 5.0
74+
self.retry_stop_after_attempts_count = retry_stop_after_attempts_count
75+
self.retry_delay_min = retry_delay_min
76+
self.retry_delay_max = retry_delay_max
77+
self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration
78+
self.retry_delay_default = retry_delay_default
7979
self.retry_dangerous_codes = retry_dangerous_codes or []
8080
self.http_proxy = http_proxy
8181
self.proxy_username = proxy_username
8282
self.proxy_password = proxy_password
83-
self.pool_connections = pool_connections or 10
84-
self.pool_maxsize = pool_maxsize or 20
83+
self.pool_connections = pool_connections
84+
self.pool_maxsize = pool_maxsize
8585
self.user_agent = user_agent
8686

8787

@@ -113,7 +113,9 @@ def get_azure_tenant_id_from_host(host: str, http_client) -> str:
113113
login_url = f"{host}/aad/auth"
114114
logger.debug("Loading tenant ID from %s", login_url)
115115

116-
with http_client.request_context("GET", login_url, allow_redirects=False) as resp:
116+
with http_client.request_context(
117+
HttpMethod.GET, login_url, allow_redirects=False
118+
) as resp:
117119
if resp.status // 100 != 3:
118120
raise ValueError(
119121
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}"

src/databricks/sql/auth/oauth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __fetch_well_known_config(self, hostname: str):
8787
known_config_url = self.idp_endpoint.get_openid_config_url(hostname)
8888

8989
try:
90-
response = self.http_client.request("GET", url=known_config_url)
90+
response = self.http_client.request(HttpMethod.GET, url=known_config_url)
9191
# Convert urllib3 response to requests-like response for compatibility
9292
response.status_code = response.status
9393
response.json = lambda: json.loads(response.data.decode())
@@ -197,7 +197,7 @@ def __send_token_request(self, token_request_url, data):
197197
}
198198
# Use unified HTTP client
199199
response = self.http_client.request(
200-
"POST", url=token_request_url, body=data, headers=headers
200+
HttpMethod.POST, url=token_request_url, body=data, headers=headers
201201
)
202202
# Convert urllib3 response to dict for compatibility
203203
return json.loads(response.data.decode())

src/databricks/sql/client.py

Lines changed: 16 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
transform_paramstyle,
3232
ColumnTable,
3333
ColumnQueue,
34+
build_client_context,
3435
)
3536
from databricks.sql.parameters.native import (
3637
DbsqlParameterBase,
@@ -52,6 +53,7 @@
5253

5354
from databricks.sql.auth.common import ClientContext
5455
from databricks.sql.common.unified_http_client import UnifiedHttpClient
56+
from databricks.sql.common.http import HttpMethod
5557

5658
from databricks.sql.thrift_api.TCLIService.ttypes import (
5759
TOpenSessionResp,
@@ -254,14 +256,14 @@ def read(self) -> Optional[OAuthToken]:
254256
"telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE
255257
)
256258

257-
client_context = self._build_client_context(server_hostname, **kwargs)
258-
http_client = UnifiedHttpClient(client_context)
259+
client_context = build_client_context(server_hostname, __version__, **kwargs)
260+
self.http_client = UnifiedHttpClient(client_context)
259261

260262
try:
261263
self.session = Session(
262264
server_hostname,
263265
http_path,
264-
http_client,
266+
self.http_client,
265267
http_headers,
266268
session_configuration,
267269
catalog,
@@ -350,50 +352,6 @@ def _set_use_inline_params_with_warning(self, value: Union[bool, str]):
350352

351353
return value
352354

353-
def _build_client_context(self, server_hostname: str, **kwargs):
354-
"""Build ClientContext for HTTP client configuration."""
355-
from databricks.sql.auth.common import ClientContext
356-
from databricks.sql.types import SSLOptions
357-
358-
# Extract SSL options
359-
ssl_options = SSLOptions(
360-
tls_verify=not kwargs.get("_tls_no_verify", False),
361-
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
362-
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
363-
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
364-
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
365-
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
366-
)
367-
368-
# Build user agent
369-
user_agent_entry = kwargs.get("user_agent_entry", "")
370-
if user_agent_entry:
371-
user_agent = f"PyDatabricksSqlConnector/{__version__} ({user_agent_entry})"
372-
else:
373-
user_agent = f"PyDatabricksSqlConnector/{__version__}"
374-
375-
return ClientContext(
376-
hostname=server_hostname,
377-
ssl_options=ssl_options,
378-
socket_timeout=kwargs.get("_socket_timeout"),
379-
retry_stop_after_attempts_count=kwargs.get(
380-
"_retry_stop_after_attempts_count"
381-
),
382-
retry_delay_min=kwargs.get("_retry_delay_min"),
383-
retry_delay_max=kwargs.get("_retry_delay_max"),
384-
retry_stop_after_attempts_duration=kwargs.get(
385-
"_retry_stop_after_attempts_duration"
386-
),
387-
retry_delay_default=kwargs.get("_retry_delay_default"),
388-
retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"),
389-
http_proxy=kwargs.get("_http_proxy"),
390-
proxy_username=kwargs.get("_proxy_username"),
391-
proxy_password=kwargs.get("_proxy_password"),
392-
pool_connections=kwargs.get("_pool_connections"),
393-
pool_maxsize=kwargs.get("_pool_maxsize"),
394-
user_agent=user_agent,
395-
)
396-
397355
# The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently.
398356
def __enter__(self) -> "Connection":
399357
return self
@@ -447,7 +405,7 @@ def get_protocol_version(openSessionResp: TOpenSessionResp):
447405
@property
448406
def open(self) -> bool:
449407
"""Return whether the connection is open by checking if the session is open."""
450-
return hasattr(self, "session") and self.session.is_open
408+
return self.session.is_open
451409

452410
def cursor(
453411
self,
@@ -497,6 +455,10 @@ def _close(self, close_cursors=True) -> None:
497455

498456
TelemetryClientFactory.close(self.get_session_id_hex())
499457

458+
# Close HTTP client that was created by this connection
459+
if self.http_client:
460+
self.http_client.close()
461+
500462
def commit(self):
501463
"""No-op because Databricks does not support transactions"""
502464
pass
@@ -796,8 +758,8 @@ def _handle_staging_put(
796758
)
797759

798760
with open(local_file, "rb") as fh:
799-
r = self.connection.session.http_client.request(
800-
"PUT", presigned_url, body=fh.read(), headers=headers
761+
r = self.connection.http_client.request(
762+
HttpMethod.PUT, presigned_url, body=fh.read(), headers=headers
801763
)
802764

803765
# fmt: off
@@ -837,8 +799,8 @@ def _handle_staging_get(
837799
session_id_hex=self.connection.get_session_id_hex(),
838800
)
839801

840-
r = self.connection.session.http_client.request(
841-
"GET", presigned_url, headers=headers
802+
r = self.connection.http_client.request(
803+
HttpMethod.GET, presigned_url, headers=headers
842804
)
843805

844806
# response.ok verifies the status code is not between 400-600.
@@ -860,8 +822,8 @@ def _handle_staging_remove(
860822
):
861823
"""Make an HTTP DELETE request to the presigned_url"""
862824

863-
r = self.connection.session.http_client.request(
864-
"DELETE", presigned_url, headers=headers
825+
r = self.connection.http_client.request(
826+
HttpMethod.DELETE, presigned_url, headers=headers
865827
)
866828

867829
if r.status >= 400:

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from databricks.sql.types import SSLOptions
1111
from databricks.sql.telemetry.latency_logger import log_latency
1212
from databricks.sql.telemetry.models.event import StatementType
13+
from databricks.sql.common.unified_http_client import UnifiedHttpClient
1314

1415
logger = logging.getLogger(__name__)
1516

@@ -79,9 +80,10 @@ def run(self) -> DownloadedFile:
7980
"""
8081

8182
logger.debug(
82-
"ResultSetDownloadHandler: starting file download, chunk id {}, offset {}, row count {}".format(
83-
self.chunk_id, self.link.startRowOffset, self.link.rowCount
84-
)
83+
"ResultSetDownloadHandler: starting file download, chunk id %s, offset %s, row count %s",
84+
self.chunk_id,
85+
self.link.startRowOffset,
86+
self.link.rowCount,
8587
)
8688

8789
# Check if link is already expired or is expiring
@@ -92,7 +94,7 @@ def run(self) -> DownloadedFile:
9294
start_time = time.time()
9395

9496
with self._http_client.request_context(
95-
method="GET",
97+
method=HttpMethod.GET,
9698
url=self.link.fileLink,
9799
timeout=self.settings.download_timeout,
98100
headers=self.link.httpHeaders,
@@ -116,15 +118,15 @@ def run(self) -> DownloadedFile:
116118
# The size of the downloaded file should match the size specified from TSparkArrowResultLink
117119
if len(decompressed_data) != self.link.bytesNum:
118120
logger.debug(
119-
"ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format(
120-
len(decompressed_data), self.link.bytesNum
121-
)
121+
"ResultSetDownloadHandler: downloaded file size %s does not match the expected value %s",
122+
len(decompressed_data),
123+
self.link.bytesNum,
122124
)
123125

124126
logger.debug(
125-
"ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format(
126-
self.link.startRowOffset, self.link.rowCount
127-
)
127+
"ResultSetDownloadHandler: successfully downloaded file, offset %s, row count %s",
128+
self.link.startRowOffset,
129+
self.link.rowCount,
128130
)
129131

130132
return DownloadedFile(

src/databricks/sql/common/feature_flag.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from concurrent.futures import ThreadPoolExecutor
66
from typing import Dict, Optional, List, Any, TYPE_CHECKING
77

8+
from databricks.sql.common.http import HttpMethod
9+
810
if TYPE_CHECKING:
911
from databricks.sql.client import Connection
1012

@@ -111,7 +113,7 @@ def _refresh_flags(self):
111113
headers["User-Agent"] = self._connection.session.useragent_header
112114

113115
response = self._http_client.request(
114-
"GET", self._feature_flag_endpoint, headers=headers, timeout=30
116+
HttpMethod.GET, self._feature_flag_endpoint, headers=headers, timeout=30
115117
)
116118

117119
if response.status == 200:

src/databricks/sql/common/unified_http_client.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import ssl
33
import urllib.parse
44
from contextlib import contextmanager
5-
from typing import Dict, Any, Optional, Generator, Union
5+
from typing import Dict, Any, Optional, Generator
66

77
import urllib3
88
from urllib3 import PoolManager, ProxyManager
@@ -11,6 +11,7 @@
1111

1212
from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType
1313
from databricks.sql.exc import RequestError
14+
from databricks.sql.common.http import HttpMethod
1415

1516
logger = logging.getLogger(__name__)
1617

@@ -135,13 +136,17 @@ def _prepare_retry_policy(self):
135136

136137
@contextmanager
137138
def request_context(
138-
self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs
139+
self,
140+
method: HttpMethod,
141+
url: str,
142+
headers: Optional[Dict[str, str]] = None,
143+
**kwargs,
139144
) -> Generator[urllib3.HTTPResponse, None, None]:
140145
"""
141146
Context manager for making HTTP requests with proper resource cleanup.
142147
143148
Args:
144-
method: HTTP method (GET, POST, PUT, DELETE)
149+
method: HTTP method (HttpMethod.GET, HttpMethod.POST, HttpMethod.PUT, HttpMethod.DELETE)
145150
url: URL to request
146151
headers: Optional headers dict
147152
**kwargs: Additional arguments passed to urllib3 request
@@ -160,7 +165,7 @@ def request_context(
160165

161166
try:
162167
response = self._pool_manager.request(
163-
method=method, url=url, headers=request_headers, **kwargs
168+
method=method.value, url=url, headers=request_headers, **kwargs
164169
)
165170
yield response
166171
except MaxRetryError as e:
@@ -174,22 +179,27 @@ def request_context(
174179
response.close()
175180

176181
def request(
177-
self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs
182+
self,
183+
method: HttpMethod,
184+
url: str,
185+
headers: Optional[Dict[str, str]] = None,
186+
**kwargs,
178187
) -> urllib3.HTTPResponse:
179188
"""
180189
Make an HTTP request.
181190
182191
Args:
183-
method: HTTP method (GET, POST, PUT, DELETE, etc.)
192+
method: HTTP method (HttpMethod.GET, HttpMethod.POST, HttpMethod.PUT, HttpMethod.DELETE, etc.)
184193
url: URL to request
185194
headers: Optional headers dict
186195
**kwargs: Additional arguments passed to urllib3 request
187196
188197
Returns:
189-
urllib3.HTTPResponse: The HTTP response object with data pre-loaded
198+
urllib3.HTTPResponse: The HTTP response object with data and metadata pre-loaded
190199
"""
191200
with self.request_context(method, url, headers=headers, **kwargs) as response:
192201
# Read the response data to ensure it's available after context exit
202+
# Note: status and headers remain accessible after close(), only data needs caching
193203
response._body = response.data
194204
return response
195205

src/databricks/sql/session.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,3 @@ def close(self) -> None:
193193
logger.error("Attempt to close session raised a local exception: %s", e)
194194

195195
self.is_open = False
196-
197-
# Close HTTP client if it exists
198-
if hasattr(self, "http_client") and self.http_client:
199-
self.http_client.close()

0 commit comments

Comments
 (0)