Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,12 @@ func applyPolicy(metadata any, policy MetadataPolicy, ownTag string) (any, error
// FindEntityMetadata finds metadata for the specified entity type in the
// metadata and decodes it into the provided metadata object.
func (m *Metadata) FindEntityMetadata(entityType string, metadata any) error {
// Validate that metadata is a pointer
metadataValue := reflect.ValueOf(metadata)
if metadataValue.Kind() != reflect.Ptr || metadataValue.IsNil() {
return errors.New("metadata parameter must be a non-nil pointer")
}

// Check if the entity type indicates one of the explicit struct fields.
v := reflect.ValueOf(m)
t := v.Elem().Type()
Expand All @@ -329,16 +335,23 @@ func (m *Metadata) FindEntityMetadata(entityType string, metadata any) error {
if j != entityType {
continue
}
if j == entityType {
fmt.Printf("found entity type %s\n", entityType)
}

value := v.Elem().FieldByName(t.Field(i).Name)
if value.IsZero() {
continue
}

metadata = value.Interface()
// Get the field value and set it to the metadata parameter
fieldValue := value.Interface()
sourceValue := reflect.ValueOf(fieldValue).Elem()

// Create a new instance of the same type as the field
targetValue := metadataValue.Elem()
if !sourceValue.Type().AssignableTo(targetValue.Type()) {
return errors.Errorf("cannot assign %v to %v", sourceValue.Type(), targetValue.Type())
}

targetValue.Set(sourceValue)
return nil
}

Expand All @@ -353,10 +366,13 @@ func (m *Metadata) FindEntityMetadata(entityType string, metadata any) error {
// struct so we can use RTTI to give the caller a richer representation.
jsonMetadata, err := json.Marshal(metadataMap)
if err != nil {
return errors.Errorf("failed to marshal metadata: %s", err)
return errors.Wrapf(err, "failed to marshal metadata")
}

return json.Unmarshal(jsonMetadata, metadata)
// Unmarshal the JSON data into the new instance
if err = json.Unmarshal(jsonMetadata, metadata); err != nil {
return errors.Wrapf(err, "failed to unmarshal metadata")
}
return nil
}

// OAuthClientMetadata is a type for holding the metadata about an oauth client
Expand Down
97 changes: 89 additions & 8 deletions metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oidfed
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"testing"

Expand Down Expand Up @@ -333,48 +334,128 @@ func TestMetadata_UnmarshalJSON(t *testing.T) {
}

func TestMetadata_FindEntityMetadata(t *testing.T) {
type AnotherEntityMetadata struct{}
type AnotherEntityMetadata struct {
AKey string `json:"a-key"`
}
var metadata Metadata
if err := json.Unmarshal(metadataMarshalData["extra metadata"].Data, &metadata); err != nil {
t.Fatal(err)
}

testCases := map[string]struct {
metadataType string
deserializeInto any
deserializeInto string
shouldSucceed bool
expectedResult any
}{
"Metadata is present and in an explicit struct field": {
metadataType: "federation_entity",
deserializeInto: FederationEntityMetadata{},
deserializeInto: "federation_entity_metadata",
shouldSucceed: true,
expectedResult: FederationEntityMetadata{
FederationFetchEndpoint: "https://federation.endpoint/fetch",
wasSet: map[string]bool{
"FederationFetchEndpoint": true,
},
},
},
"Metadata handle nil pointer parameter": {
metadataType: "federation_entity",
deserializeInto: "",
shouldSucceed: false,
expectedResult: nil,
},
"Metadata present in Extra but incorrectly typed": {
metadataType: "another-entity",
deserializeInto: "struct name",
shouldSucceed: true,
expectedResult: struct{ Name string }{},
},
"Struct field present but empty": {
metadataType: "openid_provider",
deserializeInto: "openid_provider_metadata",
shouldSucceed: false,
expectedResult: OpenIDProviderMetadata{},
},
"Extra metadata with invalid entity type": {
metadataType: "invalid-extra",
deserializeInto: "struct",
shouldSucceed: false,
expectedResult: struct{}{},
},
"Metadata is present and in extra metadata": {
metadataType: "another-entity",
deserializeInto: AnotherEntityMetadata{},
deserializeInto: "another_entity_metadata",
shouldSucceed: true,
expectedResult: AnotherEntityMetadata{
AKey: "a-value",
},
},
"Metadata is absent and would be in an explicit struct field": {
metadataType: "openid_provider",
deserializeInto: OpenIDProviderMetadata{},
deserializeInto: "openid_provider_metadata",
shouldSucceed: false,
expectedResult: OpenIDProviderMetadata{},
},
"Metadata is absent and would be in extra metadata": {
metadataType: "no-such-metadata",
deserializeInto: struct{}{},
deserializeInto: "struct",
shouldSucceed: false,
expectedResult: struct{}{},
},
}

for name, testCase := range testCases {
t.Run(
name, func(t *testing.T) {
err := metadata.FindEntityMetadata(testCase.metadataType, &testCase.deserializeInto)
var err error
var result any
switch testCase.deserializeInto {
case "federation_entity_metadata":
var deserialize FederationEntityMetadata
err = metadata.FindEntityMetadata(testCase.metadataType, &deserialize)
result = deserialize
case "openid_provider_metadata":
var deserialize OpenIDProviderMetadata
err = metadata.FindEntityMetadata(testCase.metadataType, &deserialize)
result = deserialize
case "another_entity_metadata":
var deserialize AnotherEntityMetadata
err = metadata.FindEntityMetadata(testCase.metadataType, &deserialize)
result = deserialize
case "struct":
var deserialize struct{}
err = metadata.FindEntityMetadata(testCase.metadataType, &deserialize)
result = deserialize
case "struct name":
var deserialize struct{ Name string }
err = metadata.FindEntityMetadata(testCase.metadataType, &deserialize)
result = deserialize
default:
err = metadata.FindEntityMetadata(testCase.metadataType, nil)
}
fmt.Printf("Result: %T %+v\n", result, result)

if testCase.shouldSucceed && err != nil {
t.Error(err)
} else if !testCase.shouldSucceed && err == nil {
t.Errorf("finding %s metadata should fail", testCase.metadataType)
t.Logf("%+v", testCase.deserializeInto)
if result != nil {
// t.Logf("%+v", reflect.ValueOf(testValue).Elem().Interface())
t.Logf("%+v", result)
}
}

if testCase.expectedResult != nil && result != nil {
if !reflect.DeepEqual(result, testCase.expectedResult) {
t.Errorf(
"Result not as expected.\nExpected: %T: %+v\n Got: %T: %+v",
testCase.expectedResult,
testCase.expectedResult,
result,
result,
)
}
}
},
)
Expand Down