1+ #!/usr/bin/env python3
2+
3+ """
4+ Unit tests for the JDBC-style token refresh in Databricks SQL connector.
5+
6+ This test verifies that the token federation implementation follows the JDBC driver's approach
7+ of getting a fresh external token before exchanging it for a Databricks token during refresh.
8+ """
9+
10+ import unittest
11+ from unittest .mock import patch , MagicMock
12+ from datetime import datetime , timezone , timedelta
13+
14+ from databricks .sql .auth .token_federation import (
15+ DatabricksTokenFederationProvider ,
16+ Token
17+ )
18+
19+
20+ class RefreshingCredentialsProvider :
21+ """
22+ A credentials provider that returns different tokens on each call.
23+ This simulates providers like Azure AD that can refresh their tokens.
24+ """
25+
26+ def __init__ (self ):
27+ self .call_count = 0
28+
29+ def auth_type (self ):
30+ return "bearer"
31+
32+ def __call__ (self , * args , ** kwargs ):
33+ def get_headers ():
34+ self .call_count += 1
35+ # Return a different token each time to simulate fresh tokens
36+ return {"Authorization" : f"Bearer fresh_token_{ self .call_count } " }
37+ return get_headers
38+
39+
40+ class TestJdbcStyleTokenRefresh (unittest .TestCase ):
41+ """Tests for the JDBC-style token refresh implementation."""
42+
43+ @patch ('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims' )
44+ @patch ('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token' )
45+ @patch ('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host' )
46+ def test_refresh_gets_fresh_token (self , mock_is_same_host , mock_exchange_token , mock_parse_jwt ):
47+ """Test that token refresh first gets a fresh external token."""
48+ # Set up mocks
49+ mock_parse_jwt .return_value = {"iss" : "https://login.microsoftonline.com/tenant" }
50+ mock_is_same_host .return_value = False
51+
52+ # Create a credentials provider that returns different tokens on each call
53+ refreshing_provider = RefreshingCredentialsProvider ()
54+
55+ # Set up the token federation provider
56+ federation_provider = DatabricksTokenFederationProvider (
57+ refreshing_provider , "example.com" , "client_id"
58+ )
59+
60+ # Set up mock for token exchange
61+ future_time = datetime .now (tz = timezone .utc ) + timedelta (hours = 1 )
62+ mock_exchange_token .return_value = Token (
63+ "exchanged_token_1" , "Bearer" , expiry = future_time
64+ )
65+
66+ # First call to get initial headers and token
67+ headers_factory = federation_provider ()
68+ headers = headers_factory ()
69+
70+ # Verify the first exchange happened
71+ mock_exchange_token .assert_called_with ("fresh_token_1" , "azure" )
72+ self .assertEqual (headers ["Authorization" ], "Bearer exchanged_token_1" )
73+ self .assertEqual (refreshing_provider .call_count , 1 )
74+
75+ # Reset the mock to track the next call
76+ mock_exchange_token .reset_mock ()
77+
78+ # Now simulate an approaching expiry
79+ near_expiry = datetime .now (tz = timezone .utc ) + timedelta (minutes = 4 )
80+ federation_provider .last_exchanged_token = Token (
81+ "exchanged_token_1" , "Bearer" , expiry = near_expiry
82+ )
83+ federation_provider .last_external_token = "fresh_token_1"
84+
85+ # Set up the mock to return a different token for the refresh
86+ mock_exchange_token .return_value = Token (
87+ "exchanged_token_2" , "Bearer" , expiry = future_time
88+ )
89+
90+ # Make a second call which should trigger refresh
91+ headers = headers_factory ()
92+
93+ # With JDBC-style implementation:
94+ # 1. Should call credentials provider again to get fresh token
95+ self .assertEqual (refreshing_provider .call_count , 2 )
96+
97+ # 2. Should exchange the FRESH token (fresh_token_2), not the stored one (fresh_token_1)
98+ mock_exchange_token .assert_called_once_with ("fresh_token_2" , "azure" )
99+
100+ # 3. Should return headers with the new Databricks token
101+ self .assertEqual (headers ["Authorization" ], "Bearer exchanged_token_2" )
102+
103+
104+ if __name__ == "__main__" :
105+ unittest .main ()
0 commit comments