11import abc
22import logging
3- from typing import Callable , Dict , List , Optional
3+ from typing import Callable , Dict , List
44from databricks .sql .common .http import HttpHeader
55from databricks .sql .auth .oauth import (
66 OAuthManager ,
99)
1010from databricks .sql .auth .endpoint import get_oauth_endpoints
1111from databricks .sql .auth .common import AuthType , get_effective_azure_login_app_id
12+ from databricks .sdk import WorkspaceClient
1213
1314# Private API: this is an evolving interface and it will change in the future.
1415# Please must not depend on it in your applications.
1516from databricks .sql .experimental .oauth_persistence import OAuthToken , OAuthPersistence
1617
17- logger = logging .getLogger (__name__ )
1818
1919
2020class AuthProvider :
@@ -189,6 +189,13 @@ def __init__(
189189 azure_tenant_id ,
190190 azure_workspace_resource_id = None ,
191191 ):
192+ self .workspace_api_client = WorkspaceClient (
193+ host = hostname ,
194+ azure_workspace_resource_id = azure_workspace_resource_id ,
195+ azure_tenant_id = azure_tenant_id ,
196+ azure_client_id = oauth_client_id ,
197+ azure_client_secret = oauth_client_secret ,
198+ )
192199 self .hostname = hostname
193200 self .oauth_client_id = oauth_client_id
194201 self .oauth_client_secret = oauth_client_secret
@@ -207,25 +214,26 @@ def get_token_source(self, resource: str) -> RefreshableTokenSource:
207214 )
208215
209216 def __call__ (self , * args , ** kwargs ) -> HeaderFactory :
210- inner = self .get_token_source (
211- resource = get_effective_azure_login_app_id (self .hostname )
212- )
213- cloud = self .get_token_source (resource = self .AZURE_MANAGED_RESOURCE )
217+ # inner = self.get_token_source(
218+ # resource=get_effective_azure_login_app_id(self.hostname)
219+ # )
220+ # cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)
214221
215222 def header_factory () -> Dict [str , str ]:
216- inner_token = inner .get_token ()
217- cloud_token = cloud .get_token ()
223+ # inner_token = inner.get_token()
224+ # cloud_token = cloud.get_token()
218225
219- headers = {
220- HttpHeader .AUTHORIZATION .value : f"{ inner_token .token_type } { inner_token .access_token } " ,
221- self .DATABRICKS_AZURE_SP_TOKEN_HEADER : cloud_token .access_token ,
222- }
226+ # headers = {
227+ # HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
228+ # self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
229+ # }
223230
224- if self .azure_workspace_resource_id :
225- headers [
226- self .DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
227- ] = self .azure_workspace_resource_id
231+ # if self.azure_workspace_resource_id:
232+ # headers[
233+ # self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
234+ # ] = self.azure_workspace_resource_id
228235
229- return headers
236+ # return headers
237+ return self .workspace_api_client .config .authenticate ()
230238
231239 return header_factory
0 commit comments