Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ celerybeat.pid

# Environments
.env
.env.local
test.env
.venv
env/
venv/
Expand Down Expand Up @@ -206,4 +208,4 @@ poetry.toml
# LSP config files
pyrightconfig.json

# End of https://www.toptal.com/developers/gitignore/api/python,macos
# End of https://www.toptal.com/developers/gitignore/api/python,macos
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Release History

# Unreleased

- Feature: Added first-class Databricks service principal (machine-to-machine OAuth) authentication support for SQLAlchemy connections, working across AWS, Azure, and GCP workspaces (databricks/databricks-sqlalchemy#29)

# 2.0.8 (2025-09-08)

- Feature: Added support for variant datatype (databricks/databricks-sqlalchemy#42 by @msrathore-db)
Expand All @@ -19,4 +23,4 @@
# 2.0.4 (2025-01-27)

- All the SQLAlchemy features from `databricks-sql-connector>=4.0.0` have been moved to this `databricks-sqlalchemy` library
- Support for SQLAlchemy v2 dialect is provided
- Support for SQLAlchemy v2 dialect is provided
29 changes: 28 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Every SQLAlchemy application that connects to a database needs to use an [Engine

1. Host
2. HTTP Path for a compute resource
3. API access token
3. API access token, or service principal connection parms
4. Initial catalog for the connection
5. Initial schema for the connection

Expand All @@ -46,6 +46,33 @@ engine = create_engine(
)
```

### Service principal authentication

Workspaces that prohibit Personal Access Tokens can now use Databricks service principals (see the [Databricks documentation](https://docs.databricks.com/en/dev-tools/auth/oauth-m2m) for how to create one). Supply the service principal credentials directly in the Databricks SQLAlchemy URL and set `authentication=service_principal`.

```python
import os
from sqlalchemy import create_engine

client_id = os.getenv("DATABRICKS_SP_CLIENT_ID")
client_secret = os.getenv("DATABRICKS_SP_CLIENT_SECRET")
host = os.getenv("DATABRICKS_SERVER_HOSTNAME")
http_path = os.getenv("DATABRICKS_HTTP_PATH")
catalog = os.getenv("DATABRICKS_CATALOG")
schema = os.getenv("DATABRICKS_SCHEMA")

engine = create_engine(
"databricks://"
f"{client_id}:{client_secret}"
f"@{host}?http_path={http_path}&catalog={catalog}&schema={schema}"
"&authentication=service_principal"
)
```

`client_id` and `client_secret` are read from the username and password components of the URL. If you prefer to keep the username as `token`, you can pass them in the query string via `client_id` and `client_secret`. By default the dialect requests the Databricks `sql` OAuth scope from the workspace's `/oidc` endpoint. You can override the scopes by providing `sp_scopes` (comma separated) in the query string if you have custom scopes configured.

For local development, copy `test.env.example` to `test.env`, populate it with your workspace and service principal values, and keep `test.env` untracked (it's listed in `.gitignore`). `pytest` automatically reads this file because it is referenced in `pyproject.toml` via `env_files`.

## Types

The [SQLAlchemy type hierarchy](https://docs.sqlalchemy.org/en/20/core/type_basics.html) contains backend-agnostic type implementations (represented in CamelCase) and backend-specific types (represented in UPPERCASE). The majority of SQLAlchemy's [CamelCase](https://docs.sqlalchemy.org/en/20/core/type_basics.html#the-camelcase-datatypes) types are supported. This means that a SQLAlchemy application using these types should "just work" with Databricks.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ python = "^3.8.0"
databricks_sql_connector = { version = ">=4.0.0"}
pyarrow = { version = ">=14.0.1"}
sqlalchemy = { version = ">=2.0.21" }
requests = { version = ">=2.31.0,<3.0.0" }

[tool.poetry.dev-dependencies]
pytest = "^7.1.2"
Expand Down
1 change: 1 addition & 0 deletions src/databricks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
143 changes: 143 additions & 0 deletions src/databricks/sqlalchemy/_service_principal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import threading
import time
from typing import Callable, Dict, Iterable, List, Optional

import requests
from databricks.sql.auth.authenticators import CredentialsProvider
from databricks.sql.auth.endpoint import get_oauth_endpoints


class ServicePrincipalConfigurationError(ValueError):
"""Raised when the service principal configuration is incomplete."""


class ServicePrincipalAuthenticationError(RuntimeError):
"""Raised when fetching an OAuth token fails."""


def _normalize_hostname(hostname: str) -> str:
maybe_scheme = "" if hostname.startswith("https://") else "https://"
trimmed = (
hostname[len("https://") :] if hostname.startswith("https://") else hostname
)
return f"{maybe_scheme}{trimmed}".rstrip("/")


class ServicePrincipalCredentialsProvider(CredentialsProvider):
"""CredentialsProvider that performs the Databricks OAuth client credentials flow."""

DEFAULT_SCOPES = ("sql",)

def __init__(
self,
server_hostname: str,
client_id: str,
client_secret: str,
*,
scopes: Optional[Iterable[str]] = None,
refresh_margin: int = 60,
request_timeout: int = 10,
):
if not server_hostname:
raise ServicePrincipalConfigurationError("server_hostname is required")
if not client_id:
raise ServicePrincipalConfigurationError("client_id is required")
if not client_secret:
raise ServicePrincipalConfigurationError("client_secret is required")

self._hostname = _normalize_hostname(server_hostname)
oauth_endpoints = get_oauth_endpoints(self._hostname, use_azure_auth=False)
if not oauth_endpoints:
raise ServicePrincipalConfigurationError(
f"Unable to determine OAuth endpoints for host {server_hostname}"
)

scope_tuple = tuple(scopes) if scopes else self.DEFAULT_SCOPES
mapped_scopes = oauth_endpoints.get_scopes_mapping(list(scope_tuple))

self._client_id = client_id
self._client_secret = client_secret
self._scopes: List[str] = mapped_scopes
self._refresh_margin = refresh_margin
self._request_timeout = request_timeout
self._access_token: Optional[str] = None
self._expires_at: float = 0
self._lock = threading.Lock()
self._token_endpoint = self._discover_token_endpoint(oauth_endpoints)

def auth_type(self) -> str:
return "databricks-service-principal"

def __call__(self) -> Callable[[], Dict[str, str]]:
def header_factory() -> Dict[str, str]:
access_token = self._get_token()
return {"Authorization": f"Bearer {access_token}"}

return header_factory

def _discover_token_endpoint(self, oauth_endpoints) -> str:
openid_config_url = oauth_endpoints.get_openid_config_url(self._hostname)
try:
response = requests.get(openid_config_url, timeout=self._request_timeout)
response.raise_for_status()
config = response.json()
except Exception as exc:
raise ServicePrincipalAuthenticationError(
"Failed to load Databricks OAuth configuration"
) from exc

token_endpoint = config.get("token_endpoint")
if not token_endpoint:
raise ServicePrincipalAuthenticationError(
"OAuth configuration did not include a token endpoint"
)
return token_endpoint

def _needs_refresh(self) -> bool:
if not self._access_token:
return True
now = time.time()
return now >= (self._expires_at - self._refresh_margin)

def _get_token(self) -> str:
with self._lock:
if self._needs_refresh():
self._refresh_token()
assert self._access_token
return self._access_token

def _refresh_token(self) -> None:

payload = {
"grant_type": "client_credentials",
"client_id": self._client_id,
"client_secret": self._client_secret,
"scope": " ".join(self._scopes),
}

response = requests.post(
self._token_endpoint, data=payload, timeout=self._request_timeout
)
try:
response.raise_for_status()
except Exception as exc:
raise ServicePrincipalAuthenticationError(
"Failed to retrieve OAuth token for service principal"
) from exc

try:
parsed = response.json()
access_token = parsed["access_token"]
except Exception as exc: # pragma: no cover - defensive
raise ServicePrincipalAuthenticationError(
"OAuth response did not include an access token"
) from exc

expires_in_raw = parsed.get("expires_in", 3600)
try:
expires_in = int(expires_in_raw)
except (TypeError, ValueError):
expires_in = 3600

self._access_token = access_token
self._expires_at = time.time() + max(expires_in, self._refresh_margin + 1)
72 changes: 69 additions & 3 deletions src/databricks/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
get_comment_from_dte_output,
parse_column_info_from_tgetcolumnsresponse,
)
from databricks.sqlalchemy._service_principal import (
ServicePrincipalConfigurationError,
ServicePrincipalCredentialsProvider,
)

import sqlalchemy
from sqlalchemy import DDL, event
Expand All @@ -24,7 +28,7 @@
ReflectedTableComment,
)
from sqlalchemy.engine.reflection import ReflectionDefaults
from sqlalchemy.exc import DatabaseError, SQLAlchemyError
from sqlalchemy.exc import ArgumentError, DatabaseError, SQLAlchemyError

try:
import alembic
Expand All @@ -45,6 +49,14 @@ class DatabricksImpl(DefaultImpl):
class DatabricksDialect(default.DefaultDialect):
"""This dialect implements only those methods required to pass our e2e tests"""

_SERVICE_PRINCIPAL_ALIASES = {
"serviceprincipal",
"service_principal",
"service-principal",
"serviceprincipal-auth",
"sp",
}

# See sqlalchemy.engine.interfaces for descriptions of each of these properties
name: str = "databricks"
driver: str = "databricks"
Expand Down Expand Up @@ -105,22 +117,76 @@ def create_connect_args(self, url):
# TODO: can schema be provided after HOST?
# Expected URI format is: databricks+thrift://token:dapi***@***.cloud.databricks.com?http_path=/sql/***

kwargs = {
credentials_provider = self._build_service_principal_provider(url)

kwargs: Dict[str, Any] = {
"server_hostname": url.host,
"access_token": url.password,
"http_path": url.query.get("http_path"),
"catalog": url.query.get("catalog"),
"schema": url.query.get("schema"),
"use_inline_params": False,
}

if credentials_provider:
kwargs["credentials_provider"] = credentials_provider
else:
kwargs["access_token"] = url.password

self.schema = kwargs["schema"]
self.catalog = kwargs["catalog"]

self._force_paramstyle_to_native_mode()

return [], kwargs

def _build_service_principal_provider(
self, url
) -> Optional[ServicePrincipalCredentialsProvider]:
auth_value = (
url.query.get("authentication")
or url.query.get("auth")
or url.query.get("auth_type")
)
username_hint = (url.username or "").lower() if url.username else ""
is_service_principal = False

if auth_value and auth_value.lower() in self._SERVICE_PRINCIPAL_ALIASES:
is_service_principal = True
elif username_hint in self._SERVICE_PRINCIPAL_ALIASES:
is_service_principal = True

if not is_service_principal:
return None

client_id = url.query.get("client_id") or url.username
client_secret = url.password or url.query.get("client_secret")
if not client_id:
raise ArgumentError("Service principal connections require a client_id")
if not client_secret:
raise ArgumentError("Service principal connections require a client_secret")

scopes_raw = (
url.query.get("sp_scopes")
or url.query.get("sp_scope")
or url.query.get("scope")
)
scopes = (
[scope.strip() for scope in scopes_raw.split(",") if scope.strip()]
if scopes_raw
else None
)
try:
provider = ServicePrincipalCredentialsProvider(
server_hostname=url.host or "",
client_id=client_id,
client_secret=client_secret,
scopes=scopes,
)
except ServicePrincipalConfigurationError as exc:
raise ArgumentError(str(exc)) from exc

return provider

def get_columns(
self, connection, table_name, schema=None, **kwargs
) -> List[ReflectedColumn]:
Expand Down
8 changes: 8 additions & 0 deletions test.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copy this file to `test.env` (which is .gitignored) and fill in the values
# with real credentials before running local tests.
DATABRICKS_SERVER_HOSTNAME=<workspace-hostname>
DATABRICKS_HTTP_PATH=<sql-http-path>
DATABRICKS_CATALOG=<catalog>
DATABRICKS_SCHEMA=<schema>
DATABRICKS_SP_CLIENT_ID=<service-principal-client-id>
DATABRICKS_SP_CLIENT_SECRET=<service-principal-secret>
Loading
Loading