Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,103 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl
except StopAsyncIteration:
pass # Expected - generator should complete

@pytest.mark.anyio
async def test_prm_endpoint_not_implemented_fallthrough(self, oauth_provider: OAuthClientProvider):
"""Test that PRM endpoint failures fall through without raising errors (backward compatibility)."""
# Ensure no tokens are stored
oauth_provider.context.current_tokens = None
oauth_provider.context.token_expiry_time = None
oauth_provider._initialized = True

# Mock client info to skip DCR
oauth_provider.context.client_info = OAuthClientInformationFull(
client_id="existing_client",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)

# Create a test request
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")

# Mock the auth flow
auth_flow = oauth_provider.async_auth_flow(test_request)

# First request should be the original request without auth header
request = await auth_flow.__anext__()
assert "Authorization" not in request.headers

# Send a 401 response to trigger the OAuth flow
response = httpx.Response(
401,
headers={
"WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
},
request=test_request,
)

# Next request should be to discover protected resource metadata
discovery_request = await auth_flow.asend(response)
assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
assert discovery_request.method == "GET"

# Send a 404 response - PRM endpoint not implemented (legacy server)
# This should NOT raise an error, but fall through to legacy OAuth discovery
prm_404_response = httpx.Response(
404,
content=b"Not Found",
request=discovery_request,
)

# Next request should fall through to legacy OAuth discovery fallback
# Since PRM failed, it should try OAuth metadata discovery
oauth_metadata_request = await auth_flow.asend(prm_404_response)
assert oauth_metadata_request.method == "GET"
# Should try one of the fallback URLs
assert ".well-known/oauth-authorization-server" in str(oauth_metadata_request.url)

# Send a successful OAuth metadata response to continue the flow
oauth_metadata_response = httpx.Response(
200,
content=(
b'{"issuer": "https://api.example.com", '
b'"authorization_endpoint": "https://api.example.com/authorize", '
b'"token_endpoint": "https://api.example.com/token"}'
),
request=oauth_metadata_request,
)

# Mock the authorization process
oauth_provider._perform_authorization_code_grant = mock.AsyncMock(
return_value=("test_auth_code", "test_code_verifier")
)

# Next request should be token exchange (mocked authorization, so goes straight to token)
token_request = await auth_flow.asend(oauth_metadata_response)
assert str(token_request.url) == "https://api.example.com/token"
assert token_request.method == "POST"

# Send a successful token response
token_response = httpx.Response(
200,
content=(
b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
b'"refresh_token": "new_refresh_token"}'
),
request=token_request,
)

# After OAuth flow completes, the original request is retried with auth header
final_request = await auth_flow.asend(token_response)
assert final_request.headers["Authorization"] == "Bearer new_access_token"
assert final_request.method == "GET"
assert str(final_request.url) == "https://api.example.com/v1/mcp"

# Send final success response to properly close the generator
final_response = httpx.Response(200, request=final_request)
try:
await auth_flow.asend(final_response)
except StopAsyncIteration:
pass # Expected - generator should complete

@pytest.mark.anyio
async def test_handle_metadata_response_success(self, oauth_provider: OAuthClientProvider):
"""Test successful metadata response handling."""
Expand Down
Loading