diff --git a/trustchain.go b/trustchain.go index cca41ed..5ea2873 100644 --- a/trustchain.go +++ b/trustchain.go @@ -1,6 +1,8 @@ package oidfed import ( + "reflect" + "github.com/pkg/errors" "github.com/vmihailenco/msgpack/v5" "github.com/zachmann/go-utils/maputils" @@ -89,9 +91,16 @@ func (c TrustChain) Metadata() (*Metadata, error) { if err != nil { return nil, err } + metadataFromSuperior := c[1].Metadata m := c[0].Metadata if m == nil { - m = &Metadata{} + if metadataFromSuperior == nil { + m = &Metadata{} + } else { + m = metadataFromSuperior + } + } else if metadataFromSuperior != nil { + mergeMetadata(m, metadataFromSuperior) } final, err := m.ApplyPolicy(combinedPolicy) if err != nil { @@ -103,6 +112,150 @@ func (c TrustChain) Metadata() (*Metadata, error) { return final, nil } +// mergeMetadata merges values from source into target, with source values taking precedence. +// Any value set in source will overwrite the corresponding value in target. +func mergeMetadata(target, source *Metadata) { + if source == nil { + return + } + + targetVal := reflect.ValueOf(target).Elem() + sourceVal := reflect.ValueOf(source).Elem() + typ := targetVal.Type() + + // Iterate through all fields of Metadata struct + for i := 0; i < targetVal.NumField(); i++ { + fieldName := typ.Field(i).Name + + // Skip the Extra field as it needs special handling + if fieldName == "Extra" { + continue + } + + targetField := targetVal.Field(i) + sourceField := sourceVal.Field(i) + + // Only proceed if source field is not nil + if sourceField.Kind() == reflect.Ptr && !sourceField.IsNil() { + if targetField.IsNil() { + // If target field is nil, just copy the source field + targetField.Set(sourceField) + } else { + // Both fields are non-nil pointers to structs, merge their values + mergeStructFields(targetField.Elem(), sourceField.Elem()) + } + } + } + + // Handle Extra field separately + if source.Extra != nil { + if target.Extra == nil { + target.Extra = make(map[string]any) + } + for k, v := range source.Extra { + target.Extra[k] = v + } + } +} + +// mergeStructFields merges values from source struct into target struct using reflection. +// Any field set in source will overwrite the corresponding field in target. +func mergeStructFields(target, source reflect.Value) { + // Get the wasSet map from source if it exists + var wasSetMap map[string]bool + if wasSetField := source.FieldByName("wasSet"); wasSetField.IsValid() && wasSetField.CanInterface() { + if m, ok := wasSetField.Interface().(map[string]bool); ok { + wasSetMap = m + } + } + + // Get the wasSet map from target if it exists + var targetWasSetMap map[string]bool + if targetWasSetField := target.FieldByName("wasSet"); targetWasSetField.IsValid() && targetWasSetField.CanInterface() { + if m, ok := targetWasSetField.Interface().(map[string]bool); ok { + targetWasSetMap = m + } + } else if targetWasSetField = target.FieldByName("wasSet"); targetWasSetField.IsValid() && targetWasSetField.CanSet() { + // If target has a wasSet field but it's nil, initialize it + newMap := make(map[string]bool) + targetWasSetField.Set(reflect.ValueOf(newMap)) + targetWasSetMap = newMap + } + + typ := source.Type() + // Iterate through all fields of the struct + for i := 0; i < source.NumField(); i++ { + field := typ.Field(i) + fieldName := field.Name + + // Skip the wasSet field + if fieldName == "wasSet" { + continue + } + + sourceField := source.Field(i) + targetField := target.FieldByName(fieldName) + + // If field doesn't exist in target or can't be set, skip it + if !targetField.IsValid() || !targetField.CanSet() { + continue + } + + // Check if this field was explicitly set in source + fieldWasSet := wasSetMap == nil || wasSetMap[fieldName] + + // Only overwrite if the field was set in source + if fieldWasSet && !sourceField.IsZero() { + // Handle different field types + if sourceField.Kind() == reflect.Map && targetField.Kind() == reflect.Map { + // For maps, merge the contents + if targetField.IsNil() { + targetField.Set(reflect.MakeMap(targetField.Type())) + } + + for _, key := range sourceField.MapKeys() { + value := sourceField.MapIndex(key) + targetField.SetMapIndex(key, value) + } + } else { + // For other types, just copy the value + targetField.Set(sourceField) + } + + // Update the wasSet map in target if it exists + if targetWasSetMap != nil { + targetWasSetMap[fieldName] = true + } + } + } + + // Handle Extra field separately if it exists + extraField := source.FieldByName("Extra") + if extraField.IsValid() && !extraField.IsNil() { + targetExtraField := target.FieldByName("Extra") + if targetExtraField.IsValid() { + if targetExtraField.IsNil() { + targetExtraField.Set(reflect.MakeMap(targetExtraField.Type())) + } + + // Copy all keys from source.Extra to target.Extra + for _, key := range extraField.MapKeys() { + value := extraField.MapIndex(key) + targetExtraField.SetMapIndex(key, value) + } + + // Mark these fields as set in wasSet if needed + if targetWasSetMap != nil { + for _, key := range extraField.MapKeys() { + if k, ok := key.Interface().(string); ok { + targetWasSetMap[k] = true + } + } + } + } + } +} + // Messages returns the jwts of the TrustChain func (c TrustChain) Messages() (msgs JWSMessages) { for _, cc := range c { diff --git a/trustchain_test.go b/trustchain_test.go index ef27044..d988d89 100644 --- a/trustchain_test.go +++ b/trustchain_test.go @@ -8,6 +8,156 @@ import ( "github.com/go-oidfed/lib/unixtime" ) +func TestMergeStructFields(t *testing.T) { + type NestedStruct struct { + FieldA int + FieldB *string + } + + type TestStruct struct { + SimpleField int + SliceField []int + MapField map[string]int + PtrField *NestedStruct + wasSet map[string]bool + Extra map[string]interface{} + } + + strValue := "test" + tests := []struct { + name string + target TestStruct + source TestStruct + expected TestStruct + }{ + { + name: "merge simple fields", + target: TestStruct{ + SimpleField: 1, + MapField: map[string]int{"key1": 1}, + }, + source: TestStruct{ + SimpleField: 5, + }, + expected: TestStruct{ + SimpleField: 5, + MapField: map[string]int{"key1": 1}, + }, + }, + { + name: "merge slice fields", + target: TestStruct{ + SliceField: []int{ + 1, + 2, + }, + MapField: map[string]int{"key1": 1}, + }, + source: TestStruct{ + SliceField: []int{ + 2, + 3, + }, + }, + expected: TestStruct{ + SliceField: []int{ + 2, + 3, + }, + MapField: map[string]int{"key1": 1}, + }, + }, + { + name: "merge map fields", + target: TestStruct{ + MapField: map[string]int{"key1": 1}, + }, + source: TestStruct{ + MapField: map[string]int{"key2": 2}, + }, + expected: TestStruct{ + MapField: map[string]int{ + "key1": 1, + "key2": 2, + }, + }, + }, + { + name: "merge pointer fields", + target: TestStruct{ + PtrField: &NestedStruct{ + FieldA: 10, + }, + }, + source: TestStruct{ + PtrField: &NestedStruct{ + FieldA: 20, + FieldB: &strValue, + }, + }, + expected: TestStruct{ + PtrField: &NestedStruct{ + FieldA: 20, + FieldB: &strValue, + }, + }, + }, + { + name: "merge with Extra field", + target: TestStruct{ + Extra: map[string]interface{}{ + "key1": "value1", + "key2": "value1", + }, + }, + source: TestStruct{ + Extra: map[string]interface{}{ + "key2": "value2", + "key3": "value3", + }, + }, + expected: TestStruct{ + Extra: map[string]interface{}{ + "key1": "value1", + "key2": "value2", + "key3": "value3", + }, + }, + }, + { + name: "merge wasSet map", + target: TestStruct{ + wasSet: map[string]bool{"SimpleField": true}, + }, + source: TestStruct{ + SimpleField: 15, + wasSet: map[string]bool{"SimpleField": false}, + }, + expected: TestStruct{ + SimpleField: 15, + wasSet: map[string]bool{"SimpleField": true}, + }, + }, + } + + for _, test := range tests { + t.Run( + test.name, func(t *testing.T) { + targetVal := reflect.ValueOf(&test.target).Elem() + sourceVal := reflect.ValueOf(&test.source).Elem() + + mergeStructFields(targetVal, sourceVal) + + if !reflect.DeepEqual(test.target, test.expected) { + t.Errorf( + "mergeStructFields() failed for %q: got %+v, want %+v", test.name, test.target, test.expected, + ) + } + }, + ) + } +} + func TestTrustChains_ExpiresAt(t *testing.T) { tests := []struct { name string @@ -230,3 +380,127 @@ func TestTrustChain_MetaDataPolicyCrit(t *testing.T) { ) } } + +func TestMergeMetadata_SourceNil_NoChange(t *testing.T) { + target := &Metadata{ + OpenIDProvider: &OpenIDProviderMetadata{Issuer: "issuer-target"}, + Extra: map[string]any{"k": "v"}, + } + mergeMetadata(target, nil) + if target.OpenIDProvider == nil || target.OpenIDProvider.Issuer != "issuer-target" { + t.Fatalf("OpenIDProvider changed unexpectedly: %+v", target.OpenIDProvider) + } + if got := target.Extra["k"]; got != "v" { + t.Fatalf("Extra changed unexpectedly: got %v", got) + } +} + +func TestMergeMetadata_TargetFieldNil_CopiesSource(t *testing.T) { + source := &Metadata{ + OpenIDProvider: &OpenIDProviderMetadata{ + Issuer: "issuer-source", + MTLSEndpointAliases: map[string]string{"a": "1"}, + Extra: map[string]any{"sx": 1}, + RequestAuthenticationMethodsSupported: map[string][]string{"m1": {"a"}}, + }, + Extra: map[string]any{"top": "s"}, + } + target := &Metadata{} + mergeMetadata(target, source) + if target.OpenIDProvider == nil || target.OpenIDProvider.Issuer != "issuer-source" { + t.Fatalf("expected target OpenIDProvider copied from source, got %+v", target.OpenIDProvider) + } + if target.OpenIDProvider.MTLSEndpointAliases["a"] != "1" { + t.Fatalf("expected MTLSEndpointAliases copied, got %+v", target.OpenIDProvider.MTLSEndpointAliases) + } + if target.OpenIDProvider.Extra["sx"].(int) != 1 { + t.Fatalf("expected nested Extra copied, got %+v", target.OpenIDProvider.Extra) + } + if target.Extra["top"] != "s" { + t.Fatalf("expected top-level Extra merged, got %+v", target.Extra) + } + if _, ok := target.OpenIDProvider.RequestAuthenticationMethodsSupported["m1"]; !ok { + t.Fatalf( + "expected m1 to be present in RequestAuthenticationMethodsSupported: %+v", + target.OpenIDProvider.RequestAuthenticationMethodsSupported, + ) + } +} + +func TestMergeMetadata_BothNonNil_NestedMergeAndMapsAndExtra(t *testing.T) { + target := &Metadata{ + OpenIDProvider: &OpenIDProviderMetadata{ + Issuer: "issuer-target", + Description: "desc-target", + MTLSEndpointAliases: map[string]string{"a": "1"}, + Extra: map[string]any{ + "ek1": "ev1", + "ovr": "t", + }, + RequestAuthenticationMethodsSupported: map[string][]string{"m1": {"a"}}, + }, + Extra: map[string]any{ + "x1": "t1", + "o": "t", + }, + } + source := &Metadata{ + OpenIDProvider: &OpenIDProviderMetadata{ + Issuer: "issuer-source", + Description: "", // empty should not override + MTLSEndpointAliases: map[string]string{ + "a": "3", + "b": "2", + }, + Extra: map[string]any{ + "ovr": "s", + "ns": "v", + }, + RequestAuthenticationMethodsSupported: map[string][]string{"m2": {"b"}}, + }, + Extra: map[string]any{ + "o": "s", + "x2": "s2", + }, + } + mergeMetadata(target, source) + // Issuer should be overwritten + if target.OpenIDProvider.Issuer != "issuer-source" { + t.Fatalf("Issuer not overwritten, got %q", target.OpenIDProvider.Issuer) + } + // Description should remain from target because source is empty + if target.OpenIDProvider.Description != "desc-target" { + t.Fatalf("Description should not be overwritten by empty source, got %q", target.OpenIDProvider.Description) + } + // Map merge with override + wantMTLS := map[string]string{ + "a": "3", + "b": "2", + } + if !reflect.DeepEqual(target.OpenIDProvider.MTLSEndpointAliases, wantMTLS) { + t.Fatalf( + "MTLSEndpointAliases merge mismatch: got %+v, want %+v", target.OpenIDProvider.MTLSEndpointAliases, wantMTLS, + ) + } + // Nested Extra merged + if target.OpenIDProvider.Extra["ek1"] != "ev1" || target.OpenIDProvider.Extra["ovr"] != "s" || target.OpenIDProvider.Extra["ns"] != "v" { + t.Fatalf("nested Extra merge mismatch: %+v", target.OpenIDProvider.Extra) + } + // Top-level Extra merged + if target.Extra["x1"] != "t1" || target.Extra["o"] != "s" || target.Extra["x2"] != "s2" { + t.Fatalf("top-level Extra merge mismatch: %+v", target.Extra) + } + // Map[string][]string merge should add new keys and keep existing ones + if _, ok := target.OpenIDProvider.RequestAuthenticationMethodsSupported["m1"]; !ok { + t.Fatalf( + "expected m1 to remain in RequestAuthenticationMethodsSupported: %+v", + target.OpenIDProvider.RequestAuthenticationMethodsSupported, + ) + } + if _, ok := target.OpenIDProvider.RequestAuthenticationMethodsSupported["m2"]; !ok { + t.Fatalf( + "expected m2 to be added in RequestAuthenticationMethodsSupported: %+v", + target.OpenIDProvider.RequestAuthenticationMethodsSupported, + ) + } +}