Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/databricks/sql/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from databricks.sql.auth.common import AuthType, ClientContext


def get_auth_provider(cfg: ClientContext):
def get_auth_provider(cfg: ClientContext, http_client):
if cfg.credentials_provider:
return ExternalAuthProvider(cfg.credentials_provider)
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
Expand All @@ -35,6 +35,7 @@ def get_auth_provider(cfg: ClientContext):
cfg.oauth_client_id,
cfg.oauth_scopes,
cfg.auth_type,
http_client=http_client,
)
elif cfg.access_token is not None:
return AccessTokenAuthProvider(cfg.access_token)
Expand All @@ -53,6 +54,7 @@ def get_auth_provider(cfg: ClientContext):
cfg.oauth_redirect_port_range,
cfg.oauth_client_id,
cfg.oauth_scopes,
http_client=http_client,
)
else:
raise RuntimeError("No valid authentication settings!")
Expand All @@ -79,7 +81,7 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):
)


def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs):
# TODO : unify all the auth mechanisms with the Python SDK

auth_type = kwargs.get("auth_type")
Expand Down Expand Up @@ -111,4 +113,4 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
credentials_provider=kwargs.get("credentials_provider"),
)
return get_auth_provider(cfg)
return get_auth_provider(cfg, http_client)
2 changes: 2 additions & 0 deletions src/databricks/sql/auth/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
redirect_port_range: List[int],
client_id: str,
scopes: List[str],
http_client,
auth_type: str = "databricks-oauth",
):
try:
Expand All @@ -79,6 +80,7 @@ def __init__(
port_range=redirect_port_range,
client_id=client_id,
idp_endpoint=idp_endpoint,
http_client=http_client,
)
self._hostname = hostname
self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(cloud_scopes)
Expand Down
63 changes: 47 additions & 16 deletions src/databricks/sql/auth/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
from typing import Optional, List
from urllib.parse import urlparse
from databricks.sql.common.http import DatabricksHttpClient, HttpMethod

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -36,6 +35,21 @@ def __init__(
tls_client_cert_file: Optional[str] = None,
oauth_persistence=None,
credentials_provider=None,
# HTTP client configuration parameters
ssl_options=None, # SSLOptions type
socket_timeout: Optional[float] = None,
retry_stop_after_attempts_count: Optional[int] = None,
retry_delay_min: Optional[float] = None,
retry_delay_max: Optional[float] = None,
retry_stop_after_attempts_duration: Optional[float] = None,
retry_delay_default: Optional[float] = None,
retry_dangerous_codes: Optional[List[int]] = None,
http_proxy: Optional[str] = None,
proxy_username: Optional[str] = None,
proxy_password: Optional[str] = None,
pool_connections: Optional[int] = None,
pool_maxsize: Optional[int] = None,
user_agent: Optional[str] = None,
):
self.hostname = hostname
self.access_token = access_token
Expand All @@ -52,6 +66,24 @@ def __init__(
self.oauth_persistence = oauth_persistence
self.credentials_provider = credentials_provider

# HTTP client configuration
self.ssl_options = ssl_options
self.socket_timeout = socket_timeout
self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30
self.retry_delay_min = retry_delay_min or 1.0
self.retry_delay_max = retry_delay_max or 60.0
self.retry_stop_after_attempts_duration = (
retry_stop_after_attempts_duration or 900.0
)
self.retry_delay_default = retry_delay_default or 5.0
self.retry_dangerous_codes = retry_dangerous_codes or []
self.http_proxy = http_proxy
self.proxy_username = proxy_username
self.proxy_password = proxy_password
self.pool_connections = pool_connections or 10
self.pool_maxsize = pool_maxsize or 20
self.user_agent = user_agent


def get_effective_azure_login_app_id(hostname) -> str:
"""
Expand All @@ -69,7 +101,7 @@ def get_effective_azure_login_app_id(hostname) -> str:
return AzureAppId.PROD.value[1]


def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
def get_azure_tenant_id_from_host(host: str, http_client) -> str:
"""
Load the Azure tenant ID from the Azure Databricks login page.

Expand All @@ -78,23 +110,22 @@ def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
the Azure login page, and the tenant ID is extracted from the redirect URL.
"""

if http_client is None:
http_client = DatabricksHttpClient.get_instance()

login_url = f"{host}/aad/auth"
logger.debug("Loading tenant ID from %s", login_url)
with http_client.execute(HttpMethod.GET, login_url, allow_redirects=False) as resp:
if resp.status_code // 100 != 3:

with http_client.request_context("GET", login_url, allow_redirects=False) as resp:
if resp.status // 100 != 3:
raise ValueError(
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}"
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}"
)
entra_id_endpoint = resp.headers.get("Location")
entra_id_endpoint = dict(resp.headers).get("Location")
if entra_id_endpoint is None:
raise ValueError(f"No Location header in response from {login_url}")
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
url = urlparse(entra_id_endpoint)
path_segments = url.path.split("/")
if len(path_segments) < 2:
raise ValueError(f"Invalid path in Location header: {url.path}")
return path_segments[1]

# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
url = urlparse(entra_id_endpoint)
path_segments = url.path.split("/")
if len(path_segments) < 2:
raise ValueError(f"Invalid path in Location header: {url.path}")
return path_segments[1]
42 changes: 16 additions & 26 deletions src/databricks/sql/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from typing import List, Optional

import oauthlib.oauth2
import requests
from oauthlib.oauth2.rfc6749.errors import OAuth2Error
from requests.exceptions import RequestException
from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader
from databricks.sql.common.http import HttpMethod, HttpHeader
from databricks.sql.common.http import OAuthResponse
from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler
from databricks.sql.auth.endpoint import OAuthEndpointCollection
Expand Down Expand Up @@ -63,33 +61,19 @@ def refresh(self) -> Token:
pass


class IgnoreNetrcAuth(requests.auth.AuthBase):
"""This auth method is a no-op.

We use it to force requestslib to not use .netrc to write auth headers
when making .post() requests to the oauth token endpoints, since these
don't require authentication.

In cases where .netrc is outdated or corrupt, these requests will fail.

See issue #121
"""

def __call__(self, r):
return r


class OAuthManager:
def __init__(
self,
port_range: List[int],
client_id: str,
idp_endpoint: OAuthEndpointCollection,
http_client,
):
self.port_range = port_range
self.client_id = client_id
self.redirect_port = None
self.idp_endpoint = idp_endpoint
self.http_client = http_client

@staticmethod
def __token_urlsafe(nbytes=32):
Expand All @@ -103,8 +87,11 @@ def __fetch_well_known_config(self, hostname: str):
known_config_url = self.idp_endpoint.get_openid_config_url(hostname)

try:
response = requests.get(url=known_config_url, auth=IgnoreNetrcAuth())
except RequestException as e:
response = self.http_client.request("GET", url=known_config_url)
# Convert urllib3 response to requests-like response for compatibility
response.status_code = response.status
response.json = lambda: json.loads(response.data.decode())
except Exception as e:
logger.error(
f"Unable to fetch OAuth configuration from {known_config_url}.\n"
"Verify it is a valid workspace URL and that OAuth is "
Expand All @@ -122,7 +109,7 @@ def __fetch_well_known_config(self, hostname: str):
raise RuntimeError(msg)
try:
return response.json()
except requests.exceptions.JSONDecodeError as e:
except Exception as e:
logger.error(
f"Unable to decode OAuth configuration from {known_config_url}.\n"
"Verify it is a valid workspace URL and that OAuth is "
Expand Down Expand Up @@ -209,10 +196,12 @@ def __send_token_request(token_request_url, data):
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
}
response = requests.post(
url=token_request_url, data=data, headers=headers, auth=IgnoreNetrcAuth()
# Use unified HTTP client
response = self.http_client.request(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is a static method using class http_client ?

"POST", url=token_request_url, body=data, headers=headers
)
return response.json()
# Convert urllib3 response to dict for compatibility
return json.loads(response.data.decode())

def __send_refresh_token_request(self, hostname, refresh_token):
oauth_config = self.__fetch_well_known_config(hostname)
Expand Down Expand Up @@ -320,14 +309,15 @@ def __init__(
token_url,
client_id,
client_secret,
http_client,
extra_params: dict = {},
):
self.client_id = client_id
self.client_secret = client_secret
self.token_url = token_url
self.extra_params = extra_params
self.token: Optional[Token] = None
self._http_client = DatabricksHttpClient.get_instance()
self._http_client = http_client

def get_token(self) -> Token:
if self.token is None or self.token.is_expired():
Expand Down
4 changes: 4 additions & 0 deletions src/databricks/sql/backend/sea/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def build_queue(
max_download_threads: int,
sea_client: SeaDatabricksClient,
lz4_compressed: bool,
http_client,
) -> ResultSetQueue:
"""
Factory method to build a result set queue for SEA backend.
Expand Down Expand Up @@ -94,6 +95,7 @@ def build_queue(
total_chunk_count=manifest.total_chunk_count,
lz4_compressed=lz4_compressed,
description=description,
http_client=http_client,
)
raise ProgrammingError("Invalid result format")

Expand Down Expand Up @@ -309,6 +311,7 @@ def __init__(
sea_client: SeaDatabricksClient,
statement_id: str,
total_chunk_count: int,
http_client,
lz4_compressed: bool = False,
description: List[Tuple] = [],
):
Expand Down Expand Up @@ -337,6 +340,7 @@ def __init__(
# TODO: fix these arguments when telemetry is implemented in SEA
session_id_hex=None,
chunk_id=0,
http_client=http_client,
)

logger.debug(
Expand Down
1 change: 1 addition & 0 deletions src/databricks/sql/backend/sea/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
max_download_threads=sea_client.max_download_threads,
sea_client=sea_client,
lz4_compressed=execute_response.lz4_compressed,
http_client=connection.session.http_client,
)

# Call parent constructor with common attributes
Expand Down
10 changes: 5 additions & 5 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
http_headers,
auth_provider: AuthProvider,
ssl_options: SSLOptions,
http_client=None,
**kwargs,
):
# Internal arguments in **kwargs:
Expand Down Expand Up @@ -145,10 +146,8 @@ def __init__(
# Number of threads for handling cloud fetch downloads. Defaults to 10

logger.debug(
"ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)",
server_hostname,
port,
http_path,
"ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)"
% (server_hostname, port, http_path)
)

port = port or 443
Expand Down Expand Up @@ -177,8 +176,8 @@ def __init__(
self._max_download_threads = kwargs.get("max_download_threads", 10)

self._ssl_options = ssl_options

self._auth_provider = auth_provider
self._http_client = http_client

# Connector version 3 retry approach
self.enable_v3_retries = kwargs.get("_enable_v3_retries", True)
Expand Down Expand Up @@ -1292,6 +1291,7 @@ def fetch_results(
session_id_hex=self._session_id_hex,
statement_id=command_id.to_hex_guid(),
chunk_id=chunk_id,
http_client=self._http_client,
)

return (
Expand Down
Loading
Loading