diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index 73b8fb10..3723318f 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -139,6 +139,11 @@ func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (* // Security violation; don't keep trying. return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuerURL) } + + if len(asm.CodeChallengeMethodsSupported) == 0 { + return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuerURL) + } + return asm, nil } errs = append(errs, err) diff --git a/oauthex/auth_meta_test.go b/oauthex/auth_meta_test.go index bee8f900..1e608824 100644 --- a/oauthex/auth_meta_test.go +++ b/oauthex/auth_meta_test.go @@ -7,9 +7,14 @@ package oauthex import ( + "context" "encoding/json" + "net/http" + "net/http/httptest" + "net/url" "os" "path/filepath" + "strings" "testing" ) @@ -28,3 +33,88 @@ func TestAuthMetaParse(t *testing.T) { t.Errorf("got %q, want %q", g, w) } } + +func TestGetAuthServerMetaPKCESupport(t *testing.T) { + ctx := context.Background() + tests := []struct { + name string + hasPKCESupport bool + wantError string + }{ + { + name: "server_with_pkce_support", + hasPKCESupport: true, + }, + { + name: "server_without_pkce_support", + hasPKCESupport: false, + wantError: "does not implement PKCE", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Start a fake OAuth 2.1 auth server + wrapper := http.NewServeMux() + wrapper.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) { + u, _ := url.Parse("https://" + r.Host) + issuer := "https://localhost:" + u.Port() + metadata := AuthServerMeta{ + Issuer: issuer, + AuthorizationEndpoint: issuer + "/authorize", + TokenEndpoint: issuer + "/token", + RegistrationEndpoint: issuer + "/register", + JWKSURI: issuer + "/.well-known/jwks.json", + ScopesSupported: []string{"openid", "profile", "email"}, + ResponseTypesSupported: []string{"code"}, + GrantTypesSupported: []string{"authorization_code"}, + TokenEndpointAuthMethodsSupported: []string{"none"}, + } + + // Add PKCE support based on test case + if tt.hasPKCESupport { + metadata.CodeChallengeMethodsSupported = []string{"S256"} + } + // If hasPKCESupport is false, CodeChallengeMethodsSupported remains empty + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(metadata) + }) + ts := httptest.NewTLSServer(wrapper) + defer ts.Close() + + // The fake server sets issuer to https://localhost:, so compute that issuer. + u, _ := url.Parse(ts.URL) + issuer := "https://localhost:" + u.Port() + + // The fake server presents a cert for example.com; set ServerName accordingly. + httpClient := ts.Client() + if tr, ok := httpClient.Transport.(*http.Transport); ok { + clone := tr.Clone() + clone.TLSClientConfig.ServerName = "example.com" + httpClient.Transport = clone + } + + meta, err := GetAuthServerMeta(ctx, issuer, httpClient) + if tt.wantError != "" { + if err == nil { + t.Fatal("wanted error but got none") + } + if !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("wanted error to contain %q, but got: %v", tt.wantError, err) + } + } else { + if err != nil { + t.Fatalf("unwanted error: %v", err) + } + if meta == nil { + t.Fatal("wanted metadata but got nil") + } + // Verify PKCE support is present + if len(meta.CodeChallengeMethodsSupported) == 0 { + t.Error("wanted PKCE support but CodeChallengeMethodsSupported is empty") + } + } + }) + } +}