Skip to content

Commit d5e1f48

Browse files
Add updating draft state to update_pull_request tool (#774)
* initial impl of pull request draft state update * appease linter * update README * add nosec * fixed err return type for json marshalling * add gql test
1 parent efef8ae commit d5e1f48

File tree

5 files changed

+348
-29
lines changed

5 files changed

+348
-29
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,7 @@ The following sets of tools are available (all are on by default):
736736
- **update_pull_request** - Edit pull request
737737
- `base`: New base branch name (string, optional)
738738
- `body`: New description (string, optional)
739+
- `draft`: Mark pull request as draft (true) or ready for review (false) (boolean, optional)
739740
- `maintainer_can_modify`: Allow maintainer edits (boolean, optional)
740741
- `owner`: Repository owner (string, required)
741742
- `pullNumber`: Pull request number to update (number, required)

pkg/github/__toolsnaps__/update_pull_request.snap

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
"description": "New description",
1515
"type": "string"
1616
},
17+
"draft": {
18+
"description": "Mark pull request as draft (true) or ready for review (false)",
19+
"type": "boolean"
20+
},
1721
"maintainer_can_modify": {
1822
"description": "Allow maintainer edits",
1923
"type": "boolean"

pkg/github/pullrequests.go

Lines changed: 120 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
203203
}
204204

205205
// UpdatePullRequest creates a tool to update an existing pull request.
206-
func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
206+
func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
207207
return mcp.NewTool("update_pull_request",
208208
mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository.")),
209209
mcp.WithToolAnnotation(mcp.ToolAnnotation{
@@ -232,6 +232,9 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
232232
mcp.Description("New state"),
233233
mcp.Enum("open", "closed"),
234234
),
235+
mcp.WithBoolean("draft",
236+
mcp.Description("Mark pull request as draft (true) or ready for review (false)"),
237+
),
235238
mcp.WithString("base",
236239
mcp.Description("New base branch name"),
237240
),
@@ -253,74 +256,165 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
253256
return mcp.NewToolResultError(err.Error()), nil
254257
}
255258

256-
// Build the update struct only with provided fields
259+
draftProvided := request.GetArguments()["draft"] != nil
260+
var draftValue bool
261+
if draftProvided {
262+
draftValue, err = OptionalParam[bool](request, "draft")
263+
if err != nil {
264+
return nil, err
265+
}
266+
}
267+
257268
update := &github.PullRequest{}
258-
updateNeeded := false
269+
restUpdateNeeded := false
259270

260271
if title, ok, err := OptionalParamOK[string](request, "title"); err != nil {
261272
return mcp.NewToolResultError(err.Error()), nil
262273
} else if ok {
263274
update.Title = github.Ptr(title)
264-
updateNeeded = true
275+
restUpdateNeeded = true
265276
}
266277

267278
if body, ok, err := OptionalParamOK[string](request, "body"); err != nil {
268279
return mcp.NewToolResultError(err.Error()), nil
269280
} else if ok {
270281
update.Body = github.Ptr(body)
271-
updateNeeded = true
282+
restUpdateNeeded = true
272283
}
273284

274285
if state, ok, err := OptionalParamOK[string](request, "state"); err != nil {
275286
return mcp.NewToolResultError(err.Error()), nil
276287
} else if ok {
277288
update.State = github.Ptr(state)
278-
updateNeeded = true
289+
restUpdateNeeded = true
279290
}
280291

281292
if base, ok, err := OptionalParamOK[string](request, "base"); err != nil {
282293
return mcp.NewToolResultError(err.Error()), nil
283294
} else if ok {
284295
update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)}
285-
updateNeeded = true
296+
restUpdateNeeded = true
286297
}
287298

288299
if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil {
289300
return mcp.NewToolResultError(err.Error()), nil
290301
} else if ok {
291302
update.MaintainerCanModify = github.Ptr(maintainerCanModify)
292-
updateNeeded = true
303+
restUpdateNeeded = true
293304
}
294305

295-
if !updateNeeded {
306+
if !restUpdateNeeded && !draftProvided {
296307
return mcp.NewToolResultError("No update parameters provided."), nil
297308
}
298309

310+
if restUpdateNeeded {
311+
client, err := getClient(ctx)
312+
if err != nil {
313+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
314+
}
315+
316+
_, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update)
317+
if err != nil {
318+
return ghErrors.NewGitHubAPIErrorResponse(ctx,
319+
"failed to update pull request",
320+
resp,
321+
err,
322+
), nil
323+
}
324+
defer func() { _ = resp.Body.Close() }()
325+
326+
if resp.StatusCode != http.StatusOK {
327+
body, err := io.ReadAll(resp.Body)
328+
if err != nil {
329+
return nil, fmt.Errorf("failed to read response body: %w", err)
330+
}
331+
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil
332+
}
333+
}
334+
335+
if draftProvided {
336+
gqlClient, err := getGQLClient(ctx)
337+
if err != nil {
338+
return nil, fmt.Errorf("failed to get GitHub GraphQL client: %w", err)
339+
}
340+
341+
var prQuery struct {
342+
Repository struct {
343+
PullRequest struct {
344+
ID githubv4.ID
345+
IsDraft githubv4.Boolean
346+
} `graphql:"pullRequest(number: $prNum)"`
347+
} `graphql:"repository(owner: $owner, name: $repo)"`
348+
}
349+
350+
err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{
351+
"owner": githubv4.String(owner),
352+
"repo": githubv4.String(repo),
353+
"prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers
354+
})
355+
if err != nil {
356+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil
357+
}
358+
359+
currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft)
360+
361+
if currentIsDraft != draftValue {
362+
if draftValue {
363+
// Convert to draft
364+
var mutation struct {
365+
ConvertPullRequestToDraft struct {
366+
PullRequest struct {
367+
ID githubv4.ID
368+
IsDraft githubv4.Boolean
369+
}
370+
} `graphql:"convertPullRequestToDraft(input: $input)"`
371+
}
372+
373+
err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{
374+
PullRequestID: prQuery.Repository.PullRequest.ID,
375+
}, nil)
376+
if err != nil {
377+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil
378+
}
379+
} else {
380+
// Mark as ready for review
381+
var mutation struct {
382+
MarkPullRequestReadyForReview struct {
383+
PullRequest struct {
384+
ID githubv4.ID
385+
IsDraft githubv4.Boolean
386+
}
387+
} `graphql:"markPullRequestReadyForReview(input: $input)"`
388+
}
389+
390+
err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{
391+
PullRequestID: prQuery.Repository.PullRequest.ID,
392+
}, nil)
393+
if err != nil {
394+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil
395+
}
396+
}
397+
}
398+
}
399+
299400
client, err := getClient(ctx)
300401
if err != nil {
301-
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
402+
return nil, err
302403
}
303-
pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update)
404+
405+
finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber)
304406
if err != nil {
305-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
306-
"failed to update pull request",
307-
resp,
308-
err,
309-
), nil
407+
return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil
310408
}
311-
defer func() { _ = resp.Body.Close() }()
312-
313-
if resp.StatusCode != http.StatusOK {
314-
body, err := io.ReadAll(resp.Body)
315-
if err != nil {
316-
return nil, fmt.Errorf("failed to read response body: %w", err)
409+
defer func() {
410+
if resp != nil && resp.Body != nil {
411+
_ = resp.Body.Close()
317412
}
318-
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil
319-
}
413+
}()
320414

321-
r, err := json.Marshal(pr)
415+
r, err := json.Marshal(finalPR)
322416
if err != nil {
323-
return nil, fmt.Errorf("failed to marshal response: %w", err)
417+
return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal response: %v", err)), nil
324418
}
325419

326420
return mcp.NewToolResultText(string(r)), nil

0 commit comments

Comments
 (0)