Skip to content

Commit 0736ebe

Browse files
authored
Merge pull request #89 from go-oidfed/refactoring
Refactoring
2 parents 7137c4d + c244ee2 commit 0736ebe

File tree

4 files changed

+222
-123
lines changed

4 files changed

+222
-123
lines changed

.deepsource.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
version = 1
22

3+
test_patterns = [
4+
"*_test.go"
5+
]
6+
37
[[analyzers]]
48
name = "go"
59

internal/utils/slices.go

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,59 +4,78 @@ import (
44
"reflect"
55
)
66

7-
// ReflectSliceCast casts a slice to another type using reflection
7+
// ReflectSliceCast converts a slice to another type using reflection.
8+
// Parameters:
9+
// - slice: source slice to convert
10+
// - newType: target type for the slice elements
11+
//
12+
// Retur
13+
// Returns:
14+
// - converted slice or original value if input is not a slice
815
func ReflectSliceCast(slice, newType any) any {
916
if !IsSlice(slice) {
1017
return slice
1118
}
19+
1220
typeType := reflect.TypeOf(newType)
1321
sliceV := reflect.ValueOf(slice)
1422
out := reflect.MakeSlice(typeType, sliceV.Len(), sliceV.Len())
23+
1524
for i := 0; i < sliceV.Len(); i++ {
16-
vv := sliceV.Index(i)
17-
var v reflect.Value
18-
// This is stupid and has faults, but I did not find a better way
19-
switch typeType.Elem().Kind() {
20-
case reflect.Bool:
21-
v = reflect.ValueOf(vv.Interface().(bool))
22-
case reflect.Int:
23-
v = reflect.ValueOf(vv.Interface().(int))
24-
case reflect.Int8:
25-
v = reflect.ValueOf(vv.Interface().(int8))
26-
case reflect.Int16:
27-
v = reflect.ValueOf(vv.Interface().(int16))
28-
case reflect.Int32:
29-
v = reflect.ValueOf(vv.Interface().(int32))
30-
case reflect.Int64:
31-
v = reflect.ValueOf(vv.Interface().(int64))
32-
case reflect.Uint:
33-
v = reflect.ValueOf(vv.Interface().(uint))
34-
case reflect.Uint8:
35-
v = reflect.ValueOf(vv.Interface().(uint8))
36-
case reflect.Uint16:
37-
v = reflect.ValueOf(vv.Interface().(uint16))
38-
case reflect.Uint32:
39-
v = reflect.ValueOf(vv.Interface().(uint32))
40-
case reflect.Uint64:
41-
v = reflect.ValueOf(vv.Interface().(uint64))
42-
case reflect.Uintptr:
43-
v = reflect.ValueOf(vv.Interface().(*uint))
44-
case reflect.Float32:
45-
v = reflect.ValueOf(vv.Interface().(float32))
46-
case reflect.Float64:
47-
v = reflect.ValueOf(vv.Interface().(float64))
48-
case reflect.Interface:
49-
v = vv
50-
case reflect.String:
51-
v = reflect.ValueOf(vv.Interface().(string))
52-
default:
53-
v = vv.Convert(typeType.Elem())
54-
}
55-
out.Index(i).Set(v)
25+
sourceVal := sliceV.Index(i)
26+
convertedVal := convertToTargetType(sourceVal, typeType.Elem())
27+
out.Index(i).Set(convertedVal)
5628
}
29+
5730
return out.Interface()
5831
}
5932

33+
// convertToTargetType converts a reflect.Value to the target type.
34+
// It handles primitive types explicitly and falls back to generic conversion for other types.
35+
func convertToTargetType(val reflect.Value, targetType reflect.Type) reflect.Value {
36+
if targetType.Kind() == reflect.Interface {
37+
return val
38+
}
39+
40+
// Get the underlying interface value
41+
srcInterface := val.Interface()
42+
43+
// Handle primitive types
44+
switch targetType.Kind() {
45+
case reflect.Bool:
46+
return reflect.ValueOf(srcInterface.(bool))
47+
case reflect.Int:
48+
return reflect.ValueOf(srcInterface.(int))
49+
case reflect.Int8:
50+
return reflect.ValueOf(srcInterface.(int8))
51+
case reflect.Int16:
52+
return reflect.ValueOf(srcInterface.(int16))
53+
case reflect.Int32:
54+
return reflect.ValueOf(srcInterface.(int32))
55+
case reflect.Int64:
56+
return reflect.ValueOf(srcInterface.(int64))
57+
case reflect.Uint:
58+
return reflect.ValueOf(srcInterface.(uint))
59+
case reflect.Uint8:
60+
return reflect.ValueOf(srcInterface.(uint8))
61+
case reflect.Uint16:
62+
return reflect.ValueOf(srcInterface.(uint16))
63+
case reflect.Uint32:
64+
return reflect.ValueOf(srcInterface.(uint32))
65+
case reflect.Uint64:
66+
return reflect.ValueOf(srcInterface.(uint64))
67+
case reflect.Float32:
68+
return reflect.ValueOf(srcInterface.(float32))
69+
case reflect.Float64:
70+
return reflect.ValueOf(srcInterface.(float64))
71+
case reflect.String:
72+
return reflect.ValueOf(srcInterface.(string))
73+
default:
74+
// For other types, try to convert using reflection
75+
return val.Convert(targetType)
76+
}
77+
}
78+
6079
// ReflectSliceContains checks if a slice contains a value using reflection
6180
func ReflectSliceContains(v, slice any) bool {
6281
if !IsSlice(slice) {

jwx/privateKeyStorageMultiAlg.go

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -76,56 +76,25 @@ func (sks *privateKeyStorageMultiAlg) initKeyRotation(pks *pkCollection, pksOnCh
7676

7777
// Load loads the private keys from disk and if necessary generates missing keys
7878
func (sks *privateKeyStorageMultiAlg) Load(pks *pkCollection, pksOnChange func() error) error {
79-
populatePKFromSK := false
79+
addPublicKeysToJWKS := false
8080
if sks.signers == nil {
8181
sks.signers = make(map[jwa.SignatureAlgorithm]crypto.Signer)
8282
}
8383
if len(pks.jwks) == 0 {
8484
pks.jwks = []JWKS{NewJWKS()}
85-
populatePKFromSK = true
85+
addPublicKeysToJWKS = true
8686
}
8787
pksChanged := false
88-
// load oidc keys
88+
8989
for _, alg := range sks.algs {
90-
filePath := sks.keyFilePath(alg, false)
91-
signer, err := readSignerFromFile(filePath, alg)
90+
signer, changed, err := sks.loadOrGenerateSigner(alg, pks, addPublicKeysToJWKS)
9291
if err != nil {
93-
// could not load key, generating a new one for this alg
94-
sk, pk, err := generateKeyPair(
95-
alg, sks.rsaKeyLen, keyLifetimeConf{
96-
NowIssued: true,
97-
Expires: sks.rollover.Enabled,
98-
Lifetime: sks.rollover.Interval.Duration(),
99-
},
100-
)
101-
if err != nil {
102-
return err
103-
}
104-
if err = writeSignerToFile(sk, sks.keyFilePath(alg, false)); err != nil {
105-
return err
106-
}
107-
if err = pks.jwks[0].AddKey(pk); err != nil {
108-
return errors.WithStack(err)
109-
}
110-
pksChanged = true
111-
signer = sk
112-
} else if populatePKFromSK {
113-
pk, err := signerToPublicJWK(
114-
signer, alg, keyLifetimeConf{
115-
NowIssued: false,
116-
Expires: sks.rollover.Enabled,
117-
Lifetime: sks.rollover.Interval.Duration(),
118-
},
119-
)
120-
if err != nil {
121-
return err
122-
}
123-
if err = pks.jwks[0].AddKey(pk); err != nil {
124-
return errors.WithStack(err)
125-
}
92+
return err
12693
}
94+
pksChanged = pksChanged || changed
12795
sks.signers[alg] = signer
12896

97+
// Ensure the next key file exists for rollover
12998
if !fileutils.FileExists(sks.keyFilePath(alg, true)) {
13099
_, err = generateStoreAndSetNextPrivateKey(
131100
pks, alg, sks.rsaKeyLen, keyLifetimeConf{
@@ -140,7 +109,8 @@ func (sks *privateKeyStorageMultiAlg) Load(pks *pkCollection, pksOnChange func()
140109
}
141110
}
142111
}
143-
if populatePKFromSK || pksChanged {
112+
113+
if addPublicKeysToJWKS || pksChanged {
144114
if err := pksOnChange(); err != nil {
145115
return err
146116
}
@@ -149,6 +119,49 @@ func (sks *privateKeyStorageMultiAlg) Load(pks *pkCollection, pksOnChange func()
149119
return nil
150120
}
151121

122+
// loadOrGenerateSigner loads a signer from disk or generates a new one if it doesn't exist.
123+
// If addPublicKeysToJWKS is true, it also adds the public key to the pkCollection.
124+
func (sks *privateKeyStorageMultiAlg) loadOrGenerateSigner(
125+
alg jwa.SignatureAlgorithm, pks *pkCollection, addPublicKeysToJWKS bool,
126+
) (crypto.Signer, bool, error) {
127+
filePath := sks.keyFilePath(alg, false)
128+
signer, err := readSignerFromFile(filePath, alg)
129+
if err != nil {
130+
// Could not load key, generating a new one for this alg
131+
sk, pk, err := generateKeyPair(
132+
alg,
133+
sks.rsaKeyLen,
134+
keyLifetimeConf{
135+
NowIssued: true,
136+
Expires: sks.rollover.Enabled,
137+
Lifetime: sks.rollover.Interval.Duration(),
138+
},
139+
)
140+
if err != nil {
141+
return nil, false, err
142+
}
143+
if err = writeSignerToFile(sk, filePath); err != nil {
144+
return nil, false, err
145+
}
146+
pks.addCurrentJWK(pk)
147+
return sk, true, nil
148+
}
149+
if addPublicKeysToJWKS {
150+
pk, err := signerToPublicJWK(
151+
signer, alg, keyLifetimeConf{
152+
NowIssued: false,
153+
Expires: sks.rollover.Enabled,
154+
Lifetime: sks.rollover.Interval.Duration(),
155+
},
156+
)
157+
if err != nil {
158+
return nil, false, err
159+
}
160+
pks.addCurrentJWK(pk)
161+
}
162+
return signer, addPublicKeysToJWKS, nil
163+
}
164+
152165
// GenerateNewKeys generates a new set of keys
153166
func (sks *privateKeyStorageMultiAlg) GenerateNewKeys(pks *pkCollection, pksOnChange func() error) error {
154167
futureKeys := NewJWKS()

0 commit comments

Comments
 (0)