@@ -437,6 +437,207 @@ func Test_UpdatePullRequest(t *testing.T) {
437437 }
438438}
439439
440+ func Test_UpdatePullRequest_Draft (t * testing.T ) {
441+ // Setup mock PR for success case
442+ mockUpdatedPR := & github.PullRequest {
443+ Number : github .Ptr (42 ),
444+ Title : github .Ptr ("Test PR Title" ),
445+ State : github .Ptr ("open" ),
446+ HTMLURL : github .Ptr ("https://github.com/owner/repo/pull/42" ),
447+ Body : github .Ptr ("Test PR body." ),
448+ MaintainerCanModify : github .Ptr (false ),
449+ Draft : github .Ptr (false ), // Updated to ready for review
450+ Base : & github.PullRequestBranch {
451+ Ref : github .Ptr ("main" ),
452+ },
453+ }
454+
455+ tests := []struct {
456+ name string
457+ mockedClient * http.Client
458+ requestArgs map [string ]interface {}
459+ expectError bool
460+ expectedPR * github.PullRequest
461+ expectedErrMsg string
462+ }{
463+ {
464+ name : "successful draft update to ready for review" ,
465+ mockedClient : githubv4mock .NewMockedHTTPClient (
466+ githubv4mock .NewQueryMatcher (
467+ struct {
468+ Repository struct {
469+ PullRequest struct {
470+ ID githubv4.ID
471+ IsDraft githubv4.Boolean
472+ } `graphql:"pullRequest(number: $prNum)"`
473+ } `graphql:"repository(owner: $owner, name: $repo)"`
474+ }{},
475+ map [string ]any {
476+ "owner" : githubv4 .String ("owner" ),
477+ "repo" : githubv4 .String ("repo" ),
478+ "prNum" : githubv4 .Int (42 ),
479+ },
480+ githubv4mock .DataResponse (map [string ]any {
481+ "repository" : map [string ]any {
482+ "pullRequest" : map [string ]any {
483+ "id" : "PR_kwDOA0xdyM50BPaO" ,
484+ "isDraft" : true , // Current state is draft
485+ },
486+ },
487+ }),
488+ ),
489+ githubv4mock .NewMutationMatcher (
490+ struct {
491+ MarkPullRequestReadyForReview struct {
492+ PullRequest struct {
493+ ID githubv4.ID
494+ IsDraft githubv4.Boolean
495+ }
496+ } `graphql:"markPullRequestReadyForReview(input: $input)"`
497+ }{},
498+ githubv4.MarkPullRequestReadyForReviewInput {
499+ PullRequestID : "PR_kwDOA0xdyM50BPaO" ,
500+ },
501+ nil ,
502+ githubv4mock .DataResponse (map [string ]any {
503+ "markPullRequestReadyForReview" : map [string ]any {
504+ "pullRequest" : map [string ]any {
505+ "id" : "PR_kwDOA0xdyM50BPaO" ,
506+ "isDraft" : false ,
507+ },
508+ },
509+ }),
510+ ),
511+ ),
512+ requestArgs : map [string ]interface {}{
513+ "owner" : "owner" ,
514+ "repo" : "repo" ,
515+ "pullNumber" : float64 (42 ),
516+ "draft" : false ,
517+ },
518+ expectError : false ,
519+ expectedPR : mockUpdatedPR ,
520+ },
521+ {
522+ name : "successful convert pull request to draft" ,
523+ mockedClient : githubv4mock .NewMockedHTTPClient (
524+ githubv4mock .NewQueryMatcher (
525+ struct {
526+ Repository struct {
527+ PullRequest struct {
528+ ID githubv4.ID
529+ IsDraft githubv4.Boolean
530+ } `graphql:"pullRequest(number: $prNum)"`
531+ } `graphql:"repository(owner: $owner, name: $repo)"`
532+ }{},
533+ map [string ]any {
534+ "owner" : githubv4 .String ("owner" ),
535+ "repo" : githubv4 .String ("repo" ),
536+ "prNum" : githubv4 .Int (42 ),
537+ },
538+ githubv4mock .DataResponse (map [string ]any {
539+ "repository" : map [string ]any {
540+ "pullRequest" : map [string ]any {
541+ "id" : "PR_kwDOA0xdyM50BPaO" ,
542+ "isDraft" : false , // Current state is draft
543+ },
544+ },
545+ }),
546+ ),
547+ githubv4mock .NewMutationMatcher (
548+ struct {
549+ ConvertPullRequestToDraft struct {
550+ PullRequest struct {
551+ ID githubv4.ID
552+ IsDraft githubv4.Boolean
553+ }
554+ } `graphql:"convertPullRequestToDraft(input: $input)"`
555+ }{},
556+ githubv4.ConvertPullRequestToDraftInput {
557+ PullRequestID : "PR_kwDOA0xdyM50BPaO" ,
558+ },
559+ nil ,
560+ githubv4mock .DataResponse (map [string ]any {
561+ "convertPullRequestToDraft" : map [string ]any {
562+ "pullRequest" : map [string ]any {
563+ "id" : "PR_kwDOA0xdyM50BPaO" ,
564+ "isDraft" : true ,
565+ },
566+ },
567+ }),
568+ ),
569+ ),
570+ requestArgs : map [string ]interface {}{
571+ "owner" : "owner" ,
572+ "repo" : "repo" ,
573+ "pullNumber" : float64 (42 ),
574+ "draft" : true ,
575+ },
576+ expectError : false ,
577+ expectedPR : mockUpdatedPR ,
578+ },
579+ }
580+
581+ for _ , tc := range tests {
582+ t .Run (tc .name , func (t * testing.T ) {
583+ // For draft-only tests, we need to mock both GraphQL and the final REST GET call
584+ restClient := github .NewClient (mock .NewMockedHTTPClient (
585+ mock .WithRequestMatch (
586+ mock .GetReposPullsByOwnerByRepoByPullNumber ,
587+ mockUpdatedPR ,
588+ ),
589+ ))
590+ gqlClient := githubv4 .NewClient (tc .mockedClient )
591+
592+ _ , handler := UpdatePullRequest (stubGetClientFn (restClient ), stubGetGQLClientFn (gqlClient ), translations .NullTranslationHelper )
593+
594+ request := createMCPRequest (tc .requestArgs )
595+
596+ result , err := handler (context .Background (), request )
597+
598+ if tc .expectError || tc .expectedErrMsg != "" {
599+ require .NoError (t , err )
600+ require .True (t , result .IsError )
601+ errorContent := getErrorResult (t , result )
602+ if tc .expectedErrMsg != "" {
603+ assert .Contains (t , errorContent .Text , tc .expectedErrMsg )
604+ }
605+ return
606+ }
607+
608+ require .NoError (t , err )
609+ require .False (t , result .IsError )
610+
611+ textContent := getTextResult (t , result )
612+
613+ // Unmarshal and verify the successful result
614+ var returnedPR github.PullRequest
615+ err = json .Unmarshal ([]byte (textContent .Text ), & returnedPR )
616+ require .NoError (t , err )
617+ assert .Equal (t , * tc .expectedPR .Number , * returnedPR .Number )
618+ if tc .expectedPR .Title != nil {
619+ assert .Equal (t , * tc .expectedPR .Title , * returnedPR .Title )
620+ }
621+ if tc .expectedPR .Body != nil {
622+ assert .Equal (t , * tc .expectedPR .Body , * returnedPR .Body )
623+ }
624+ if tc .expectedPR .State != nil {
625+ assert .Equal (t , * tc .expectedPR .State , * returnedPR .State )
626+ }
627+ if tc .expectedPR .Base != nil && tc .expectedPR .Base .Ref != nil {
628+ assert .NotNil (t , returnedPR .Base )
629+ assert .Equal (t , * tc .expectedPR .Base .Ref , * returnedPR .Base .Ref )
630+ }
631+ if tc .expectedPR .MaintainerCanModify != nil {
632+ assert .Equal (t , * tc .expectedPR .MaintainerCanModify , * returnedPR .MaintainerCanModify )
633+ }
634+ if tc .expectedPR .Draft != nil {
635+ assert .Equal (t , * tc .expectedPR .Draft , * returnedPR .Draft )
636+ }
637+ })
638+ }
639+ }
640+
440641func Test_ListPullRequests (t * testing.T ) {
441642 // Verify tool definition once
442643 mockClient := github .NewClient (nil )
0 commit comments