Skip to content

Commit f2d4516

Browse files
committed
update test
1 parent c37cd01 commit f2d4516

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

tests/unit/test_token_federation.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,19 @@ def test_token_is_expired(self):
4141
self.assertFalse(token.is_expired())
4242

4343
def test_token_needs_refresh(self):
44-
"""Test Token needs_refresh method."""
44+
"""Test Token needs_refresh method using actual TOKEN_REFRESH_BUFFER_SECONDS."""
4545
# Token with expiry in the past
4646
past = datetime.now(tz=timezone.utc) - timedelta(hours=1)
4747
token = Token("access_token", "Bearer", expiry=past)
4848
self.assertTrue(token.needs_refresh())
4949

5050
# Token with expiry in the near future (within refresh buffer)
51-
near_future = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS - 60)
51+
near_future = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1)
5252
token = Token("access_token", "Bearer", expiry=near_future)
5353
self.assertTrue(token.needs_refresh())
5454

5555
# Token with expiry far in the future
56-
far_future = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS + 60)
56+
far_future = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS + 10)
5757
token = Token("access_token", "Bearer", expiry=far_future)
5858
self.assertFalse(token.needs_refresh())
5959

@@ -127,23 +127,19 @@ def test_token_refresh(self, mock_detect_idp, mock_is_same_host, mock_exchange_t
127127
mock_is_same_host.return_value = False
128128
mock_detect_idp.return_value = "azure"
129129

130-
# Create mock credentials provider that can return different tokens for different calls
131-
mock_creds_provider = MagicMock()
132-
133-
# First call returns initial_token, second call returns fresh_token
130+
# Create the initial header factory
134131
initial_headers = {"Authorization": "Bearer initial_token"}
135-
fresh_headers = {"Authorization": "Bearer fresh_token"}
136-
137-
# Set up initial header factory
138132
initial_header_factory = MagicMock()
139133
initial_header_factory.return_value = initial_headers
140134

141-
# Set up fresh header factory for second call
135+
# Create the fresh header factory for later use
136+
fresh_headers = {"Authorization": "Bearer fresh_token"}
142137
fresh_header_factory = MagicMock()
143138
fresh_header_factory.return_value = fresh_headers
144139

145-
# Configure the mock to return factories
146-
mock_creds_provider.side_effect = [initial_header_factory, fresh_header_factory]
140+
# Create the credentials provider that will return the header factory
141+
mock_creds_provider = MagicMock()
142+
mock_creds_provider.return_value = initial_header_factory
147143

148144
# Set up the token federation provider
149145
federation_provider = DatabricksTokenFederationProvider(
@@ -166,16 +162,18 @@ def test_token_refresh(self, mock_detect_idp, mock_is_same_host, mock_exchange_t
166162

167163
# Reset the mocks to track the next call
168164
mock_exchange_token.reset_mock()
169-
mock_creds_provider.reset_mock()
170-
mock_creds_provider.return_value = fresh_header_factory
171165

172166
# Now simulate an approaching expiry
173-
near_expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS - 60)
167+
near_expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1)
174168
federation_provider.last_exchanged_token = Token(
175169
"exchanged_token_1", "Bearer", expiry=near_expiry
176170
)
177171
federation_provider.last_external_token = "initial_token"
178172

173+
# For the refresh call, we need the credentials provider to return a fresh token
174+
# Update the mock to return fresh_header_factory for the second call
175+
mock_creds_provider.return_value = fresh_header_factory
176+
179177
# Set up the mock to return a different token for the refresh
180178
mock_exchange_token.return_value = Token(
181179
"exchanged_token_2", "Bearer", expiry=future_time
@@ -184,8 +182,7 @@ def test_token_refresh(self, mock_detect_idp, mock_is_same_host, mock_exchange_t
184182
# Make a second call which should trigger refresh
185183
headers = headers_factory()
186184

187-
# Verify a fresh token was requested from the credentials provider
188-
# and the exchange was performed with the fresh token
185+
# Verify the exchange was performed with the fresh token
189186
mock_exchange_token.assert_called_once_with("fresh_token", "azure")
190187

191188
# Verify the headers contain the new token

0 commit comments

Comments
 (0)