Skip to content
103 changes: 81 additions & 22 deletions oauthex/auth_meta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"net/url"
"os"
"path/filepath"
"strings"
"testing"
)

Expand All @@ -33,32 +34,90 @@ func TestAuthMetaParse(t *testing.T) {
}
}

func TestGetAuthServerMetaRequirePKCE(t *testing.T) {
func TestGetAuthServerMetaPKCESupport(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
hasPKCESupport bool
expectError bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/expect/want/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also wantError can be a string, if empty, then no error is expected, otherwise the error should contain that string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revised

expectedError string
}{
{
name: "server_with_pkce_support",
hasPKCESupport: true,
expectError: false,
},
{
name: "server_without_pkce_support",
hasPKCESupport: false,
expectError: true,
expectedError: "does not implement PKCE",
},
}

// Start a fake OAuth 2.1 auth server that advertises PKCE (S256).
wrapper := http.NewServeMux()
wrapper.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
NewFakeMCPServerMux().ServeHTTP(w, r)
})
ts := httptest.NewTLSServer(wrapper)
defer ts.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Start a fake OAuth 2.1 auth server
wrapper := http.NewServeMux()
wrapper.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) {
u, _ := url.Parse("https://" + r.Host)
issuer := "https://localhost:" + u.Port()
metadata := AuthServerMeta{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
RegistrationEndpoint: issuer + "/register",
JWKSURI: issuer + "/.well-known/jwks.json",
ScopesSupported: []string{"openid", "profile", "email"},
ResponseTypesSupported: []string{"code"},
GrantTypesSupported: []string{"authorization_code"},
TokenEndpointAuthMethodsSupported: []string{"none"},
}

// Validate that the server supports PKCE per MCP auth requirements.
// The fake server sets issuer to https://localhost:<port>, so compute that issuer.
u, _ := url.Parse(ts.URL)
issuer := "https://localhost:" + u.Port()
// Add PKCE support based on test case
if tt.hasPKCESupport {
metadata.CodeChallengeMethodsSupported = []string{"S256"}
}
// If hasPKCESupport is false, CodeChallengeMethodsSupported remains empty

// The fake server presents a cert for example.com; set ServerName accordingly.
httpClient := ts.Client()
if tr, ok := httpClient.Transport.(*http.Transport); ok {
clone := tr.Clone()
clone.TLSClientConfig.ServerName = "example.com"
httpClient.Transport = clone
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(metadata)
})
ts := httptest.NewTLSServer(wrapper)
defer ts.Close()

if _, err := GetAuthServerMeta(ctx, issuer, httpClient); err != nil {
t.Fatal(err)
}
// The fake server sets issuer to https://localhost:<port>, so compute that issuer.
u, _ := url.Parse(ts.URL)
issuer := "https://localhost:" + u.Port()

// The fake server presents a cert for example.com; set ServerName accordingly.
httpClient := ts.Client()
if tr, ok := httpClient.Transport.(*http.Transport); ok {
clone := tr.Clone()
clone.TLSClientConfig.ServerName = "example.com"
httpClient.Transport = clone
}

meta, err := GetAuthServerMeta(ctx, issuer, httpClient)
if tt.expectError {
if err == nil {
t.Fatal("expected error but got none")
}
if !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error to contain %q, but got: %v", tt.expectedError, err)
}
} else {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if meta == nil {
t.Fatal("expected metadata but got nil")
}
// Verify PKCE support is present
if len(meta.CodeChallengeMethodsSupported) == 0 {
t.Error("expected PKCE support but CodeChallengeMethodsSupported is empty")
}
}
})
}
}