Skip to content

Commit d89e522

Browse files
authored
Merge branch 'main' into feat/259/assign-reviewers
2 parents 94cef70 + 89e3afd commit d89e522

File tree

3 files changed

+276
-0
lines changed

3 files changed

+276
-0
lines changed

pkg/github/discussions_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ var (
8787
"url": "https://github.com/owner/.github/discussions/4",
8888
"category": map[string]any{"name": "General"},
8989
},
90+
9091
}
9192

9293
// Ordered mock responses

pkg/github/pullrequests.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,12 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra
321321
}
322322

323323
// Handle REST API updates
324+
}
325+
326+
if !restUpdateNeeded && !draftProvided {
327+
return mcp.NewToolResultError("No update parameters provided."), nil
328+
}
329+
324330
if restUpdateNeeded {
325331
client, err := getClient(ctx)
326332
if err != nil {
@@ -462,6 +468,90 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra
462468
}
463469
}()
464470

471+
r, err := json.Marshal(finalPR)
472+
if err != nil {
473+
}
474+
475+
if draftProvided {
476+
gqlClient, err := getGQLClient(ctx)
477+
if err != nil {
478+
return nil, fmt.Errorf("failed to get GitHub GraphQL client: %w", err)
479+
}
480+
481+
var prQuery struct {
482+
Repository struct {
483+
PullRequest struct {
484+
ID githubv4.ID
485+
IsDraft githubv4.Boolean
486+
} `graphql:"pullRequest(number: $prNum)"`
487+
} `graphql:"repository(owner: $owner, name: $repo)"`
488+
}
489+
490+
err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{
491+
"owner": githubv4.String(owner),
492+
"repo": githubv4.String(repo),
493+
"prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers
494+
})
495+
if err != nil {
496+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil
497+
}
498+
499+
currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft)
500+
501+
if currentIsDraft != draftValue {
502+
if draftValue {
503+
// Convert to draft
504+
var mutation struct {
505+
ConvertPullRequestToDraft struct {
506+
PullRequest struct {
507+
ID githubv4.ID
508+
IsDraft githubv4.Boolean
509+
}
510+
} `graphql:"convertPullRequestToDraft(input: $input)"`
511+
}
512+
513+
err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{
514+
PullRequestID: prQuery.Repository.PullRequest.ID,
515+
}, nil)
516+
if err != nil {
517+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil
518+
}
519+
} else {
520+
// Mark as ready for review
521+
var mutation struct {
522+
MarkPullRequestReadyForReview struct {
523+
PullRequest struct {
524+
ID githubv4.ID
525+
IsDraft githubv4.Boolean
526+
}
527+
} `graphql:"markPullRequestReadyForReview(input: $input)"`
528+
}
529+
530+
err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{
531+
PullRequestID: prQuery.Repository.PullRequest.ID,
532+
}, nil)
533+
if err != nil {
534+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil
535+
}
536+
}
537+
}
538+
}
539+
540+
client, err := getClient(ctx)
541+
if err != nil {
542+
return nil, err
543+
}
544+
545+
finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber)
546+
if err != nil {
547+
return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil
548+
}
549+
defer func() {
550+
if resp != nil && resp.Body != nil {
551+
_ = resp.Body.Close()
552+
}
553+
}()
554+
465555
r, err := json.Marshal(finalPR)
466556
if err != nil {
467557
return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal response: %v", err)), nil

pkg/github/pullrequests_test.go

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,191 @@ func Test_UpdatePullRequest_Draft(t *testing.T) {
622622
}
623623
}
624624

625+
func Test_UpdatePullRequest_Draft(t *testing.T) {
626+
// Setup mock PR for success case
627+
mockUpdatedPR := &github.PullRequest{
628+
Number: github.Ptr(42),
629+
Title: github.Ptr("Test PR Title"),
630+
State: github.Ptr("open"),
631+
HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"),
632+
Body: github.Ptr("Test PR body."),
633+
MaintainerCanModify: github.Ptr(false),
634+
Draft: github.Ptr(false), // Updated to ready for review
635+
Base: &github.PullRequestBranch{
636+
Ref: github.Ptr("main"),
637+
},
638+
}
639+
640+
tests := []struct {
641+
name string
642+
mockedClient *http.Client
643+
requestArgs map[string]interface{}
644+
expectError bool
645+
expectedPR *github.PullRequest
646+
expectedErrMsg string
647+
}{
648+
{
649+
name: "successful draft update to ready for review",
650+
mockedClient: githubv4mock.NewMockedHTTPClient(
651+
githubv4mock.NewQueryMatcher(
652+
struct {
653+
Repository struct {
654+
PullRequest struct {
655+
ID githubv4.ID
656+
IsDraft githubv4.Boolean
657+
} `graphql:"pullRequest(number: $prNum)"`
658+
} `graphql:"repository(owner: $owner, name: $repo)"`
659+
}{},
660+
map[string]any{
661+
"owner": githubv4.String("owner"),
662+
"repo": githubv4.String("repo"),
663+
"prNum": githubv4.Int(42),
664+
},
665+
githubv4mock.DataResponse(map[string]any{
666+
"repository": map[string]any{
667+
"pullRequest": map[string]any{
668+
"id": "PR_kwDOA0xdyM50BPaO",
669+
"isDraft": true, // Current state is draft
670+
},
671+
},
672+
}),
673+
),
674+
githubv4mock.NewMutationMatcher(
675+
struct {
676+
MarkPullRequestReadyForReview struct {
677+
PullRequest struct {
678+
ID githubv4.ID
679+
IsDraft githubv4.Boolean
680+
}
681+
} `graphql:"markPullRequestReadyForReview(input: $input)"`
682+
}{},
683+
githubv4.MarkPullRequestReadyForReviewInput{
684+
PullRequestID: "PR_kwDOA0xdyM50BPaO",
685+
},
686+
nil,
687+
githubv4mock.DataResponse(map[string]any{
688+
"markPullRequestReadyForReview": map[string]any{
689+
"pullRequest": map[string]any{
690+
"id": "PR_kwDOA0xdyM50BPaO",
691+
"isDraft": false,
692+
},
693+
},
694+
}),
695+
),
696+
),
697+
requestArgs: map[string]interface{}{
698+
"owner": "owner",
699+
"repo": "repo",
700+
"pullNumber": float64(42),
701+
"draft": false,
702+
},
703+
expectError: false,
704+
expectedPR: mockUpdatedPR,
705+
},
706+
{
707+
name: "successful convert pull request to draft",
708+
mockedClient: githubv4mock.NewMockedHTTPClient(
709+
githubv4mock.NewQueryMatcher(
710+
struct {
711+
Repository struct {
712+
PullRequest struct {
713+
ID githubv4.ID
714+
IsDraft githubv4.Boolean
715+
} `graphql:"pullRequest(number: $prNum)"`
716+
} `graphql:"repository(owner: $owner, name: $repo)"`
717+
}{},
718+
map[string]any{
719+
"owner": githubv4.String("owner"),
720+
"repo": githubv4.String("repo"),
721+
"prNum": githubv4.Int(42),
722+
},
723+
githubv4mock.DataResponse(map[string]any{
724+
"repository": map[string]any{
725+
"pullRequest": map[string]any{
726+
"id": "PR_kwDOA0xdyM50BPaO",
727+
"isDraft": false, // Current state is draft
728+
},
729+
},
730+
}),
731+
),
732+
githubv4mock.NewMutationMatcher(
733+
struct {
734+
ConvertPullRequestToDraft struct {
735+
PullRequest struct {
736+
ID githubv4.ID
737+
IsDraft githubv4.Boolean
738+
}
739+
} `graphql:"convertPullRequestToDraft(input: $input)"`
740+
}{},
741+
githubv4.ConvertPullRequestToDraftInput{
742+
PullRequestID: "PR_kwDOA0xdyM50BPaO",
743+
},
744+
nil,
745+
githubv4mock.DataResponse(map[string]any{
746+
"convertPullRequestToDraft": map[string]any{
747+
"pullRequest": map[string]any{
748+
"id": "PR_kwDOA0xdyM50BPaO",
749+
"isDraft": true,
750+
},
751+
},
752+
}),
753+
),
754+
),
755+
requestArgs: map[string]interface{}{
756+
"owner": "owner",
757+
"repo": "repo",
758+
"pullNumber": float64(42),
759+
"draft": true,
760+
},
761+
expectError: false,
762+
expectedPR: mockUpdatedPR,
763+
},
764+
}
765+
766+
for _, tc := range tests {
767+
t.Run(tc.name, func(t *testing.T) {
768+
// For draft-only tests, we need to mock both GraphQL and the final REST GET call
769+
restClient := github.NewClient(mock.NewMockedHTTPClient(
770+
mock.WithRequestMatch(
771+
mock.GetReposPullsByOwnerByRepoByPullNumber,
772+
mockUpdatedPR,
773+
),
774+
))
775+
gqlClient := githubv4.NewClient(tc.mockedClient)
776+
777+
_, handler := UpdatePullRequest(stubGetClientFn(restClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper)
778+
779+
request := createMCPRequest(tc.requestArgs)
780+
781+
result, err := handler(context.Background(), request)
782+
783+
if tc.expectError || tc.expectedErrMsg != "" {
784+
require.NoError(t, err)
785+
require.True(t, result.IsError)
786+
errorContent := getErrorResult(t, result)
787+
if tc.expectedErrMsg != "" {
788+
assert.Contains(t, errorContent.Text, tc.expectedErrMsg)
789+
}
790+
return
791+
}
792+
793+
require.NoError(t, err)
794+
require.False(t, result.IsError)
795+
796+
textContent := getTextResult(t, result)
797+
798+
// Unmarshal and verify the successful result
799+
var returnedPR github.PullRequest
800+
err = json.Unmarshal([]byte(textContent.Text), &returnedPR)
801+
require.NoError(t, err)
802+
assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number)
803+
if tc.expectedPR.Draft != nil {
804+
assert.Equal(t, *tc.expectedPR.Draft, *returnedPR.Draft)
805+
}
806+
})
807+
}
808+
}
809+
625810
func Test_ListPullRequests(t *testing.T) {
626811
// Verify tool definition once
627812
mockClient := github.NewClient(nil)

0 commit comments

Comments
 (0)