Skip to content

Commit a377ce7

Browse files
committed
Refractor
1 parent 42841f1 commit a377ce7

File tree

6 files changed

+124
-93
lines changed

6 files changed

+124
-93
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 9 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,39 +7,7 @@
77
DatabricksOAuthProvider,
88
AzureServicePrincipalCredentialProvider,
99
)
10-
from databricks.sql.auth.common import AuthType
11-
12-
13-
class ClientContext:
14-
def __init__(
15-
self,
16-
hostname: str,
17-
access_token: Optional[str] = None,
18-
auth_type: Optional[str] = None,
19-
oauth_scopes: Optional[List[str]] = None,
20-
oauth_client_id: Optional[str] = None,
21-
oauth_client_secret: Optional[str] = None,
22-
azure_tenant_id: Optional[str] = None,
23-
azure_workspace_resource_id: Optional[str] = None,
24-
oauth_redirect_port_range: Optional[List[int]] = None,
25-
use_cert_as_auth: Optional[str] = None,
26-
tls_client_cert_file: Optional[str] = None,
27-
oauth_persistence=None,
28-
credentials_provider=None,
29-
):
30-
self.hostname = hostname
31-
self.access_token = access_token
32-
self.auth_type = auth_type
33-
self.oauth_scopes = oauth_scopes
34-
self.oauth_client_id = oauth_client_id
35-
self.oauth_client_secret = oauth_client_secret
36-
self.azure_tenant_id = azure_tenant_id
37-
self.azure_workspace_resource_id = azure_workspace_resource_id
38-
self.oauth_redirect_port_range = oauth_redirect_port_range
39-
self.use_cert_as_auth = use_cert_as_auth
40-
self.tls_client_cert_file = tls_client_cert_file
41-
self.oauth_persistence = oauth_persistence
42-
self.credentials_provider = credentials_provider
10+
from databricks.sql.auth.common import AuthType, ClientContext
4311

4412

4513
def get_auth_provider(cfg: ClientContext):
@@ -49,8 +17,8 @@ def get_auth_provider(cfg: ClientContext):
4917
return ExternalAuthProvider(
5018
AzureServicePrincipalCredentialProvider(
5119
cfg.hostname,
52-
cfg.oauth_client_id,
53-
cfg.oauth_client_secret,
20+
cfg.azure_client_id,
21+
cfg.azure_client_secret,
5422
cfg.azure_tenant_id,
5523
cfg.azure_workspace_resource_id,
5624
)
@@ -113,13 +81,10 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):
11381

11482
def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
11583
auth_type = kwargs.get("auth_type")
116-
client_id = kwargs.get("oauth_client_id")
117-
redirect_port_range = kwargs.get("oauth_redirect_port_range")
84+
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
85+
auth_type == AuthType.AZURE_OAUTH.value
86+
)
11887

119-
if auth_type != AuthType.AZURE_SP_M2M.value:
120-
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
121-
auth_type == AuthType.AZURE_OAUTH.value
122-
)
12388
if kwargs.get("username") or kwargs.get("password"):
12489
raise ValueError(
12590
"Username/password authentication is no longer supported. "
@@ -133,8 +98,9 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
13398
use_cert_as_auth=kwargs.get("_use_cert_as_auth"),
13499
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
135100
oauth_scopes=PYSQL_OAUTH_SCOPES,
136-
oauth_client_id=client_id,
137-
oauth_client_secret=kwargs.get("oauth_client_secret"),
101+
oauth_client_id=kwargs.get("oauth_client_id") or client_id,
102+
azure_client_id=kwargs.get("azure_client_id"),
103+
azure_client_secret=kwargs.get("azure_client_secret"),
138104
azure_tenant_id=kwargs.get("azure_tenant_id"),
139105
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"),
140106
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]

src/databricks/sql/auth/authenticators.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@
88
ClientCredentialsTokenSource,
99
)
1010
from databricks.sql.auth.endpoint import get_oauth_endpoints
11-
from databricks.sql.auth.common import AuthType, get_effective_azure_login_app_id
12-
from databricks.sdk import WorkspaceClient
11+
from databricks.sql.auth.common import (
12+
AuthType,
13+
get_effective_azure_login_app_id,
14+
get_azure_tenant_id_from_host,
15+
)
1316

1417
# Private API: this is an evolving interface and it will change in the future.
1518
# Please must not depend on it in your applications.
1619
from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence
1720

1821

19-
2022
class AuthProvider:
2123
def add_headers(self, request_headers: Dict[str, str]):
2224
pass
@@ -165,8 +167,8 @@ class AzureServicePrincipalCredentialProvider(CredentialsProvider):
165167
166168
Attributes:
167169
hostname (str): The Databricks workspace hostname.
168-
oauth_client_id (str): The Azure service principal's client ID.
169-
oauth_client_secret (str): The Azure service principal's client secret.
170+
azure_client_id (str): The Azure service principal's client ID.
171+
azure_client_secret (str): The Azure service principal's client secret.
170172
azure_tenant_id (str): The Azure AD tenant ID.
171173
azure_workspace_resource_id (str, optional): The Azure workspace resource ID.
172174
"""
@@ -184,56 +186,50 @@ class AzureServicePrincipalCredentialProvider(CredentialsProvider):
184186
def __init__(
185187
self,
186188
hostname,
187-
oauth_client_id,
188-
oauth_client_secret,
189-
azure_tenant_id,
189+
azure_client_id,
190+
azure_client_secret,
191+
azure_tenant_id=None,
190192
azure_workspace_resource_id=None,
191193
):
192-
self.workspace_api_client = WorkspaceClient(
193-
host=hostname,
194-
azure_workspace_resource_id=azure_workspace_resource_id,
195-
azure_tenant_id=azure_tenant_id,
196-
azure_client_id=oauth_client_id,
197-
azure_client_secret=oauth_client_secret,
198-
)
199194
self.hostname = hostname
200-
self.oauth_client_id = oauth_client_id
201-
self.oauth_client_secret = oauth_client_secret
202-
self.azure_tenant_id = azure_tenant_id
195+
self.azure_client_id = azure_client_id
196+
self.azure_client_secret = azure_client_secret
203197
self.azure_workspace_resource_id = azure_workspace_resource_id
198+
self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(
199+
hostname
200+
)
204201

205202
def auth_type(self) -> str:
206203
return AuthType.AZURE_SP_M2M.value
207204

208205
def get_token_source(self, resource: str) -> RefreshableTokenSource:
209206
return ClientCredentialsTokenSource(
210207
token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}",
211-
oauth_client_id=self.oauth_client_id,
212-
oauth_client_secret=self.oauth_client_secret,
208+
client_id=self.azure_client_id,
209+
client_secret=self.azure_client_secret,
213210
extra_params={"resource": resource},
214211
)
215212

216213
def __call__(self, *args, **kwargs) -> HeaderFactory:
217-
# inner = self.get_token_source(
218-
# resource=get_effective_azure_login_app_id(self.hostname)
219-
# )
220-
# cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)
214+
inner = self.get_token_source(
215+
resource=get_effective_azure_login_app_id(self.hostname)
216+
)
217+
cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)
221218

222219
def header_factory() -> Dict[str, str]:
223-
# inner_token = inner.get_token()
224-
# cloud_token = cloud.get_token()
220+
inner_token = inner.get_token()
221+
cloud_token = cloud.get_token()
225222

226-
# headers = {
227-
# HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
228-
# self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
229-
# }
223+
headers = {
224+
HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
225+
self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
226+
}
230227

231-
# if self.azure_workspace_resource_id:
232-
# headers[
233-
# self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
234-
# ] = self.azure_workspace_resource_id
228+
if self.azure_workspace_resource_id:
229+
headers[
230+
self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
231+
] = self.azure_workspace_resource_id
235232

236-
# return headers
237-
return self.workspace_api_client.config.authenticate()
233+
return headers
238234

239235
return header_factory

src/databricks/sql/auth/common.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
from enum import Enum
2+
import logging
3+
from typing import Optional, List
4+
from urllib.parse import urlparse
5+
from databricks.sql.common.http import DatabricksHttpClient, HttpMethod
6+
7+
logger = logging.getLogger(__name__)
28

39

410
class AuthType(Enum):
@@ -13,6 +19,40 @@ class AzureAppId(Enum):
1319
PROD = (".azuredatabricks.net", "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d")
1420

1521

22+
class ClientContext:
23+
def __init__(
24+
self,
25+
hostname: str,
26+
access_token: Optional[str] = None,
27+
auth_type: Optional[str] = None,
28+
oauth_scopes: Optional[List[str]] = None,
29+
oauth_client_id: Optional[str] = None,
30+
azure_client_id: Optional[str] = None,
31+
azure_client_secret: Optional[str] = None,
32+
azure_tenant_id: Optional[str] = None,
33+
azure_workspace_resource_id: Optional[str] = None,
34+
oauth_redirect_port_range: Optional[List[int]] = None,
35+
use_cert_as_auth: Optional[str] = None,
36+
tls_client_cert_file: Optional[str] = None,
37+
oauth_persistence=None,
38+
credentials_provider=None,
39+
):
40+
self.hostname = hostname
41+
self.access_token = access_token
42+
self.auth_type = auth_type
43+
self.oauth_scopes = oauth_scopes
44+
self.oauth_client_id = oauth_client_id
45+
self.azure_client_id = azure_client_id
46+
self.azure_client_secret = azure_client_secret
47+
self.azure_tenant_id = azure_tenant_id
48+
self.azure_workspace_resource_id = azure_workspace_resource_id
49+
self.oauth_redirect_port_range = oauth_redirect_port_range
50+
self.use_cert_as_auth = use_cert_as_auth
51+
self.tls_client_cert_file = tls_client_cert_file
52+
self.oauth_persistence = oauth_persistence
53+
self.credentials_provider = credentials_provider
54+
55+
1656
def get_effective_azure_login_app_id(hostname) -> str:
1757
"""
1858
Get the effective Azure login app ID for a given hostname.
@@ -27,3 +67,30 @@ def get_effective_azure_login_app_id(hostname) -> str:
2767

2868
# default databricks resource id
2969
return AzureAppId.PROD.value
70+
71+
72+
def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
73+
"""
74+
Load the Azure tenant ID from the Azure Databricks login page.
75+
"""
76+
77+
if http_client is None:
78+
http_client = DatabricksHttpClient.get_instance()
79+
80+
login_url = f"{host}/aad/auth"
81+
logger.debug("Loading tenant ID from %s", login_url)
82+
with http_client.execute(HttpMethod.GET, login_url, allow_redirects=False) as resp:
83+
if resp.status_code // 100 != 3:
84+
raise ValueError(
85+
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}"
86+
)
87+
entra_id_endpoint = resp.headers.get("Location")
88+
if entra_id_endpoint is None:
89+
raise ValueError(f"No Location header in response from {login_url}")
90+
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
91+
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
92+
url = urlparse(entra_id_endpoint)
93+
path_segments = url.path.split("/")
94+
if len(path_segments) < 2:
95+
raise ValueError(f"Invalid path in Location header: {url.path}")
96+
return path_segments[1]

src/databricks/sql/auth/oauth.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -311,19 +311,19 @@ class ClientCredentialsTokenSource(RefreshableTokenSource):
311311
312312
Attributes:
313313
token_url (str): The URL of the token endpoint.
314-
oauth_client_id (str): The client ID.
315-
oauth_client_secret (str): The client secret.
314+
client_id (str): The client ID.
315+
client_secret (str): The client secret.
316316
"""
317317

318318
def __init__(
319319
self,
320320
token_url,
321-
oauth_client_id,
322-
oauth_client_secret,
321+
client_id,
322+
client_secret,
323323
extra_params: dict = {},
324324
):
325-
self.oauth_client_id = oauth_client_id
326-
self.oauth_client_secret = oauth_client_secret
325+
self.client_id = client_id
326+
self.client_secret = client_secret
327327
self.token_url = token_url
328328
self.extra_params = extra_params
329329
self.token: Optional[Token] = None
@@ -341,8 +341,8 @@ def refresh(self) -> Token:
341341
data = urlencode(
342342
{
343343
"grant_type": "client_credentials",
344-
"client_id": self.oauth_client_id,
345-
"client_secret": self.oauth_client_secret,
344+
"client_id": self.client_id,
345+
"client_secret": self.client_secret,
346346
**self.extra_params,
347347
}
348348
)

src/databricks/sql/common/http.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ def get_instance(cls) -> "DatabricksHttpClient":
6262
if cls._instance is None:
6363
cls._instance = DatabricksHttpClient()
6464
return cls._instance
65-
65+
6666
@contextmanager
67-
def execute(self, method: HttpMethod, url: str, **kwargs) -> Generator[requests.Response, None, None]:
67+
def execute(
68+
self, method: HttpMethod, url: str, **kwargs
69+
) -> Generator[requests.Response, None, None]:
6870
response = None
6971
try:
7072
response = self.session.request(method.value, url, **kwargs)

tests/unit/test_auth.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ def status_response(response_status_code):
224224
def token_source(self):
225225
return ClientCredentialsTokenSource(
226226
token_url="https://token_url.com",
227-
oauth_client_id="client_id",
228-
oauth_client_secret="client_secret",
227+
client_id="client_id",
228+
client_secret="client_secret",
229229
)
230230

231231
def test_no_token_refresh__when_token_is_not_expired(
@@ -271,8 +271,8 @@ class TestAzureServicePrincipalCredentialProvider:
271271
def credential_provider(self):
272272
return AzureServicePrincipalCredentialProvider(
273273
hostname="hostname",
274-
oauth_client_id="client_id",
275-
oauth_client_secret="client_secret",
274+
azure_client_id="client_id",
275+
azure_client_secret="client_secret",
276276
azure_tenant_id="tenant_id",
277277
)
278278

0 commit comments

Comments
 (0)