Skip to content

Commit 63a7f82

Browse files
committed
address comments
1 parent ae5ee50 commit 63a7f82

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

src/databricks/sql/auth/token_federation.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,20 @@ def __init__(
5656

5757
def add_headers(self, request_headers: Dict[str, str]):
5858
"""Add authentication headers to the request."""
59+
60+
if self._cached_token and not self._cached_token.is_expired():
61+
request_headers["Authorization"] = f"{self._cached_token.token_type} {self._cached_token.access_token}"
62+
return
63+
64+
# Get the external headers first to check if we need token federation
65+
self._external_headers = {}
66+
self.external_provider.add_headers(self._external_headers)
67+
68+
# If no Authorization header from external provider, pass through all headers
69+
if "Authorization" not in self._external_headers:
70+
request_headers.update(self._external_headers)
71+
return
72+
5973
token = self._get_token()
6074
request_headers["Authorization"] = f"{token.token_type} {token.access_token}"
6175

@@ -65,11 +79,7 @@ def _get_token(self) -> Token:
6579
if self._cached_token and not self._cached_token.is_expired():
6680
return self._cached_token
6781

68-
# Get the external token
69-
self._external_headers = {}
70-
self.external_provider.add_headers(self._external_headers)
71-
72-
# Extract token from Authorization header
82+
# Extract token from already-fetched headers
7383
auth_header = self._external_headers.get("Authorization", "")
7484
token_type, access_token = self._extract_token_from_header(auth_header)
7585

tests/unit/test_auth.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
164164
kwargs = {"credentials_provider": MyProvider()}
165165
mock_http_client = MagicMock()
166166
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs)
167-
self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider")
167+
168+
self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider")
169+
self.assertEqual(type(auth_provider.external_provider).__name__, "ExternalAuthProvider")
168170

169171
headers = {}
170172
auth_provider.add_headers(headers)
@@ -199,8 +201,11 @@ def test_get_python_sql_connector_default_auth(self, mock__initial_get_token):
199201
hostname = "foo.cloud.databricks.com"
200202
mock_http_client = MagicMock()
201203
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client)
202-
self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider")
203-
self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID)
204+
205+
self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider")
206+
self.assertEqual(type(auth_provider.external_provider).__name__, "DatabricksOAuthProvider")
207+
208+
self.assertEqual(auth_provider.external_provider._client_id, PYSQL_OAUTH_CLIENT_ID)
204209

205210

206211
class TestClientCredentialsTokenSource:

0 commit comments

Comments
 (0)