Skip to content

Commit be359ea

Browse files
committed
test: add unit tests for updating pull request draft state
1 parent 5ea322b commit be359ea

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed

pkg/github/pullrequests_test.go

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,207 @@ func Test_UpdatePullRequest(t *testing.T) {
437437
}
438438
}
439439

440+
func Test_UpdatePullRequest_Draft(t *testing.T) {
441+
// Setup mock PR for success case
442+
mockUpdatedPR := &github.PullRequest{
443+
Number: github.Ptr(42),
444+
Title: github.Ptr("Test PR Title"),
445+
State: github.Ptr("open"),
446+
HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"),
447+
Body: github.Ptr("Test PR body."),
448+
MaintainerCanModify: github.Ptr(false),
449+
Draft: github.Ptr(false), // Updated to ready for review
450+
Base: &github.PullRequestBranch{
451+
Ref: github.Ptr("main"),
452+
},
453+
}
454+
455+
tests := []struct {
456+
name string
457+
mockedClient *http.Client
458+
requestArgs map[string]interface{}
459+
expectError bool
460+
expectedPR *github.PullRequest
461+
expectedErrMsg string
462+
}{
463+
{
464+
name: "successful draft update to ready for review",
465+
mockedClient: githubv4mock.NewMockedHTTPClient(
466+
githubv4mock.NewQueryMatcher(
467+
struct {
468+
Repository struct {
469+
PullRequest struct {
470+
ID githubv4.ID
471+
IsDraft githubv4.Boolean
472+
} `graphql:"pullRequest(number: $prNum)"`
473+
} `graphql:"repository(owner: $owner, name: $repo)"`
474+
}{},
475+
map[string]any{
476+
"owner": githubv4.String("owner"),
477+
"repo": githubv4.String("repo"),
478+
"prNum": githubv4.Int(42),
479+
},
480+
githubv4mock.DataResponse(map[string]any{
481+
"repository": map[string]any{
482+
"pullRequest": map[string]any{
483+
"id": "PR_kwDOA0xdyM50BPaO",
484+
"isDraft": true, // Current state is draft
485+
},
486+
},
487+
}),
488+
),
489+
githubv4mock.NewMutationMatcher(
490+
struct {
491+
MarkPullRequestReadyForReview struct {
492+
PullRequest struct {
493+
ID githubv4.ID
494+
IsDraft githubv4.Boolean
495+
}
496+
} `graphql:"markPullRequestReadyForReview(input: $input)"`
497+
}{},
498+
githubv4.MarkPullRequestReadyForReviewInput{
499+
PullRequestID: "PR_kwDOA0xdyM50BPaO",
500+
},
501+
nil,
502+
githubv4mock.DataResponse(map[string]any{
503+
"markPullRequestReadyForReview": map[string]any{
504+
"pullRequest": map[string]any{
505+
"id": "PR_kwDOA0xdyM50BPaO",
506+
"isDraft": false,
507+
},
508+
},
509+
}),
510+
),
511+
),
512+
requestArgs: map[string]interface{}{
513+
"owner": "owner",
514+
"repo": "repo",
515+
"pullNumber": float64(42),
516+
"draft": false,
517+
},
518+
expectError: false,
519+
expectedPR: mockUpdatedPR,
520+
},
521+
{
522+
name: "successful convert pull request to draft",
523+
mockedClient: githubv4mock.NewMockedHTTPClient(
524+
githubv4mock.NewQueryMatcher(
525+
struct {
526+
Repository struct {
527+
PullRequest struct {
528+
ID githubv4.ID
529+
IsDraft githubv4.Boolean
530+
} `graphql:"pullRequest(number: $prNum)"`
531+
} `graphql:"repository(owner: $owner, name: $repo)"`
532+
}{},
533+
map[string]any{
534+
"owner": githubv4.String("owner"),
535+
"repo": githubv4.String("repo"),
536+
"prNum": githubv4.Int(42),
537+
},
538+
githubv4mock.DataResponse(map[string]any{
539+
"repository": map[string]any{
540+
"pullRequest": map[string]any{
541+
"id": "PR_kwDOA0xdyM50BPaO",
542+
"isDraft": false, // Current state is draft
543+
},
544+
},
545+
}),
546+
),
547+
githubv4mock.NewMutationMatcher(
548+
struct {
549+
ConvertPullRequestToDraft struct {
550+
PullRequest struct {
551+
ID githubv4.ID
552+
IsDraft githubv4.Boolean
553+
}
554+
} `graphql:"convertPullRequestToDraft(input: $input)"`
555+
}{},
556+
githubv4.ConvertPullRequestToDraftInput{
557+
PullRequestID: "PR_kwDOA0xdyM50BPaO",
558+
},
559+
nil,
560+
githubv4mock.DataResponse(map[string]any{
561+
"convertPullRequestToDraft": map[string]any{
562+
"pullRequest": map[string]any{
563+
"id": "PR_kwDOA0xdyM50BPaO",
564+
"isDraft": true,
565+
},
566+
},
567+
}),
568+
),
569+
),
570+
requestArgs: map[string]interface{}{
571+
"owner": "owner",
572+
"repo": "repo",
573+
"pullNumber": float64(42),
574+
"draft": true,
575+
},
576+
expectError: false,
577+
expectedPR: mockUpdatedPR,
578+
},
579+
}
580+
581+
for _, tc := range tests {
582+
t.Run(tc.name, func(t *testing.T) {
583+
// For draft-only tests, we need to mock both GraphQL and the final REST GET call
584+
restClient := github.NewClient(mock.NewMockedHTTPClient(
585+
mock.WithRequestMatch(
586+
mock.GetReposPullsByOwnerByRepoByPullNumber,
587+
mockUpdatedPR,
588+
),
589+
))
590+
gqlClient := githubv4.NewClient(tc.mockedClient)
591+
592+
_, handler := UpdatePullRequest(stubGetClientFn(restClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper)
593+
594+
request := createMCPRequest(tc.requestArgs)
595+
596+
result, err := handler(context.Background(), request)
597+
598+
if tc.expectError || tc.expectedErrMsg != "" {
599+
require.NoError(t, err)
600+
require.True(t, result.IsError)
601+
errorContent := getErrorResult(t, result)
602+
if tc.expectedErrMsg != "" {
603+
assert.Contains(t, errorContent.Text, tc.expectedErrMsg)
604+
}
605+
return
606+
}
607+
608+
require.NoError(t, err)
609+
require.False(t, result.IsError)
610+
611+
textContent := getTextResult(t, result)
612+
613+
// Unmarshal and verify the successful result
614+
var returnedPR github.PullRequest
615+
err = json.Unmarshal([]byte(textContent.Text), &returnedPR)
616+
require.NoError(t, err)
617+
assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number)
618+
if tc.expectedPR.Title != nil {
619+
assert.Equal(t, *tc.expectedPR.Title, *returnedPR.Title)
620+
}
621+
if tc.expectedPR.Body != nil {
622+
assert.Equal(t, *tc.expectedPR.Body, *returnedPR.Body)
623+
}
624+
if tc.expectedPR.State != nil {
625+
assert.Equal(t, *tc.expectedPR.State, *returnedPR.State)
626+
}
627+
if tc.expectedPR.Base != nil && tc.expectedPR.Base.Ref != nil {
628+
assert.NotNil(t, returnedPR.Base)
629+
assert.Equal(t, *tc.expectedPR.Base.Ref, *returnedPR.Base.Ref)
630+
}
631+
if tc.expectedPR.MaintainerCanModify != nil {
632+
assert.Equal(t, *tc.expectedPR.MaintainerCanModify, *returnedPR.MaintainerCanModify)
633+
}
634+
if tc.expectedPR.Draft != nil {
635+
assert.Equal(t, *tc.expectedPR.Draft, *returnedPR.Draft)
636+
}
637+
})
638+
}
639+
}
640+
440641
func Test_ListPullRequests(t *testing.T) {
441642
// Verify tool definition once
442643
mockClient := github.NewClient(nil)

0 commit comments

Comments
 (0)