From 6bb13c552b38496c21c998008d5f32e79f8c8116 Mon Sep 17 00:00:00 2001 From: David Date: Sun, 16 Nov 2025 19:18:02 -0500 Subject: [PATCH] Add service principal authentication support Signed-off-by: David --- .gitignore | 4 +- CHANGELOG.md | 6 +- README.md | 29 +- pyproject.toml | 1 + src/databricks/__init__.py | 1 + .../sqlalchemy/_service_principal.py | 143 ++++++++++ src/databricks/sqlalchemy/base.py | 72 ++++- test.env.example | 8 + .../test_local/e2e/test_service_principal.py | 250 ++++++++++++++++++ tests/unit/test_service_principal.py | 122 +++++++++ 10 files changed, 630 insertions(+), 6 deletions(-) create mode 100644 src/databricks/sqlalchemy/_service_principal.py create mode 100644 test.env.example create mode 100644 tests/test_local/e2e/test_service_principal.py create mode 100644 tests/unit/test_service_principal.py diff --git a/.gitignore b/.gitignore index 4c7e499..f109094 100644 --- a/.gitignore +++ b/.gitignore @@ -158,6 +158,8 @@ celerybeat.pid # Environments .env +.env.local +test.env .venv env/ venv/ @@ -206,4 +208,4 @@ poetry.toml # LSP config files pyrightconfig.json -# End of https://www.toptal.com/developers/gitignore/api/python,macos \ No newline at end of file +# End of https://www.toptal.com/developers/gitignore/api/python,macos diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d0b2f4..55c55db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Release History +# Unreleased + +- Feature: Added first-class Databricks service principal (machine-to-machine OAuth) authentication support for SQLAlchemy connections, working across AWS, Azure, and GCP workspaces (databricks/databricks-sqlalchemy#29) + # 2.0.8 (2025-09-08) - Feature: Added support for variant datatype (databricks/databricks-sqlalchemy#42 by @msrathore-db) @@ -19,4 +23,4 @@ # 2.0.4 (2025-01-27) - All the SQLAlchemy features from `databricks-sql-connector>=4.0.0` have been moved to this `databricks-sqlalchemy` library -- Support for SQLAlchemy v2 dialect is provided \ No newline at end of file +- Support for SQLAlchemy v2 dialect is provided diff --git a/README.md b/README.md index 4c442ad..8b55a3f 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Every SQLAlchemy application that connects to a database needs to use an [Engine 1. Host 2. HTTP Path for a compute resource -3. API access token +3. API access token, or service principal connection parms 4. Initial catalog for the connection 5. Initial schema for the connection @@ -46,6 +46,33 @@ engine = create_engine( ) ``` +### Service principal authentication + +Workspaces that prohibit Personal Access Tokens can now use Databricks service principals (see the [Databricks documentation](https://docs.databricks.com/en/dev-tools/auth/oauth-m2m) for how to create one). Supply the service principal credentials directly in the Databricks SQLAlchemy URL and set `authentication=service_principal`. + +```python +import os +from sqlalchemy import create_engine + +client_id = os.getenv("DATABRICKS_SP_CLIENT_ID") +client_secret = os.getenv("DATABRICKS_SP_CLIENT_SECRET") +host = os.getenv("DATABRICKS_SERVER_HOSTNAME") +http_path = os.getenv("DATABRICKS_HTTP_PATH") +catalog = os.getenv("DATABRICKS_CATALOG") +schema = os.getenv("DATABRICKS_SCHEMA") + +engine = create_engine( + "databricks://" + f"{client_id}:{client_secret}" + f"@{host}?http_path={http_path}&catalog={catalog}&schema={schema}" + "&authentication=service_principal" +) +``` + +`client_id` and `client_secret` are read from the username and password components of the URL. If you prefer to keep the username as `token`, you can pass them in the query string via `client_id` and `client_secret`. By default the dialect requests the Databricks `sql` OAuth scope from the workspace's `/oidc` endpoint. You can override the scopes by providing `sp_scopes` (comma separated) in the query string if you have custom scopes configured. + +For local development, copy `test.env.example` to `test.env`, populate it with your workspace and service principal values, and keep `test.env` untracked (it's listed in `.gitignore`). `pytest` automatically reads this file because it is referenced in `pyproject.toml` via `env_files`. + ## Types The [SQLAlchemy type hierarchy](https://docs.sqlalchemy.org/en/20/core/type_basics.html) contains backend-agnostic type implementations (represented in CamelCase) and backend-specific types (represented in UPPERCASE). The majority of SQLAlchemy's [CamelCase](https://docs.sqlalchemy.org/en/20/core/type_basics.html#the-camelcase-datatypes) types are supported. This means that a SQLAlchemy application using these types should "just work" with Databricks. diff --git a/pyproject.toml b/pyproject.toml index c531857..84207e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ python = "^3.8.0" databricks_sql_connector = { version = ">=4.0.0"} pyarrow = { version = ">=14.0.1"} sqlalchemy = { version = ">=2.0.21" } +requests = { version = ">=2.31.0,<3.0.0" } [tool.poetry.dev-dependencies] pytest = "^7.1.2" diff --git a/src/databricks/__init__.py b/src/databricks/__init__.py index e69de29..8db66d3 100644 --- a/src/databricks/__init__.py +++ b/src/databricks/__init__.py @@ -0,0 +1 @@ +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/src/databricks/sqlalchemy/_service_principal.py b/src/databricks/sqlalchemy/_service_principal.py new file mode 100644 index 0000000..9d05ae3 --- /dev/null +++ b/src/databricks/sqlalchemy/_service_principal.py @@ -0,0 +1,143 @@ +import threading +import time +from typing import Callable, Dict, Iterable, List, Optional + +import requests +from databricks.sql.auth.authenticators import CredentialsProvider +from databricks.sql.auth.endpoint import get_oauth_endpoints + + +class ServicePrincipalConfigurationError(ValueError): + """Raised when the service principal configuration is incomplete.""" + + +class ServicePrincipalAuthenticationError(RuntimeError): + """Raised when fetching an OAuth token fails.""" + + +def _normalize_hostname(hostname: str) -> str: + maybe_scheme = "" if hostname.startswith("https://") else "https://" + trimmed = ( + hostname[len("https://") :] if hostname.startswith("https://") else hostname + ) + return f"{maybe_scheme}{trimmed}".rstrip("/") + + +class ServicePrincipalCredentialsProvider(CredentialsProvider): + """CredentialsProvider that performs the Databricks OAuth client credentials flow.""" + + DEFAULT_SCOPES = ("sql",) + + def __init__( + self, + server_hostname: str, + client_id: str, + client_secret: str, + *, + scopes: Optional[Iterable[str]] = None, + refresh_margin: int = 60, + request_timeout: int = 10, + ): + if not server_hostname: + raise ServicePrincipalConfigurationError("server_hostname is required") + if not client_id: + raise ServicePrincipalConfigurationError("client_id is required") + if not client_secret: + raise ServicePrincipalConfigurationError("client_secret is required") + + self._hostname = _normalize_hostname(server_hostname) + oauth_endpoints = get_oauth_endpoints(self._hostname, use_azure_auth=False) + if not oauth_endpoints: + raise ServicePrincipalConfigurationError( + f"Unable to determine OAuth endpoints for host {server_hostname}" + ) + + scope_tuple = tuple(scopes) if scopes else self.DEFAULT_SCOPES + mapped_scopes = oauth_endpoints.get_scopes_mapping(list(scope_tuple)) + + self._client_id = client_id + self._client_secret = client_secret + self._scopes: List[str] = mapped_scopes + self._refresh_margin = refresh_margin + self._request_timeout = request_timeout + self._access_token: Optional[str] = None + self._expires_at: float = 0 + self._lock = threading.Lock() + self._token_endpoint = self._discover_token_endpoint(oauth_endpoints) + + def auth_type(self) -> str: + return "databricks-service-principal" + + def __call__(self) -> Callable[[], Dict[str, str]]: + def header_factory() -> Dict[str, str]: + access_token = self._get_token() + return {"Authorization": f"Bearer {access_token}"} + + return header_factory + + def _discover_token_endpoint(self, oauth_endpoints) -> str: + openid_config_url = oauth_endpoints.get_openid_config_url(self._hostname) + try: + response = requests.get(openid_config_url, timeout=self._request_timeout) + response.raise_for_status() + config = response.json() + except Exception as exc: + raise ServicePrincipalAuthenticationError( + "Failed to load Databricks OAuth configuration" + ) from exc + + token_endpoint = config.get("token_endpoint") + if not token_endpoint: + raise ServicePrincipalAuthenticationError( + "OAuth configuration did not include a token endpoint" + ) + return token_endpoint + + def _needs_refresh(self) -> bool: + if not self._access_token: + return True + now = time.time() + return now >= (self._expires_at - self._refresh_margin) + + def _get_token(self) -> str: + with self._lock: + if self._needs_refresh(): + self._refresh_token() + assert self._access_token + return self._access_token + + def _refresh_token(self) -> None: + + payload = { + "grant_type": "client_credentials", + "client_id": self._client_id, + "client_secret": self._client_secret, + "scope": " ".join(self._scopes), + } + + response = requests.post( + self._token_endpoint, data=payload, timeout=self._request_timeout + ) + try: + response.raise_for_status() + except Exception as exc: + raise ServicePrincipalAuthenticationError( + "Failed to retrieve OAuth token for service principal" + ) from exc + + try: + parsed = response.json() + access_token = parsed["access_token"] + except Exception as exc: # pragma: no cover - defensive + raise ServicePrincipalAuthenticationError( + "OAuth response did not include an access token" + ) from exc + + expires_in_raw = parsed.get("expires_in", 3600) + try: + expires_in = int(expires_in_raw) + except (TypeError, ValueError): + expires_in = 3600 + + self._access_token = access_token + self._expires_at = time.time() + max(expires_in, self._refresh_margin + 1) diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index 9148de7..ed0aa12 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -13,6 +13,10 @@ get_comment_from_dte_output, parse_column_info_from_tgetcolumnsresponse, ) +from databricks.sqlalchemy._service_principal import ( + ServicePrincipalConfigurationError, + ServicePrincipalCredentialsProvider, +) import sqlalchemy from sqlalchemy import DDL, event @@ -24,7 +28,7 @@ ReflectedTableComment, ) from sqlalchemy.engine.reflection import ReflectionDefaults -from sqlalchemy.exc import DatabaseError, SQLAlchemyError +from sqlalchemy.exc import ArgumentError, DatabaseError, SQLAlchemyError try: import alembic @@ -45,6 +49,14 @@ class DatabricksImpl(DefaultImpl): class DatabricksDialect(default.DefaultDialect): """This dialect implements only those methods required to pass our e2e tests""" + _SERVICE_PRINCIPAL_ALIASES = { + "serviceprincipal", + "service_principal", + "service-principal", + "serviceprincipal-auth", + "sp", + } + # See sqlalchemy.engine.interfaces for descriptions of each of these properties name: str = "databricks" driver: str = "databricks" @@ -105,15 +117,21 @@ def create_connect_args(self, url): # TODO: can schema be provided after HOST? # Expected URI format is: databricks+thrift://token:dapi***@***.cloud.databricks.com?http_path=/sql/*** - kwargs = { + credentials_provider = self._build_service_principal_provider(url) + + kwargs: Dict[str, Any] = { "server_hostname": url.host, - "access_token": url.password, "http_path": url.query.get("http_path"), "catalog": url.query.get("catalog"), "schema": url.query.get("schema"), "use_inline_params": False, } + if credentials_provider: + kwargs["credentials_provider"] = credentials_provider + else: + kwargs["access_token"] = url.password + self.schema = kwargs["schema"] self.catalog = kwargs["catalog"] @@ -121,6 +139,54 @@ def create_connect_args(self, url): return [], kwargs + def _build_service_principal_provider( + self, url + ) -> Optional[ServicePrincipalCredentialsProvider]: + auth_value = ( + url.query.get("authentication") + or url.query.get("auth") + or url.query.get("auth_type") + ) + username_hint = (url.username or "").lower() if url.username else "" + is_service_principal = False + + if auth_value and auth_value.lower() in self._SERVICE_PRINCIPAL_ALIASES: + is_service_principal = True + elif username_hint in self._SERVICE_PRINCIPAL_ALIASES: + is_service_principal = True + + if not is_service_principal: + return None + + client_id = url.query.get("client_id") or url.username + client_secret = url.password or url.query.get("client_secret") + if not client_id: + raise ArgumentError("Service principal connections require a client_id") + if not client_secret: + raise ArgumentError("Service principal connections require a client_secret") + + scopes_raw = ( + url.query.get("sp_scopes") + or url.query.get("sp_scope") + or url.query.get("scope") + ) + scopes = ( + [scope.strip() for scope in scopes_raw.split(",") if scope.strip()] + if scopes_raw + else None + ) + try: + provider = ServicePrincipalCredentialsProvider( + server_hostname=url.host or "", + client_id=client_id, + client_secret=client_secret, + scopes=scopes, + ) + except ServicePrincipalConfigurationError as exc: + raise ArgumentError(str(exc)) from exc + + return provider + def get_columns( self, connection, table_name, schema=None, **kwargs ) -> List[ReflectedColumn]: diff --git a/test.env.example b/test.env.example new file mode 100644 index 0000000..322885e --- /dev/null +++ b/test.env.example @@ -0,0 +1,8 @@ +# Copy this file to `test.env` (which is .gitignored) and fill in the values +# with real credentials before running local tests. +DATABRICKS_SERVER_HOSTNAME= +DATABRICKS_HTTP_PATH= +DATABRICKS_CATALOG= +DATABRICKS_SCHEMA= +DATABRICKS_SP_CLIENT_ID= +DATABRICKS_SP_CLIENT_SECRET= diff --git a/tests/test_local/e2e/test_service_principal.py b/tests/test_local/e2e/test_service_principal.py new file mode 100644 index 0000000..50ec8eb --- /dev/null +++ b/tests/test_local/e2e/test_service_principal.py @@ -0,0 +1,250 @@ +import os +import time +import uuid + +import pytest +from sqlalchemy import ( + BigInteger, + Identity, + MetaData, + String, + Table, + create_engine, + func, + inspect, + insert, + select, + text, +) +from sqlalchemy.dialects import registry +from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column + +from databricks.sqlalchemy.base import DatabricksDialect + +registry.register("databricks", "databricks.sqlalchemy", "DatabricksDialect") + +REQUIRED_ENV = [ + "DATABRICKS_SERVER_HOSTNAME", + "DATABRICKS_HTTP_PATH", + "DATABRICKS_CATALOG", + "DATABRICKS_SCHEMA", + "DATABRICKS_SP_CLIENT_ID", + "DATABRICKS_SP_CLIENT_SECRET", +] + + +def _load_env(): + env = {key: os.getenv(key) for key in REQUIRED_ENV} + missing = [key for key, value in env.items() if not value] + if missing: + pytest.skip( + f"Service principal env vars missing: {', '.join(missing)}. " + "Populate test.env with workspace and service principal credentials." + ) + return env + + +def _build_url(env, extra_params: str = "") -> str: + base = ( + "databricks://" + f"{env['DATABRICKS_SP_CLIENT_ID']}:{env['DATABRICKS_SP_CLIENT_SECRET']}" + f"@{env['DATABRICKS_SERVER_HOSTNAME']}" + f"?http_path={env['DATABRICKS_HTTP_PATH']}" + f"&catalog={env['DATABRICKS_CATALOG']}" + f"&schema={env['DATABRICKS_SCHEMA']}" + "&authentication=service_principal" + ) + if extra_params: + base += "&" + extra_params.lstrip("&") + return base + + +def _fully_qualified(env, table_name: str) -> str: + return ( + f"`{env['DATABRICKS_CATALOG']}`." + f"`{env['DATABRICKS_SCHEMA']}`." + f"`{table_name}`" + ) + + +def _random_table_name(prefix: str) -> str: + return f"{prefix}_{uuid.uuid4().hex[:8]}" + + +@pytest.fixture(scope="session") +def sp_env(): + return _load_env() + + +@pytest.fixture +def sp_engine_factory(sp_env): + engines = [] + + def factory(extra_params: str = ""): + engine = create_engine(_build_url(sp_env, extra_params)) + engines.append(engine) + return engine + + yield factory + + for engine in engines: + engine.dispose() + + +def _drop_table(engine, fq_name: str): + with engine.begin() as conn: + conn.execute(text(f"DROP TABLE IF EXISTS {fq_name}")) + + +def _create_identity_table(engine, fq_name: str): + with engine.begin() as conn: + conn.execute( + text( + f""" + CREATE TABLE {fq_name} ( + id BIGINT GENERATED ALWAYS AS IDENTITY, + value STRING + ) + """ + ) + ) + + +def test_sp_basic_crud(sp_env, sp_engine_factory): + engine = sp_engine_factory() + table_name = _random_table_name("sp_basic") + fq_name = _fully_qualified(sp_env, table_name) + + try: + _drop_table(engine, fq_name) + _create_identity_table(engine, fq_name) + with engine.begin() as conn: + conn.execute( + text(f"INSERT INTO {fq_name} (value) VALUES (:value)"), + {"value": "hello"}, + ) + rows = conn.execute(text(f"SELECT COUNT(*) FROM {fq_name}")).scalar_one() + assert rows == 1 + + conn.execute( + text(f"DELETE FROM {fq_name} WHERE value = :value"), + {"value": "hello"}, + ) + rows = conn.execute(text(f"SELECT COUNT(*) FROM {fq_name}")).scalar_one() + assert rows == 0 + finally: + _drop_table(engine, fq_name) + + +def test_sp_scope_override(sp_env, sp_engine_factory): + engine = sp_engine_factory("sp_scopes=sql") + try: + with engine.connect() as conn: + result = conn.execute(text("SELECT 1")).scalar_one() + assert result == 1 + finally: + engine.dispose() + + +def test_sp_reflection(sp_env, sp_engine_factory): + engine = sp_engine_factory() + table_name = _random_table_name("sp_reflect") + fq_name = _fully_qualified(sp_env, table_name) + + try: + _drop_table(engine, fq_name) + _create_identity_table(engine, fq_name) + inspector = inspect(engine) + columns = inspector.get_columns(table_name, schema=sp_env["DATABRICKS_SCHEMA"]) + id_column = next(col for col in columns if col["name"] == "id") + assert id_column["type"].__class__.__name__.lower().startswith("bigint") + finally: + _drop_table(engine, fq_name) + + +def test_sp_token_refresh_allows_followup_insert( + sp_env, sp_engine_factory, monkeypatch +): + captured = {} + original = DatabricksDialect._build_service_principal_provider + + def wrapper(self, url): + provider = original(self, url) + captured["provider"] = provider + return provider + + monkeypatch.setattr(DatabricksDialect, "_build_service_principal_provider", wrapper) + engine = sp_engine_factory() + table_name = _random_table_name("sp_refresh") + fq_name = _fully_qualified(sp_env, table_name) + + try: + _drop_table(engine, fq_name) + _create_identity_table(engine, fq_name) + with engine.begin() as conn: + conn.execute( + text(f"INSERT INTO {fq_name} (value) VALUES (:value)"), + {"value": "first"}, + ) + + provider = captured.get("provider") + assert provider is not None, "credentials provider was not captured" + provider._expires_at = time.time() - 10 + + with engine.begin() as conn: + conn.execute( + text(f"INSERT INTO {fq_name} (value) VALUES (:value)"), + {"value": "second"}, + ) + rows = conn.execute(text(f"SELECT COUNT(*) FROM {fq_name}")).scalar_one() + assert rows == 2 + finally: + _drop_table(engine, fq_name) + + +def test_sp_orm_identity_roundtrip(sp_env, sp_engine_factory): + engine = sp_engine_factory() + table_name = _random_table_name("sp_orm") + + class Base(DeclarativeBase): + pass + + schema = sp_env["DATABRICKS_SCHEMA"] + + class OrmRecord(Base): + __tablename__ = table_name + __table_args__ = {"schema": schema} + + id: Mapped[int] = mapped_column( + BigInteger, Identity(always=True), primary_key=True + ) + value: Mapped[str] = mapped_column(String(100)) + + try: + Base.metadata.drop_all(engine) + Base.metadata.create_all(engine) + + with Session(engine) as session: + result = session.execute(insert(OrmRecord).values(value="orm-test")) + session.commit() + + with Session(engine) as session: + inserted_id = session.scalar( + select(OrmRecord.id) + .where(OrmRecord.value == "orm-test") + .order_by(OrmRecord.id.desc()) + .limit(1) + ) + assert inserted_id is not None + + fetched = session.get(OrmRecord, inserted_id) + assert fetched is not None + assert fetched.value == "orm-test" + + session.delete(fetched) + session.commit() + + remaining = session.scalar(select(func.count()).select_from(OrmRecord)) + assert remaining == 0 + finally: + Base.metadata.drop_all(engine) diff --git a/tests/unit/test_service_principal.py b/tests/unit/test_service_principal.py new file mode 100644 index 0000000..4335591 --- /dev/null +++ b/tests/unit/test_service_principal.py @@ -0,0 +1,122 @@ +from typing import Dict + +import pytest +from sqlalchemy.engine import make_url +from sqlalchemy.exc import ArgumentError + +import databricks.sqlalchemy._service_principal as sp +from databricks.sqlalchemy.base import DatabricksDialect +from databricks.sqlalchemy._service_principal import ( + ServicePrincipalConfigurationError, + ServicePrincipalCredentialsProvider, +) + + +class DummyResponse: + def __init__(self, payload: Dict[str, object], status_code: int = 200): + self._payload = payload + self.status_code = status_code + + def raise_for_status(self): + if not 200 <= self.status_code < 300: + raise RuntimeError("request failed") + return None + + def json(self): + return self._payload + + +def test_service_principal_provider_refreshes_tokens(monkeypatch): + responses = iter( + [ + {"access_token": "token-one", "expires_in": 5}, + {"access_token": "token-two", "expires_in": 5}, + ] + ) + call_count = {"value": 0} + token_endpoint = "https://dbc.cloud.databricks.com/oidc/oauth2/v2.0/token" + + def fake_get(url, timeout): + assert "oidc" in url + return DummyResponse({"token_endpoint": token_endpoint}) + + def fake_post(url, data, timeout): + assert url == token_endpoint + call_count["value"] += 1 + assert data["client_id"] == "client-id" + return DummyResponse(next(responses)) + + current_time = {"value": 0} + + def fake_time(): + return current_time["value"] + + monkeypatch.setattr(sp.requests, "post", fake_post) + monkeypatch.setattr(sp.requests, "get", fake_get) + monkeypatch.setattr(sp.time, "time", fake_time) + + provider = sp.ServicePrincipalCredentialsProvider( + "dbc.cloud.databricks.com", + "client-id", + "secret", + refresh_margin=0, + ) + header_factory = provider() + assert header_factory()["Authorization"] == "Bearer token-one" + assert call_count["value"] == 1 + + current_time["value"] = 2 + assert header_factory()["Authorization"] == "Bearer token-one" + assert call_count["value"] == 1 + + current_time["value"] = 6 + assert header_factory()["Authorization"] == "Bearer token-two" + assert call_count["value"] == 2 + + +def test_service_principal_provider_requires_hostname(): + with pytest.raises(ServicePrincipalConfigurationError): + sp.ServicePrincipalCredentialsProvider( + server_hostname="", + client_id="client", + client_secret="secret", + ) + + +def test_create_connect_args_with_service_principal(monkeypatch): + token_endpoint = "https://dbc.cloud.databricks.com/oidc/oauth2/v2.0/token" + + def fake_get(url, timeout): + return DummyResponse({"token_endpoint": token_endpoint}) + + def fake_post(url, data, timeout): + return DummyResponse({"access_token": "token", "expires_in": 3600}) + + monkeypatch.setattr(sp.requests, "get", fake_get) + monkeypatch.setattr(sp.requests, "post", fake_post) + + dialect = DatabricksDialect() + url = make_url( + "databricks://client-id:client-secret@acme.cloud.databricks.com" + "?http_path=/sql/1&catalog=main&schema=test" + "&authentication=service_principal" + ) + + _, kwargs = dialect.create_connect_args(url) + + assert "access_token" not in kwargs + assert isinstance( + kwargs.get("credentials_provider"), ServicePrincipalCredentialsProvider + ) + assert kwargs["server_hostname"] == "acme.cloud.databricks.com" + + +def test_create_connect_args_requires_secret(): + dialect = DatabricksDialect() + url = make_url( + "databricks://client-id@acme.cloud.databricks.com" + "?http_path=/sql/1&catalog=main&schema=test&authentication=service_principal" + ) + + with pytest.raises(ArgumentError, match="client_secret"): + dialect.create_connect_args(url)