@@ -35,6 +35,7 @@ def __init__(
3535 oauth_persistence = None ,
3636 credentials_provider = None ,
3737 identity_federation_client_id : Optional [str ] = None ,
38+ use_token_federation : bool = False ,
3839 ):
3940 self .hostname = hostname
4041 self .access_token = access_token
@@ -47,6 +48,7 @@ def __init__(
4748 self .oauth_persistence = oauth_persistence
4849 self .credentials_provider = credentials_provider
4950 self .identity_federation_client_id = identity_federation_client_id
51+ self .use_token_federation = use_token_federation
5052
5153
5254def get_auth_provider (cfg : ClientContext ):
@@ -71,64 +73,32 @@ def get_auth_provider(cfg: ClientContext):
7173 Raises:
7274 RuntimeError: If no valid authentication settings are provided
7375 """
74- # If credentials_provider is explicitly provided
76+ from databricks . sql . auth . token_federation import DatabricksTokenFederationProvider
7577 if cfg .credentials_provider :
76- # If token federation is enabled and credentials provider is provided,
77- # wrap the credentials provider with DatabricksTokenFederationProvider
78- if cfg .auth_type == AuthType .TOKEN_FEDERATION .value :
79- from databricks .sql .auth .token_federation import (
80- DatabricksTokenFederationProvider ,
81- )
82-
83- federation_provider = DatabricksTokenFederationProvider (
84- cfg .credentials_provider ,
85- cfg .hostname ,
86- cfg .identity_federation_client_id ,
87- )
88- return ExternalAuthProvider (federation_provider )
89-
90- # If not token federation, just use the credentials provider directly
91- return ExternalAuthProvider (cfg .credentials_provider )
92-
93- # If we don't have a credentials provider but have token federation auth type with access token
94- if cfg .auth_type == AuthType .TOKEN_FEDERATION .value and cfg .access_token :
95- # Create a simple credentials provider and wrap it with token federation provider
96- from databricks .sql .auth .token_federation import (
97- DatabricksTokenFederationProvider ,
98- SimpleCredentialsProvider ,
99- )
100-
101- simple_provider = SimpleCredentialsProvider (cfg .access_token )
102- federation_provider = DatabricksTokenFederationProvider (
103- simple_provider , cfg .hostname , cfg .identity_federation_client_id
104- )
105- return ExternalAuthProvider (federation_provider )
106-
107- if cfg .auth_type in [AuthType .DATABRICKS_OAUTH .value , AuthType .AZURE_OAUTH .value ]:
78+ base_provider = ExternalAuthProvider (cfg .credentials_provider )
79+ elif cfg .access_token is not None :
80+ base_provider = AccessTokenAuthProvider (cfg .access_token )
81+ elif cfg .auth_type in [AuthType .DATABRICKS_OAUTH .value , AuthType .AZURE_OAUTH .value ]:
10882 assert cfg .oauth_redirect_port_range is not None
10983 assert cfg .oauth_client_id is not None
11084 assert cfg .oauth_scopes is not None
111-
112- return DatabricksOAuthProvider (
85+ base_provider = DatabricksOAuthProvider (
11386 cfg .hostname ,
11487 cfg .oauth_persistence ,
11588 cfg .oauth_redirect_port_range ,
11689 cfg .oauth_client_id ,
11790 cfg .oauth_scopes ,
11891 cfg .auth_type ,
11992 )
120- elif cfg .access_token is not None :
121- return AccessTokenAuthProvider (cfg .access_token )
12293 elif cfg .use_cert_as_auth and cfg .tls_client_cert_file :
123- # no op authenticator. authentication is performed using ssl certificate outside of headers
124- return AuthProvider ()
94+ base_provider = AuthProvider ()
12595 else :
12696 if (
12797 cfg .oauth_redirect_port_range is not None
12898 and cfg .oauth_client_id is not None
12999 and cfg .oauth_scopes is not None
130100 ):
131- return DatabricksOAuthProvider (
101+ base_provider = DatabricksOAuthProvider (
132102 cfg .hostname ,
133103 cfg .oauth_persistence ,
134104 cfg .oauth_redirect_port_range ,
@@ -138,6 +108,13 @@ def get_auth_provider(cfg: ClientContext):
138108 else :
139109 raise RuntimeError ("No valid authentication settings!" )
140110
111+ if getattr (cfg , "use_token_federation" , False ):
112+ base_provider = DatabricksTokenFederationProvider (
113+ base_provider , cfg .hostname , cfg .identity_federation_client_id
114+ )
115+
116+ return base_provider
117+
141118
142119PYSQL_OAUTH_SCOPES = ["sql" , "offline_access" ]
143120PYSQL_OAUTH_CLIENT_ID = "databricks-sql-python"
@@ -206,5 +183,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
206183 oauth_persistence = kwargs .get ("experimental_oauth_persistence" ),
207184 credentials_provider = kwargs .get ("credentials_provider" ),
208185 identity_federation_client_id = kwargs .get ("identity_federation_client_id" ),
186+ use_token_federation = kwargs .get ("use_token_federation" , False ),
209187 )
210188 return get_auth_provider (cfg )
0 commit comments