Skip to content

Commit 42841f1

Browse files
committed
testing sdk
1 parent ce4543c commit 42841f1

File tree

6 files changed

+70
-51
lines changed

6 files changed

+70
-51
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,7 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
116116
client_id = kwargs.get("oauth_client_id")
117117
redirect_port_range = kwargs.get("oauth_redirect_port_range")
118118

119-
if auth_type == AuthType.AZURE_SP_M2M.value:
120-
pass
121-
else:
119+
if auth_type != AuthType.AZURE_SP_M2M.value:
122120
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
123121
auth_type == AuthType.AZURE_OAUTH.value
124122
)
@@ -140,7 +138,7 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
140138
azure_tenant_id=kwargs.get("azure_tenant_id"),
141139
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"),
142140
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
143-
if client_id and kwargs.get("oauth_redirect_port")
141+
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
144142
else redirect_port_range,
145143
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
146144
credentials_provider=kwargs.get("credentials_provider"),

src/databricks/sql/auth/authenticators.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import abc
22
import logging
3-
from typing import Callable, Dict, List, Optional
3+
from typing import Callable, Dict, List
44
from databricks.sql.common.http import HttpHeader
55
from databricks.sql.auth.oauth import (
66
OAuthManager,
@@ -9,12 +9,12 @@
99
)
1010
from databricks.sql.auth.endpoint import get_oauth_endpoints
1111
from databricks.sql.auth.common import AuthType, get_effective_azure_login_app_id
12+
from databricks.sdk import WorkspaceClient
1213

1314
# Private API: this is an evolving interface and it will change in the future.
1415
# Please must not depend on it in your applications.
1516
from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence
1617

17-
logger = logging.getLogger(__name__)
1818

1919

2020
class AuthProvider:
@@ -189,6 +189,13 @@ def __init__(
189189
azure_tenant_id,
190190
azure_workspace_resource_id=None,
191191
):
192+
self.workspace_api_client = WorkspaceClient(
193+
host=hostname,
194+
azure_workspace_resource_id=azure_workspace_resource_id,
195+
azure_tenant_id=azure_tenant_id,
196+
azure_client_id=oauth_client_id,
197+
azure_client_secret=oauth_client_secret,
198+
)
192199
self.hostname = hostname
193200
self.oauth_client_id = oauth_client_id
194201
self.oauth_client_secret = oauth_client_secret
@@ -207,25 +214,26 @@ def get_token_source(self, resource: str) -> RefreshableTokenSource:
207214
)
208215

209216
def __call__(self, *args, **kwargs) -> HeaderFactory:
210-
inner = self.get_token_source(
211-
resource=get_effective_azure_login_app_id(self.hostname)
212-
)
213-
cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)
217+
# inner = self.get_token_source(
218+
# resource=get_effective_azure_login_app_id(self.hostname)
219+
# )
220+
# cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)
214221

215222
def header_factory() -> Dict[str, str]:
216-
inner_token = inner.get_token()
217-
cloud_token = cloud.get_token()
223+
# inner_token = inner.get_token()
224+
# cloud_token = cloud.get_token()
218225

219-
headers = {
220-
HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
221-
self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
222-
}
226+
# headers = {
227+
# HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
228+
# self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
229+
# }
223230

224-
if self.azure_workspace_resource_id:
225-
headers[
226-
self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
227-
] = self.azure_workspace_resource_id
231+
# if self.azure_workspace_resource_id:
232+
# headers[
233+
# self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
234+
# ] = self.azure_workspace_resource_id
228235

229-
return headers
236+
# return headers
237+
return self.workspace_api_client.config.authenticate()
230238

231239
return header_factory

src/databricks/sql/auth/common.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from enum import Enum
2-
from typing import Optional
32

43

54
class AuthType(Enum):
@@ -8,21 +7,23 @@ class AuthType(Enum):
87
AZURE_SP_M2M = "azure-sp-m2m"
98

109

10+
class AzureAppId(Enum):
11+
DEV = (".dev.azuredatabricks.net", "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc")
12+
STAGING = (".staging.azuredatabricks.net", "4a67d088-db5c-48f1-9ff2-0aace800ae68")
13+
PROD = (".azuredatabricks.net", "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d")
14+
15+
1116
def get_effective_azure_login_app_id(hostname) -> str:
1217
"""
1318
Get the effective Azure login app ID for a given hostname.
1419
This function determines the appropriate Azure login app ID based on the hostname.
1520
If the hostname does not match any of these domains, it returns the default Databricks resource ID.
1621
1722
"""
18-
azure_app_ids = {
19-
".dev.azuredatabricks.net": "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc",
20-
".staging.azuredatabricks.net": "4a67d088-db5c-48f1-9ff2-0aace800ae68",
21-
}
22-
23-
for domain, app_id in azure_app_ids.items():
23+
for azure_app_id in AzureAppId:
24+
domain, app_id = azure_app_id.value
2425
if domain in hostname:
2526
return app_id
2627

2728
# default databricks resource id
28-
return "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
29+
return AzureAppId.PROD.value

src/databricks/sql/auth/oauth.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, access_token: str, token_type: str, refresh_token: str):
3939
self.token_type = token_type
4040
self.refresh_token = refresh_token
4141

42-
def is_expired(self):
42+
def is_expired(self) -> bool:
4343
try:
4444
decoded_token = jwt.decode(
4545
self.access_token, options={"verify_signature": False}
@@ -50,7 +50,7 @@ def is_expired(self):
5050
return exp_time and (exp_time - buffer_time) <= current_time
5151
except Exception as e:
5252
logger.error("Failed to decode token: %s", e)
53-
return e
53+
raise e
5454

5555

5656
class RefreshableTokenSource(ABC):
@@ -347,18 +347,17 @@ def refresh(self) -> Token:
347347
}
348348
)
349349

350-
response = self._http_client.execute(
350+
with self._http_client.execute(
351351
method=HttpMethod.POST, url=self.token_url, headers=headers, data=data
352-
)
353-
354-
if response.status_code == 200:
355-
oauth_response = OAuthResponse(**response.json())
356-
return Token(
357-
oauth_response.access_token,
358-
oauth_response.token_type,
359-
oauth_response.refresh_token,
360-
)
361-
else:
362-
raise Exception(
363-
f"Failed to get token: {response.status_code} {response.text}"
364-
)
352+
) as response:
353+
if response.status_code == 200:
354+
oauth_response = OAuthResponse(**response.json())
355+
return Token(
356+
oauth_response.access_token,
357+
oauth_response.token_type,
358+
oauth_response.refresh_token,
359+
)
360+
else:
361+
raise Exception(
362+
f"Failed to get token: {response.status_code} {response.text}"
363+
)

src/databricks/sql/common/http.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
from enum import Enum
55
import threading
66
from dataclasses import dataclass
7+
from contextlib import contextmanager
8+
from typing import Generator
9+
import logging
10+
11+
logger = logging.getLogger(__name__)
12+
713

814
# Enums for HTTP Methods
915
class HttpMethod(str, Enum):
@@ -56,10 +62,19 @@ def get_instance(cls) -> "DatabricksHttpClient":
5662
if cls._instance is None:
5763
cls._instance = DatabricksHttpClient()
5864
return cls._instance
59-
60-
def execute(self, method: HttpMethod, url: str, **kwargs) -> requests.Response:
61-
with self.session.request(method.value, url, **kwargs) as response:
62-
return response
65+
66+
@contextmanager
67+
def execute(self, method: HttpMethod, url: str, **kwargs) -> Generator[requests.Response, None, None]:
68+
response = None
69+
try:
70+
response = self.session.request(method.value, url, **kwargs)
71+
yield response
72+
except Exception as e:
73+
logger.error(f"Error executing HTTP request in DatabricksHttpClient: {e}")
74+
raise e
75+
finally:
76+
if response is not None:
77+
response.close()
6378

6479
def close(self):
6580
self.session.close()

tests/unit/test_auth.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import unittest
22
import pytest
3-
from typing import Optional
43
from unittest.mock import patch, MagicMock
54
import jwt
65
from databricks.sql.auth.auth import (
@@ -10,7 +9,6 @@
109
AuthType,
1110
)
1211
import time
13-
from datetime import datetime, timedelta
1412
from databricks.sql.auth.auth import (
1513
get_python_sql_connector_auth_provider,
1614
PYSQL_OAUTH_CLIENT_ID,

0 commit comments

Comments
 (0)