Skip to content

Commit 90eb11a

Browse files
mattdhollowayMayorFaj
authored andcommitted
Add updating draft state to update_pull_request tool (github#774)
* initial impl of pull request draft state update * appease linter * update README * add nosec * fixed err return type for json marshalling * add gql test
1 parent 046f994 commit 90eb11a

File tree

5 files changed

+334
-63
lines changed

5 files changed

+334
-63
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,7 @@ The following sets of tools are available (all are on by default):
736736
- **update_pull_request** - Edit pull request
737737
- `base`: New base branch name (string, optional)
738738
- `body`: New description (string, optional)
739+
- `draft`: Mark pull request as draft (true) or ready for review (false) (boolean, optional)
739740
- `maintainer_can_modify`: Allow maintainer edits (boolean, optional)
740741
- `owner`: Repository owner (string, required)
741742
- `pullNumber`: Pull request number to update (number, required)

pkg/github/__toolsnaps__/update_pull_request.snap

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
"description": "New description",
1515
"type": "string"
1616
},
17+
"draft": {
18+
"description": "Mark pull request as draft (true) or ready for review (false)",
19+
"type": "boolean"
20+
},
1721
"maintainer_can_modify": {
1822
"description": "Allow maintainer edits",
1923
"type": "boolean"

pkg/github/pullrequests.go

Lines changed: 104 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
203203
}
204204

205205
// UpdatePullRequest creates a tool to update an existing pull request.
206-
func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
206+
func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
207207
return mcp.NewTool("update_pull_request",
208208
mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository.")),
209209
mcp.WithToolAnnotation(mcp.ToolAnnotation{
@@ -232,6 +232,9 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
232232
mcp.Description("New state"),
233233
mcp.Enum("open", "closed"),
234234
),
235+
mcp.WithBoolean("draft",
236+
mcp.Description("Mark pull request as draft (true) or ready for review (false)"),
237+
),
235238
mcp.WithString("base",
236239
mcp.Description("New base branch name"),
237240
),
@@ -259,43 +262,51 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
259262
return mcp.NewToolResultError(err.Error()), nil
260263
}
261264

262-
// Build the update struct only with provided fields
265+
draftProvided := request.GetArguments()["draft"] != nil
266+
var draftValue bool
267+
if draftProvided {
268+
draftValue, err = OptionalParam[bool](request, "draft")
269+
if err != nil {
270+
return nil, err
271+
}
272+
}
273+
263274
update := &github.PullRequest{}
264-
updateNeeded := false
275+
restUpdateNeeded := false
265276

266277
if title, ok, err := OptionalParamOK[string](request, "title"); err != nil {
267278
return mcp.NewToolResultError(err.Error()), nil
268279
} else if ok {
269280
update.Title = github.Ptr(title)
270-
updateNeeded = true
281+
restUpdateNeeded = true
271282
}
272283

273284
if body, ok, err := OptionalParamOK[string](request, "body"); err != nil {
274285
return mcp.NewToolResultError(err.Error()), nil
275286
} else if ok {
276287
update.Body = github.Ptr(body)
277-
updateNeeded = true
288+
restUpdateNeeded = true
278289
}
279290

280291
if state, ok, err := OptionalParamOK[string](request, "state"); err != nil {
281292
return mcp.NewToolResultError(err.Error()), nil
282293
} else if ok {
283294
update.State = github.Ptr(state)
284-
updateNeeded = true
295+
restUpdateNeeded = true
285296
}
286297

287298
if base, ok, err := OptionalParamOK[string](request, "base"); err != nil {
288299
return mcp.NewToolResultError(err.Error()), nil
289300
} else if ok {
290301
update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)}
291-
updateNeeded = true
302+
restUpdateNeeded = true
292303
}
293304

294305
if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil {
295306
return mcp.NewToolResultError(err.Error()), nil
296307
} else if ok {
297308
update.MaintainerCanModify = github.Ptr(maintainerCanModify)
298-
updateNeeded = true
309+
restUpdateNeeded = true
299310
}
300311

301312
// Handle reviewers separately
@@ -305,82 +316,115 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
305316
}
306317

307318
// If no updates and no reviewers, return error early
308-
if !updateNeeded && len(reviewers) == 0 {
319+
if !restUpdateNeeded && len(reviewers) == 0 && !draftProvided {
309320
return mcp.NewToolResultError("No update parameters provided"), nil
310321
}
311322

312-
// Create the GitHub client
313323
client, err := getClient(ctx)
314324
if err != nil {
315325
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
316326
}
327+
pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update)
328+
if err != nil {
329+
return ghErrors.NewGitHubAPIErrorResponse(ctx,
330+
"failed to update pull request",
331+
resp,
332+
err,
333+
), nil
334+
}
335+
defer func() { _ = resp.Body.Close() }()
317336

318-
var pr *github.PullRequest
319-
var resp *http.Response
320-
321-
// Update the PR if needed
322-
if updateNeeded {
323-
var ghResp *github.Response
324-
pr, ghResp, err = client.PullRequests.Edit(ctx, owner, repo, pullNumber, update)
337+
if resp.StatusCode != http.StatusOK {
338+
body, err := io.ReadAll(resp.Body)
325339
if err != nil {
326-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
327-
"failed to update pull request",
328-
ghResp,
329-
err,
330-
), nil
340+
return nil, fmt.Errorf("failed to read response body: %w", err)
331341
}
332-
resp = ghResp.Response
333-
defer func() {
334-
if resp != nil && resp.Body != nil {
335-
_ = resp.Body.Close()
336-
}
337-
}()
342+
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil
343+
}
338344

339-
if resp.StatusCode != http.StatusOK {
340-
body, err := io.ReadAll(resp.Body)
341-
if err != nil {
342-
return nil, fmt.Errorf("failed to read response body: %w", err)
343-
}
344-
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil
345+
if draftProvided {
346+
gqlClient, err := getGQLClient(ctx)
347+
if err != nil {
348+
return nil, fmt.Errorf("failed to get GitHub GraphQL client: %w", err)
345349
}
346-
}
347350

348-
// Add reviewers if specified
349-
if len(reviewers) > 0 {
350-
reviewersRequest := github.ReviewersRequest{
351-
Reviewers: reviewers,
351+
var prQuery struct {
352+
Repository struct {
353+
PullRequest struct {
354+
ID githubv4.ID
355+
IsDraft githubv4.Boolean
356+
} `graphql:"pullRequest(number: $prNum)"`
357+
} `graphql:"repository(owner: $owner, name: $repo)"`
352358
}
353359

354-
// Use the direct result of RequestReviewers which includes the requested reviewers
355-
updatedPR, resp, err := client.PullRequests.RequestReviewers(ctx, owner, repo, pullNumber, reviewersRequest)
360+
err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{
361+
"owner": githubv4.String(owner),
362+
"repo": githubv4.String(repo),
363+
"prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers
364+
})
356365
if err != nil {
357-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
358-
"failed to request reviewers",
359-
resp,
360-
err,
361-
), nil
366+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil
362367
}
363-
defer func() {
364-
if resp != nil && resp.Body != nil {
365-
_ = resp.Body.Close()
366-
}
367-
}()
368368

369-
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
370-
body, err := io.ReadAll(resp.Body)
371-
if err != nil {
372-
return nil, fmt.Errorf("failed to read response body: %w", err)
369+
currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft)
370+
371+
if currentIsDraft != draftValue {
372+
if draftValue {
373+
// Convert to draft
374+
var mutation struct {
375+
ConvertPullRequestToDraft struct {
376+
PullRequest struct {
377+
ID githubv4.ID
378+
IsDraft githubv4.Boolean
379+
}
380+
} `graphql:"convertPullRequestToDraft(input: $input)"`
381+
}
382+
383+
err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{
384+
PullRequestID: prQuery.Repository.PullRequest.ID,
385+
}, nil)
386+
if err != nil {
387+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil
388+
}
389+
} else {
390+
// Mark as ready for review
391+
var mutation struct {
392+
MarkPullRequestReadyForReview struct {
393+
PullRequest struct {
394+
ID githubv4.ID
395+
IsDraft githubv4.Boolean
396+
}
397+
} `graphql:"markPullRequestReadyForReview(input: $input)"`
398+
}
399+
400+
err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{
401+
PullRequestID: prQuery.Repository.PullRequest.ID,
402+
}, nil)
403+
if err != nil {
404+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil
405+
}
373406
}
374-
return mcp.NewToolResultError(fmt.Sprintf("failed to request reviewers: %s", string(body))), nil
375407
}
408+
}
376409

377-
// Use the updated PR with reviewers
378-
pr = updatedPR
410+
client, err := getClient(ctx)
411+
if err != nil {
412+
return nil, err
379413
}
380414

381-
r, err := json.Marshal(pr)
415+
finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber)
382416
if err != nil {
383-
return nil, fmt.Errorf("failed to marshal response: %w", err)
417+
return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil
418+
}
419+
defer func() {
420+
if resp != nil && resp.Body != nil {
421+
_ = resp.Body.Close()
422+
}
423+
}()
424+
425+
r, err := json.Marshal(finalPR)
426+
if err != nil {
427+
return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal response: %v", err)), nil
384428
}
385429

386430
return mcp.NewToolResultText(string(r)), nil

0 commit comments

Comments
 (0)