Skip to content

Commit 08e4662

Browse files
Merge branch 'main' into enhance-ci
2 parents 467a9ac + 576eafc commit 08e4662

28 files changed

+2213
-389
lines changed

.github/CODEOWNERS

Lines changed: 0 additions & 5 deletions
This file was deleted.

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Release History
22

3+
# 4.0.5 (2025-06-24)
4+
- Fix: Reverted change in cursor close handling which led to errors impacting users (databricks/databricks-sql-python#613 by @madhav-db)
5+
36
# 4.0.4 (2025-06-16)
47

58
- Update thrift client library after cleaning up unused fields and structs (databricks/databricks-sql-python#553 by @vikrantpuppala)

poetry.lock

Lines changed: 39 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "databricks-sql-connector"
3-
version = "4.0.4"
3+
version = "4.0.5"
44
description = "Databricks SQL Connector for Python"
55
authors = ["Databricks <databricks-sql-connector-maintainers@databricks.com>"]
66
license = "Apache-2.0"
@@ -20,16 +20,18 @@ requests = "^2.18.1"
2020
oauthlib = "^3.1.0"
2121
openpyxl = "^3.0.10"
2222
urllib3 = ">=1.26"
23+
python-dateutil = "^2.8.0"
2324
pyarrow = [
2425
{ version = ">=14.0.1", python = ">=3.8,<3.13", optional=true },
2526
{ version = ">=18.0.0", python = ">=3.13", optional=true }
2627
]
27-
python-dateutil = "^2.8.0"
28+
pyjwt = "^2.0.0"
29+
2830

2931
[tool.poetry.extras]
3032
pyarrow = ["pyarrow"]
3133

32-
[tool.poetry.dev-dependencies]
34+
[tool.poetry.group.dev.dependencies]
3335
pytest = "^7.1.2"
3436
mypy = "^1.10.1"
3537
pylint = ">=2.12.0"

src/databricks/sql/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __repr__(self):
6868
DATE = DBAPITypeObject("date")
6969
ROWID = DBAPITypeObject()
7070

71-
__version__ = "4.0.4"
71+
__version__ = "4.0.5"
7272
USER_AGENT_NAME = "PyDatabricksSqlConnector"
7373

7474
# These two functions are pyhive legacy

src/databricks/sql/auth/auth.py

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,29 @@
1-
from enum import Enum
21
from typing import Optional, List
32

43
from databricks.sql.auth.authenticators import (
54
AuthProvider,
65
AccessTokenAuthProvider,
76
ExternalAuthProvider,
87
DatabricksOAuthProvider,
8+
AzureServicePrincipalCredentialProvider,
99
)
10-
11-
12-
class AuthType(Enum):
13-
DATABRICKS_OAUTH = "databricks-oauth"
14-
AZURE_OAUTH = "azure-oauth"
15-
# other supported types (access_token) can be inferred
16-
# we can add more types as needed later
17-
18-
19-
class ClientContext:
20-
def __init__(
21-
self,
22-
hostname: str,
23-
access_token: Optional[str] = None,
24-
auth_type: Optional[str] = None,
25-
oauth_scopes: Optional[List[str]] = None,
26-
oauth_client_id: Optional[str] = None,
27-
oauth_redirect_port_range: Optional[List[int]] = None,
28-
use_cert_as_auth: Optional[str] = None,
29-
tls_client_cert_file: Optional[str] = None,
30-
oauth_persistence=None,
31-
credentials_provider=None,
32-
):
33-
self.hostname = hostname
34-
self.access_token = access_token
35-
self.auth_type = auth_type
36-
self.oauth_scopes = oauth_scopes
37-
self.oauth_client_id = oauth_client_id
38-
self.oauth_redirect_port_range = oauth_redirect_port_range
39-
self.use_cert_as_auth = use_cert_as_auth
40-
self.tls_client_cert_file = tls_client_cert_file
41-
self.oauth_persistence = oauth_persistence
42-
self.credentials_provider = credentials_provider
10+
from databricks.sql.auth.common import AuthType, ClientContext
4311

4412

4513
def get_auth_provider(cfg: ClientContext):
4614
if cfg.credentials_provider:
4715
return ExternalAuthProvider(cfg.credentials_provider)
48-
if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
16+
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
17+
return ExternalAuthProvider(
18+
AzureServicePrincipalCredentialProvider(
19+
cfg.hostname,
20+
cfg.azure_client_id,
21+
cfg.azure_client_secret,
22+
cfg.azure_tenant_id,
23+
cfg.azure_workspace_resource_id,
24+
)
25+
)
26+
elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
4927
assert cfg.oauth_redirect_port_range is not None
5028
assert cfg.oauth_client_id is not None
5129
assert cfg.oauth_scopes is not None
@@ -102,10 +80,13 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):
10280

10381

10482
def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
83+
# TODO : unify all the auth mechanisms with the Python SDK
84+
10585
auth_type = kwargs.get("auth_type")
10686
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
10787
auth_type == AuthType.AZURE_OAUTH.value
10888
)
89+
10990
if kwargs.get("username") or kwargs.get("password"):
11091
raise ValueError(
11192
"Username/password authentication is no longer supported. "
@@ -120,6 +101,10 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
120101
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
121102
oauth_scopes=PYSQL_OAUTH_SCOPES,
122103
oauth_client_id=kwargs.get("oauth_client_id") or client_id,
104+
azure_client_id=kwargs.get("azure_client_id"),
105+
azure_client_secret=kwargs.get("azure_client_secret"),
106+
azure_tenant_id=kwargs.get("azure_tenant_id"),
107+
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"),
123108
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
124109
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
125110
else redirect_port_range,

src/databricks/sql/auth/authenticators.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
import abc
2-
import base64
32
import logging
43
from typing import Callable, Dict, List
5-
6-
from databricks.sql.auth.oauth import OAuthManager
7-
from databricks.sql.auth.endpoint import get_oauth_endpoints, infer_cloud_from_host
4+
from databricks.sql.common.http import HttpHeader
5+
from databricks.sql.auth.oauth import (
6+
OAuthManager,
7+
RefreshableTokenSource,
8+
ClientCredentialsTokenSource,
9+
)
10+
from databricks.sql.auth.endpoint import get_oauth_endpoints
11+
from databricks.sql.auth.common import (
12+
AuthType,
13+
get_effective_azure_login_app_id,
14+
get_azure_tenant_id_from_host,
15+
)
816

917
# Private API: this is an evolving interface and it will change in the future.
1018
# Please must not depend on it in your applications.
@@ -146,3 +154,82 @@ def add_headers(self, request_headers: Dict[str, str]):
146154
headers = self._header_factory()
147155
for k, v in headers.items():
148156
request_headers[k] = v
157+
158+
159+
class AzureServicePrincipalCredentialProvider(CredentialsProvider):
160+
"""
161+
A credential provider for Azure Service Principal authentication with Databricks.
162+
163+
This class implements the CredentialsProvider protocol to authenticate requests
164+
to Databricks REST APIs using Azure Active Directory (AAD) service principal
165+
credentials. It handles OAuth 2.0 client credentials flow to obtain access tokens
166+
from Azure AD and automatically refreshes them when they expire.
167+
168+
Attributes:
169+
hostname (str): The Databricks workspace hostname.
170+
azure_client_id (str): The Azure service principal's client ID.
171+
azure_client_secret (str): The Azure service principal's client secret.
172+
azure_tenant_id (str): The Azure AD tenant ID.
173+
azure_workspace_resource_id (str, optional): The Azure workspace resource ID.
174+
"""
175+
176+
AZURE_AAD_ENDPOINT = "https://login.microsoftonline.com"
177+
AZURE_TOKEN_ENDPOINT = "oauth2/token"
178+
179+
AZURE_MANAGED_RESOURCE = "https://management.core.windows.net/"
180+
181+
DATABRICKS_AZURE_SP_TOKEN_HEADER = "X-Databricks-Azure-SP-Management-Token"
182+
DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = (
183+
"X-Databricks-Azure-Workspace-Resource-Id"
184+
)
185+
186+
def __init__(
187+
self,
188+
hostname,
189+
azure_client_id,
190+
azure_client_secret,
191+
azure_tenant_id=None,
192+
azure_workspace_resource_id=None,
193+
):
194+
self.hostname = hostname
195+
self.azure_client_id = azure_client_id
196+
self.azure_client_secret = azure_client_secret
197+
self.azure_workspace_resource_id = azure_workspace_resource_id
198+
self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(
199+
hostname
200+
)
201+
202+
def auth_type(self) -> str:
203+
return AuthType.AZURE_SP_M2M.value
204+
205+
def get_token_source(self, resource: str) -> RefreshableTokenSource:
206+
return ClientCredentialsTokenSource(
207+
token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}",
208+
client_id=self.azure_client_id,
209+
client_secret=self.azure_client_secret,
210+
extra_params={"resource": resource},
211+
)
212+
213+
def __call__(self, *args, **kwargs) -> HeaderFactory:
214+
inner = self.get_token_source(
215+
resource=get_effective_azure_login_app_id(self.hostname)
216+
)
217+
cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)
218+
219+
def header_factory() -> Dict[str, str]:
220+
inner_token = inner.get_token()
221+
cloud_token = cloud.get_token()
222+
223+
headers = {
224+
HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
225+
self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
226+
}
227+
228+
if self.azure_workspace_resource_id:
229+
headers[
230+
self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
231+
] = self.azure_workspace_resource_id
232+
233+
return headers
234+
235+
return header_factory

0 commit comments

Comments
 (0)