Skip to content

Commit 5040569

Browse files
committed
initial commit
1 parent 85d0cd9 commit 5040569

File tree

1 file changed

+18
-40
lines changed

1 file changed

+18
-40
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
oauth_persistence=None,
3636
credentials_provider=None,
3737
identity_federation_client_id: Optional[str] = None,
38+
use_token_federation: bool = False,
3839
):
3940
self.hostname = hostname
4041
self.access_token = access_token
@@ -47,6 +48,7 @@ def __init__(
4748
self.oauth_persistence = oauth_persistence
4849
self.credentials_provider = credentials_provider
4950
self.identity_federation_client_id = identity_federation_client_id
51+
self.use_token_federation = use_token_federation
5052

5153

5254
def get_auth_provider(cfg: ClientContext):
@@ -71,64 +73,32 @@ def get_auth_provider(cfg: ClientContext):
7173
Raises:
7274
RuntimeError: If no valid authentication settings are provided
7375
"""
74-
# If credentials_provider is explicitly provided
76+
from databricks.sql.auth.token_federation import DatabricksTokenFederationProvider
7577
if cfg.credentials_provider:
76-
# If token federation is enabled and credentials provider is provided,
77-
# wrap the credentials provider with DatabricksTokenFederationProvider
78-
if cfg.auth_type == AuthType.TOKEN_FEDERATION.value:
79-
from databricks.sql.auth.token_federation import (
80-
DatabricksTokenFederationProvider,
81-
)
82-
83-
federation_provider = DatabricksTokenFederationProvider(
84-
cfg.credentials_provider,
85-
cfg.hostname,
86-
cfg.identity_federation_client_id,
87-
)
88-
return ExternalAuthProvider(federation_provider)
89-
90-
# If not token federation, just use the credentials provider directly
91-
return ExternalAuthProvider(cfg.credentials_provider)
92-
93-
# If we don't have a credentials provider but have token federation auth type with access token
94-
if cfg.auth_type == AuthType.TOKEN_FEDERATION.value and cfg.access_token:
95-
# Create a simple credentials provider and wrap it with token federation provider
96-
from databricks.sql.auth.token_federation import (
97-
DatabricksTokenFederationProvider,
98-
SimpleCredentialsProvider,
99-
)
100-
101-
simple_provider = SimpleCredentialsProvider(cfg.access_token)
102-
federation_provider = DatabricksTokenFederationProvider(
103-
simple_provider, cfg.hostname, cfg.identity_federation_client_id
104-
)
105-
return ExternalAuthProvider(federation_provider)
106-
107-
if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
78+
base_provider = ExternalAuthProvider(cfg.credentials_provider)
79+
elif cfg.access_token is not None:
80+
base_provider = AccessTokenAuthProvider(cfg.access_token)
81+
elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
10882
assert cfg.oauth_redirect_port_range is not None
10983
assert cfg.oauth_client_id is not None
11084
assert cfg.oauth_scopes is not None
111-
112-
return DatabricksOAuthProvider(
85+
base_provider = DatabricksOAuthProvider(
11386
cfg.hostname,
11487
cfg.oauth_persistence,
11588
cfg.oauth_redirect_port_range,
11689
cfg.oauth_client_id,
11790
cfg.oauth_scopes,
11891
cfg.auth_type,
11992
)
120-
elif cfg.access_token is not None:
121-
return AccessTokenAuthProvider(cfg.access_token)
12293
elif cfg.use_cert_as_auth and cfg.tls_client_cert_file:
123-
# no op authenticator. authentication is performed using ssl certificate outside of headers
124-
return AuthProvider()
94+
base_provider = AuthProvider()
12595
else:
12696
if (
12797
cfg.oauth_redirect_port_range is not None
12898
and cfg.oauth_client_id is not None
12999
and cfg.oauth_scopes is not None
130100
):
131-
return DatabricksOAuthProvider(
101+
base_provider = DatabricksOAuthProvider(
132102
cfg.hostname,
133103
cfg.oauth_persistence,
134104
cfg.oauth_redirect_port_range,
@@ -138,6 +108,13 @@ def get_auth_provider(cfg: ClientContext):
138108
else:
139109
raise RuntimeError("No valid authentication settings!")
140110

111+
if getattr(cfg, "use_token_federation", False):
112+
base_provider = DatabricksTokenFederationProvider(
113+
base_provider, cfg.hostname, cfg.identity_federation_client_id
114+
)
115+
116+
return base_provider
117+
141118

142119
PYSQL_OAUTH_SCOPES = ["sql", "offline_access"]
143120
PYSQL_OAUTH_CLIENT_ID = "databricks-sql-python"
@@ -206,5 +183,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
206183
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
207184
credentials_provider=kwargs.get("credentials_provider"),
208185
identity_federation_client_id=kwargs.get("identity_federation_client_id"),
186+
use_token_federation=kwargs.get("use_token_federation", False),
209187
)
210188
return get_auth_provider(cfg)

0 commit comments

Comments
 (0)