@@ -41,19 +41,19 @@ def test_token_is_expired(self):
4141 self .assertFalse (token .is_expired ())
4242
4343 def test_token_needs_refresh (self ):
44- """Test Token needs_refresh method."""
44+ """Test Token needs_refresh method using actual TOKEN_REFRESH_BUFFER_SECONDS ."""
4545 # Token with expiry in the past
4646 past = datetime .now (tz = timezone .utc ) - timedelta (hours = 1 )
4747 token = Token ("access_token" , "Bearer" , expiry = past )
4848 self .assertTrue (token .needs_refresh ())
4949
5050 # Token with expiry in the near future (within refresh buffer)
51- near_future = datetime .now (tz = timezone .utc ) + timedelta (seconds = TOKEN_REFRESH_BUFFER_SECONDS - 60 )
51+ near_future = datetime .now (tz = timezone .utc ) + timedelta (seconds = TOKEN_REFRESH_BUFFER_SECONDS - 1 )
5252 token = Token ("access_token" , "Bearer" , expiry = near_future )
5353 self .assertTrue (token .needs_refresh ())
5454
5555 # Token with expiry far in the future
56- far_future = datetime .now (tz = timezone .utc ) + timedelta (seconds = TOKEN_REFRESH_BUFFER_SECONDS + 60 )
56+ far_future = datetime .now (tz = timezone .utc ) + timedelta (seconds = TOKEN_REFRESH_BUFFER_SECONDS + 10 )
5757 token = Token ("access_token" , "Bearer" , expiry = far_future )
5858 self .assertFalse (token .needs_refresh ())
5959
@@ -127,23 +127,19 @@ def test_token_refresh(self, mock_detect_idp, mock_is_same_host, mock_exchange_t
127127 mock_is_same_host .return_value = False
128128 mock_detect_idp .return_value = "azure"
129129
130- # Create mock credentials provider that can return different tokens for different calls
131- mock_creds_provider = MagicMock ()
132-
133- # First call returns initial_token, second call returns fresh_token
130+ # Create the initial header factory
134131 initial_headers = {"Authorization" : "Bearer initial_token" }
135- fresh_headers = {"Authorization" : "Bearer fresh_token" }
136-
137- # Set up initial header factory
138132 initial_header_factory = MagicMock ()
139133 initial_header_factory .return_value = initial_headers
140134
141- # Set up fresh header factory for second call
135+ # Create the fresh header factory for later use
136+ fresh_headers = {"Authorization" : "Bearer fresh_token" }
142137 fresh_header_factory = MagicMock ()
143138 fresh_header_factory .return_value = fresh_headers
144139
145- # Configure the mock to return factories
146- mock_creds_provider .side_effect = [initial_header_factory , fresh_header_factory ]
140+ # Create the credentials provider that will return the header factory
141+ mock_creds_provider = MagicMock ()
142+ mock_creds_provider .return_value = initial_header_factory
147143
148144 # Set up the token federation provider
149145 federation_provider = DatabricksTokenFederationProvider (
@@ -166,16 +162,18 @@ def test_token_refresh(self, mock_detect_idp, mock_is_same_host, mock_exchange_t
166162
167163 # Reset the mocks to track the next call
168164 mock_exchange_token .reset_mock ()
169- mock_creds_provider .reset_mock ()
170- mock_creds_provider .return_value = fresh_header_factory
171165
172166 # Now simulate an approaching expiry
173- near_expiry = datetime .now (tz = timezone .utc ) + timedelta (seconds = TOKEN_REFRESH_BUFFER_SECONDS - 60 )
167+ near_expiry = datetime .now (tz = timezone .utc ) + timedelta (seconds = TOKEN_REFRESH_BUFFER_SECONDS - 1 )
174168 federation_provider .last_exchanged_token = Token (
175169 "exchanged_token_1" , "Bearer" , expiry = near_expiry
176170 )
177171 federation_provider .last_external_token = "initial_token"
178172
173+ # For the refresh call, we need the credentials provider to return a fresh token
174+ # Update the mock to return fresh_header_factory for the second call
175+ mock_creds_provider .return_value = fresh_header_factory
176+
179177 # Set up the mock to return a different token for the refresh
180178 mock_exchange_token .return_value = Token (
181179 "exchanged_token_2" , "Bearer" , expiry = future_time
@@ -184,8 +182,7 @@ def test_token_refresh(self, mock_detect_idp, mock_is_same_host, mock_exchange_t
184182 # Make a second call which should trigger refresh
185183 headers = headers_factory ()
186184
187- # Verify a fresh token was requested from the credentials provider
188- # and the exchange was performed with the fresh token
185+ # Verify the exchange was performed with the fresh token
189186 mock_exchange_token .assert_called_once_with ("fresh_token" , "azure" )
190187
191188 # Verify the headers contain the new token
0 commit comments