@@ -344,20 +344,26 @@ def test_token_exchange_success(self, federation_provider):
344344 """Test successful token exchange."""
345345 # Mock successful response
346346 with patch ("databricks.sql.auth.token_federation.requests.post" ) as mock_post :
347+ # Create a token with a valid expiry
348+ expiry_timestamp = int (
349+ (datetime .now (tz = timezone .utc ) + timedelta (hours = 1 )).timestamp ()
350+ )
351+
347352 # Configure mock response
348353 mock_response = MagicMock ()
349354 mock_response .status_code = 200
350355 mock_response .json .return_value = {
351356 "access_token" : "new_token" ,
352357 "token_type" : "Bearer" ,
353358 "refresh_token" : "refresh_value" ,
354- "expires_in" : 3600 ,
355359 }
356360 mock_post .return_value = mock_response
357361
358- # Patch the _get_expiry_from_jwt method to return None (forcing use of expires_in)
362+ # Mock JWT expiry extraction to return a valid expiry
359363 with patch .object (
360- federation_provider , "_get_expiry_from_jwt" , return_value = None
364+ federation_provider ,
365+ "_get_expiry_from_jwt" ,
366+ return_value = datetime .fromtimestamp (expiry_timestamp , tz = timezone .utc ),
361367 ):
362368 # Call the exchange method
363369 token = federation_provider ._exchange_token ("original_token" )
@@ -367,10 +373,11 @@ def test_token_exchange_success(self, federation_provider):
367373 assert token .token_type == "Bearer"
368374 assert token .refresh_token == "refresh_value"
369375
370- # Verify expiry time (should be ~1 hour in future)
371- now = datetime .now (tz = timezone .utc )
372- assert token .expiry > now
373- assert token .expiry < now + timedelta (seconds = 3601 )
376+ # Verify expiry time is correctly set
377+ expiry_datetime = datetime .fromtimestamp (
378+ expiry_timestamp , tz = timezone .utc
379+ )
380+ assert token .expiry == expiry_datetime
374381
375382 def test_token_exchange_failure (self , federation_provider ):
376383 """Test token exchange failure handling."""
0 commit comments