Skip to content

Commit 55aa0bb

Browse files
Merge pull request #386 from sabre1041/fix-transitive-azure-members
Refactor transitive member management and improvements to group querying
2 parents aed4f3b + 061c6ca commit 55aa0bb

File tree

1 file changed

+45
-39
lines changed

1 file changed

+45
-39
lines changed

pkg/syncer/azure.go

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
nethttp "net/http"
1414

15+
abstractions "github.com/microsoft/kiota-abstractions-go"
1516
userv1 "github.com/openshift/api/user/v1"
1617
redhatcopv1alpha1 "github.com/redhat-cop/group-sync-operator/api/v1alpha1"
1718
"github.com/redhat-cop/group-sync-operator/pkg/constants"
@@ -26,7 +27,6 @@ import (
2627

2728
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
2829
azidentity "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
29-
abstractions "github.com/microsoft/kiota-abstractions-go"
3030
az "github.com/microsoft/kiota-authentication-azure-go"
3131
kiota "github.com/microsoft/kiota-http-go"
3232
msgraphsdk "github.com/microsoftgraph/msgraph-sdk-go"
@@ -36,10 +36,10 @@ import (
3636
)
3737

3838
var (
39-
azureLogger = logf.Log.WithName("syncer_azure")
40-
caser = cases.Title(language.Und, cases.NoLower)
41-
azurePageSize = int32(999)
42-
transitiveMemberHeaders = abstractions.NewRequestHeaders()
39+
azureLogger = logf.Log.WithName("syncer_azure")
40+
caser = cases.Title(language.Und, cases.NoLower)
41+
azurePageSize = int32(999)
42+
headers = abstractions.NewRequestHeaders()
4343
)
4444

4545
const (
@@ -73,7 +73,7 @@ func (a *AzureSyncer) Init() bool {
7373
a.CachedGroups = make(map[string]*graph.Group)
7474
a.CachedGroupUsers = make(map[string][]*graph.User)
7575
a.Context = context.Background()
76-
transitiveMemberHeaders.Add("ConsistencyLevel", "eventual")
76+
headers.Add("ConsistencyLevel", "eventual")
7777

7878
return false
7979
}
@@ -219,7 +219,6 @@ func (a *AzureSyncer) Sync() ([]userv1.Group, error) {
219219
filter := fmt.Sprintf("displayName eq '%s'", baseGroup)
220220
groupRequestParameters := &msgroups.GroupsRequestBuilderGetQueryParameters{
221221
Filter: &filter,
222-
Top: &azurePageSize,
223222
}
224223

225224
groupRequestConfiguration := &msgroups.GroupsRequestBuilderGetRequestConfiguration{
@@ -249,37 +248,35 @@ func (a *AzureSyncer) Sync() ([]userv1.Group, error) {
249248
// Add Base Group
250249
aadGroups = append(aadGroups, baseGroupResult[0])
251250

252-
var baseGroupMembersRequestConfiguration *msgroups.ItemMembersRequestBuilderGetRequestConfiguration
251+
var baseGroupMembersRequestConfiguration *msgroups.ItemMembersGraphGroupRequestBuilderGetRequestConfiguration
253252

254253
if a.Provider.Filter != "" {
255-
requestParameters := &msgroups.ItemMembersRequestBuilderGetQueryParameters{
254+
requestParameters := &msgroups.ItemMembersGraphGroupRequestBuilderGetQueryParameters{
256255
Filter: &a.Provider.Filter,
257256
Top: &azurePageSize,
258257
}
259-
baseGroupMembersRequestConfiguration = &msgroups.ItemMembersRequestBuilderGetRequestConfiguration{
258+
baseGroupMembersRequestConfiguration = &msgroups.ItemMembersGraphGroupRequestBuilderGetRequestConfiguration{
260259
QueryParameters: requestParameters,
261260
}
262261

263262
}
264263

265-
baseGroupMembersRequest, err := a.Client.GroupsById(*baseGroupResult[0].GetId()).Members().Get(a.Context, baseGroupMembersRequestConfiguration)
264+
baseGroupMembersRequest, err := a.Client.GroupsById(*baseGroupResult[0].GetId()).Members().GraphGroup().Get(a.Context, baseGroupMembersRequestConfiguration)
266265

267266
if err != nil {
268267
azureLogger.Error(err, "Failed to get base group members", "Provider", a.Name, "Base Group", baseGroup)
269268
return nil, err
270269
}
271270

272-
pageIterator, err := msgraphcore.NewPageIterator[interface{}](baseGroupMembersRequest, &a.Adapter.GraphRequestAdapterBase, graph.CreateGroupCollectionResponseFromDiscriminatorValue)
271+
pageIterator, err := msgraphcore.NewPageIterator[*graph.Group](baseGroupMembersRequest, &a.Adapter.GraphRequestAdapterBase, graph.CreateGroupCollectionResponseFromDiscriminatorValue)
273272

274273
if err != nil {
275274
return nil, err
276275
}
277276

278-
err = pageIterator.Iterate(a.Context, func(pageItem interface{}) bool {
277+
err = pageIterator.Iterate(a.Context, func(group *graph.Group) bool {
279278

280-
if member, ok := pageItem.(*graph.Group); ok {
281-
aadGroups = append(aadGroups, *member)
282-
}
279+
aadGroups = append(aadGroups, *group)
283280
return true
284281
})
285282

@@ -292,7 +289,11 @@ func (a *AzureSyncer) Sync() ([]userv1.Group, error) {
292289

293290
} else {
294291

295-
var groupConfiguration = msgroups.GroupsRequestBuilderGetRequestConfiguration{}
292+
var groupConfiguration = msgroups.GroupsRequestBuilderGetRequestConfiguration{
293+
QueryParameters: &msgroups.GroupsRequestBuilderGetQueryParameters{
294+
Top: &azurePageSize,
295+
},
296+
}
296297

297298
if a.Provider.Filter != "" {
298299
groupRequestParameters := &msgroups.GroupsRequestBuilderGetQueryParameters{
@@ -391,50 +392,56 @@ func (a *AzureSyncer) listGroupMembers(groupID *string) ([]string, error) {
391392
selectParameter = []string{GraphUserNameAttribute}
392393
}
393394

394-
queryParameters := msgroups.ItemTransitiveMembersRequestBuilderGetQueryParameters{
395+
queryParameters := msgroups.ItemTransitiveMembersGraphUserRequestBuilderGetQueryParameters{
395396
Select: selectParameter,
397+
Top: &azurePageSize,
396398
Count: &truthy,
397399
}
398400

399-
transitiveMembersGetConfiguration := msgroups.ItemTransitiveMembersRequestBuilderGetRequestConfiguration{
401+
transitiveMembersConfiguration := msgroups.ItemTransitiveMembersGraphUserRequestBuilderGetRequestConfiguration{
400402
QueryParameters: &queryParameters,
401-
Headers: transitiveMemberHeaders,
403+
Headers: headers,
402404
}
403405

404-
memberRequest, err := a.Client.GroupsById(*groupID).TransitiveMembers().Get(a.Context, &transitiveMembersGetConfiguration)
406+
memberRequest, err := a.Client.GroupsById(*groupID).TransitiveMembers().GraphUser().Get(a.Context, &transitiveMembersConfiguration)
405407

406408
if err != nil {
407409
return nil, err
408410
}
409411

410-
pageIterator, err := msgraphcore.NewPageIterator[interface{}](memberRequest, &a.Adapter.GraphRequestAdapterBase, graph.CreateGroupCollectionResponseFromDiscriminatorValue)
412+
for {
411413

412-
if err != nil {
413-
return nil, err
414-
}
414+
for _, member := range memberRequest.GetValue() {
415+
if username, found := a.getUsernameForUser(member); found {
416+
groupMembers = append(groupMembers, fmt.Sprintf("%v", username))
417+
}
418+
}
415419

416-
err = pageIterator.Iterate(a.Context, func(pageItem interface{}) bool {
420+
nextPageUrl := memberRequest.GetOdataNextLink()
421+
if nextPageUrl != nil {
422+
transitiveMembersConfiguration := msgroups.ItemTransitiveMembersGraphUserRequestBuilderGetRequestConfiguration{
423+
Headers: headers,
424+
}
417425

418-
if member, ok := pageItem.(*graph.User); ok {
419-
if username, found := a.getUsernameForUser(*member); found {
420-
groupMembers = append(groupMembers, fmt.Sprintf("%v", username))
426+
memberRequest, err = msgroups.NewItemTransitiveMembersGraphUserRequestBuilder(*nextPageUrl, a.Client.GetAdapter()).Get(context.Background(), &transitiveMembersConfiguration)
427+
428+
if err != nil {
429+
azureLogger.Error(err, "Failed to get iterate over group members", "Provider", a.Name, "Group ID", groupID)
430+
return nil, err
421431
}
432+
} else {
433+
break
422434
}
423-
return true
424-
})
425435

426-
if err != nil {
427-
azureLogger.Error(err, "Failed to get iterate over group members", "Provider", a.Name, "Group ID", groupID)
428-
return nil, err
429436
}
430437

431438
return groupMembers, nil
432439

433440
}
434441

435-
func (a *AzureSyncer) getUsernameForUser(user graph.User) (string, bool) {
442+
func (a *AzureSyncer) getUsernameForUser(user graph.Userable) (string, bool) {
436443

437-
userValue := reflect.ValueOf(&user)
444+
userValue := reflect.ValueOf(user)
438445

439446
if a.Provider.UserNameAttributes == nil {
440447
return a.isUsernamePresent(userValue, GraphUserNameAttribute)
@@ -490,14 +497,13 @@ func getAuthorityHost(authorityHost *string) string {
490497
func (a *AzureSyncer) getGroupsFromResults(result graph.GroupCollectionResponseable) ([]graph.Group, error) {
491498
groups := []graph.Group{}
492499

493-
pageIterator, err := msgraphcore.NewPageIterator[interface{}](result, &a.Adapter.GraphRequestAdapterBase, graph.CreateGroupCollectionResponseFromDiscriminatorValue)
500+
pageIterator, err := msgraphcore.NewPageIterator[*graph.Group](result, &a.Adapter.GraphRequestAdapterBase, graph.CreateGroupCollectionResponseFromDiscriminatorValue)
494501

495502
if err != nil {
496503
return nil, err
497504
}
498505

499-
iterateErr := pageIterator.Iterate(a.Context, func(pageItem interface{}) bool {
500-
group := pageItem.(*graph.Group)
506+
iterateErr := pageIterator.Iterate(a.Context, func(group *graph.Group) bool {
501507
groups = append(groups, *group)
502508
return true
503509
})

0 commit comments

Comments
 (0)