Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 53 additions & 5 deletions oauthex/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,15 @@ func TestGetProtectedResourceMetadata(t *testing.T) {
server := httptest.NewTLSServer(h)
h.installHandlers(server.URL)
client := server.Client()
res, err := client.Get(server.URL + "/resource")
serverURL := server.URL + "/resource"
res, err := client.Get(serverURL)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusUnauthorized {
t.Fatal("want unauth")
}
prm, err := GetProtectedResourceMetadataFromHeader(ctx, res.Header, client)
prm, err := GetProtectedResourceMetadataFromHeader(ctx, serverURL, res.Header, client)
if err != nil {
t.Fatal(err)
}
Expand All @@ -240,11 +241,53 @@ func TestGetProtectedResourceMetadata(t *testing.T) {
t.Fatal("nil prm")
}
})
// Test that metadata URL and resource identifier are properly distinguished (issue #560).
t.Run("FromHeaderValidatesAgainstServerURL", func(t *testing.T) {
h := &fakeResourceHandler{serveWWWAuthenticate: true}
server := httptest.NewTLSServer(h)
h.installHandlers(server.URL)
client := server.Client()
serverURL := server.URL + "/resource"
res, err := client.Get(serverURL)
if err != nil {
t.Fatal(err)
}
// This should succeed because we validate against serverURL, not metadataURL
prm, err := GetProtectedResourceMetadataFromHeader(ctx, serverURL, res.Header, client)
if err != nil {
t.Fatalf("Expected validation to succeed, got error: %v", err)
}
if prm == nil {
t.Fatal("Expected non-nil prm")
}
if prm.Resource != serverURL {
t.Errorf("Expected resource %q, got %q", serverURL, prm.Resource)
}
})
t.Run("FromHeaderRejectsImpersonation", func(t *testing.T) {
h := &fakeResourceHandler{serveWWWAuthenticate: true, resourceOverride: "https://attacker.com/evil"}
server := httptest.NewTLSServer(h)
h.installHandlers(server.URL)
client := server.Client()
serverURL := server.URL + "/resource"
res, err := client.Get(serverURL)
if err != nil {
t.Fatal(err)
}
prm, err := GetProtectedResourceMetadataFromHeader(ctx, serverURL, res.Header, client)
if err == nil {
t.Fatal("Expected validation error for mismatched resource, got nil")
}
if prm != nil {
t.Fatal("Expected nil prm on validation failure")
}
})
}

type fakeResourceHandler struct {
http.ServeMux
serveWWWAuthenticate bool
resourceOverride string // If set, use this instead of correct resource (for testing validation)
}

func (h *fakeResourceHandler) installHandlers(serverURL string) {
Expand All @@ -258,11 +301,16 @@ func (h *fakeResourceHandler) installHandlers(serverURL string) {
}))
h.Handle("GET "+path, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// If there is a WWW-Authenticate header, the resource field is the value of that header.
// If not, it's the server URL without the "/.well-known/..." part.
// Per RFC 9728 section 3.3, the resource field should contain the actual resource identifier,
// which is the URL the client uses to access the resource (serverURL + "/resource" for WWW-Authenticate case).
// For the FromID test case, it's just the serverURL.
resource := serverURL
if h.serveWWWAuthenticate {
resource = url
resource = serverURL + "/resource"
}
// Allow testing with custom resource values (e.g., impersonation attacks).
if h.resourceOverride != "" {
resource = h.resourceOverride
}
prm := &ProtectedResourceMetadata{Resource: resource}
if err := json.NewEncoder(w).Encode(prm); err != nil {
Expand Down
17 changes: 9 additions & 8 deletions oauthex/resource_meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string,
// GetProtectedResourceMetadataFromHeader retrieves protected resource metadata
// using information in the given header, using the given client (or the default
// client if nil).
// It issues a GET request to a URL discovered by parsing the WWW-Authenticate headers in the given request,
// It then validates the resource field of the resulting metadata against the given URL.
// If there is no URL in the request, it returns nil, nil.
func GetProtectedResourceMetadataFromHeader(ctx context.Context, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) {
// It issues a GET request to a URL discovered by parsing the WWW-Authenticate headers in the given request.
// Per RFC 9728 section 3.3, it validates that the resource field of the resulting metadata
// matches the serverURL (the URL that the client used to make the original request to the resource server).
// If there is no metadata URL in the header, it returns nil, nil.
func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL string, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) {
defer util.Wrapf(&err, "GetProtectedResourceMetadataFromHeader")
headers := header[http.CanonicalHeaderKey("WWW-Authenticate")]
if len(headers) == 0 {
Expand All @@ -66,11 +67,11 @@ func GetProtectedResourceMetadataFromHeader(ctx context.Context, header http.Hea
if err != nil {
return nil, err
}
url := ResourceMetadataURL(cs)
if url == "" {
metadataURL := ResourceMetadataURL(cs)
if metadataURL == "" {
return nil, nil
}
return getPRM(ctx, url, c, url)
return getPRM(ctx, metadataURL, c, serverURL)
}

// getPRM makes a GET request to the given URL, and validates the response.
Expand All @@ -83,7 +84,7 @@ func getPRM(ctx context.Context, purl string, c *http.Client, wantResource strin
if err != nil {
return nil, err
}
// Validate the Resource field to thwart impersonation attacks (section 3.3).
// Validate the Resource field (see RFC 9728, section 3.3).
if prm.Resource != wantResource {
return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, wantResource)
}
Expand Down