Skip to content

Commit 7cc1ff4

Browse files
committed
use GraphQL API to get file SHA
1 parent e650bb0 commit 7cc1ff4

File tree

3 files changed

+70
-13
lines changed

3 files changed

+70
-13
lines changed

pkg/github/repositories.go

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@ import (
1616
"github.com/google/go-github/v72/github"
1717
"github.com/mark3labs/mcp-go/mcp"
1818
"github.com/mark3labs/mcp-go/server"
19+
"github.com/shurcooL/githubv4"
1920
)
2021

22+
// getFileSHAFunc is a package-level variable that holds the getFileSHA function
23+
// This allows tests to mock the behavior
24+
var getFileSHAFunc = getFileSHA
25+
2126
func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
2227
return mcp.NewTool("get_commit",
2328
mcp.WithDescription(t("TOOL_GET_COMMITS_DESCRIPTION", "Get details for a commit from a GitHub repository")),
@@ -446,7 +451,7 @@ func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFun
446451
}
447452

448453
// GetFileContents creates a tool to get the contents of a file or directory from a GitHub repository.
449-
func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
454+
func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
450455
return mcp.NewTool("get_file_contents",
451456
mcp.WithDescription(t("TOOL_GET_FILE_CONTENTS_DESCRIPTION", "Get the contents of a file or directory from a GitHub repository")),
452457
mcp.WithToolAnnotation(mcp.ToolAnnotation{
@@ -507,15 +512,13 @@ func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t t
507512
// If the path is (most likely) not to be a directory, we will
508513
// first try to get the raw content from the GitHub raw content API.
509514
if path != "" && !strings.HasSuffix(path, "/") {
510-
// First, get file info from Contents API to retrieve SHA
511-
var fileSHA string
512-
opts := &github.RepositoryContentGetOptions{Ref: ref}
513-
fileContent, _, respContents, err := client.Repositories.GetContents(ctx, owner, repo, path, opts)
514-
if respContents != nil {
515-
defer func() { _ = respContents.Body.Close() }()
515+
gqlClient, err := getGQLClient(ctx)
516+
if err != nil {
517+
return mcp.NewToolResultError("failed to get GitHub GraphQL client"), nil
516518
}
517-
if err == nil && respContents.StatusCode == http.StatusOK && fileContent != nil && fileContent.SHA != nil {
518-
fileSHA = *fileContent.SHA
519+
fileSHA, err := getFileSHAFunc(ctx, gqlClient, owner, repo, path, rawOpts)
520+
if err != nil {
521+
return mcp.NewToolResultError(fmt.Sprintf("failed to get file SHA: %s", err)), nil
519522
}
520523

521524
rawClient, err := getRawClient(ctx)
@@ -1383,3 +1386,34 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner
13831386
// Use provided ref, or it will be empty which defaults to the default branch
13841387
return &raw.ContentOpts{Ref: ref, SHA: sha}, nil
13851388
}
1389+
1390+
// getFileSHA retrieves the Blob SHA of a file.
1391+
func getFileSHA(ctx context.Context, client *githubv4.Client, owner, repo, path string, opts *raw.ContentOpts) (string, error) {
1392+
var query struct {
1393+
Repository struct {
1394+
Object struct {
1395+
Blob struct {
1396+
OID string
1397+
} `graphql:"... on Blob"`
1398+
} `graphql:"object(expression: $expression)"`
1399+
} `graphql:"repository(owner: $owner, name: $repo)"`
1400+
}
1401+
1402+
// Prepare the expression based on the provided options
1403+
expression := githubv4.String(path)
1404+
if opts != nil && opts.SHA != "" {
1405+
expression = githubv4.String(fmt.Sprintf("%s:%s", opts.SHA, path))
1406+
} else if opts != nil && opts.Ref != "" {
1407+
expression = githubv4.String(fmt.Sprintf("%s:%s", opts.Ref, path))
1408+
}
1409+
1410+
variables := map[string]interface{}{
1411+
"owner": githubv4.String(owner),
1412+
"repo": githubv4.String(repo),
1413+
"expression": expression,
1414+
}
1415+
if err := client.Query(ctx, &query, variables); err != nil {
1416+
return "", fmt.Errorf("failed to query file SHA: %w", err)
1417+
}
1418+
return query.Repository.Object.Blob.OID, nil
1419+
}

pkg/github/repositories_test.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,24 @@ import (
1515
"github.com/google/go-github/v72/github"
1616
"github.com/mark3labs/mcp-go/mcp"
1717
"github.com/migueleliasweb/go-github-mock/src/mock"
18+
"github.com/shurcooL/githubv4"
1819
"github.com/stretchr/testify/assert"
1920
"github.com/stretchr/testify/require"
2021
)
2122

23+
// mockGetFileSHA is a test helper that mocks the getFileSHA function so that we don't need to mock the GraphQL client in Test_GetFileContents.
24+
func mockGetFileSHA(expectedSHA string) func(context.Context, *githubv4.Client, string, string, string, *raw.ContentOpts) (string, error) {
25+
return func(_ context.Context, _ *githubv4.Client, _, _, _ string, _ *raw.ContentOpts) (string, error) {
26+
return expectedSHA, nil
27+
}
28+
}
29+
2230
func Test_GetFileContents(t *testing.T) {
2331
// Verify tool definition once
2432
mockClient := github.NewClient(nil)
2533
mockRawClient := raw.NewClient(mockClient, &url.URL{Scheme: "https", Host: "raw.githubusercontent.com", Path: "/"})
26-
tool, _ := GetFileContents(stubGetClientFn(mockClient), stubGetRawClientFn(mockRawClient), translations.NullTranslationHelper)
34+
mockGQLClient := githubv4.NewClient(nil)
35+
tool, _ := GetFileContents(stubGetClientFn(mockClient), stubGetRawClientFn(mockRawClient), stubGetGQLClientFn(mockGQLClient), translations.NullTranslationHelper)
2736
require.NoError(t, toolsnaps.Test(tool.Name, tool))
2837

2938
assert.Equal(t, "get_file_contents", tool.Name)
@@ -56,10 +65,10 @@ func Test_GetFileContents(t *testing.T) {
5665
HTMLURL: github.Ptr("https://github.com/owner/repo/tree/main/src"),
5766
},
5867
}
59-
6068
tests := []struct {
6169
name string
6270
mockedClient *http.Client
71+
mockFileSHA string
6372
requestArgs map[string]interface{}
6473
expectError bool
6574
expectedResult interface{}
@@ -98,6 +107,7 @@ func Test_GetFileContents(t *testing.T) {
98107
}),
99108
),
100109
),
110+
mockFileSHA: "abc123",
101111
requestArgs: map[string]interface{}{
102112
"owner": "owner",
103113
"repo": "repo",
@@ -143,6 +153,7 @@ func Test_GetFileContents(t *testing.T) {
143153
}),
144154
),
145155
),
156+
mockFileSHA: "def456",
146157
requestArgs: map[string]interface{}{
147158
"owner": "owner",
148159
"repo": "repo",
@@ -188,6 +199,7 @@ func Test_GetFileContents(t *testing.T) {
188199
),
189200
),
190201
),
202+
mockFileSHA: "", // Directory content doesn't need SHA
191203
requestArgs: map[string]interface{}{
192204
"owner": "owner",
193205
"repo": "repo",
@@ -221,6 +233,7 @@ func Test_GetFileContents(t *testing.T) {
221233
}),
222234
),
223235
),
236+
mockFileSHA: "", // Error case doesn't need SHA
224237
requestArgs: map[string]interface{}{
225238
"owner": "owner",
226239
"repo": "repo",
@@ -234,10 +247,20 @@ func Test_GetFileContents(t *testing.T) {
234247

235248
for _, tc := range tests {
236249
t.Run(tc.name, func(t *testing.T) {
250+
// Mock the getFileSHA function if mockFileSHA is provided
251+
if tc.mockFileSHA != "" {
252+
originalGetFileSHA := getFileSHAFunc
253+
getFileSHAFunc = mockGetFileSHA(tc.mockFileSHA)
254+
defer func() {
255+
getFileSHAFunc = originalGetFileSHA
256+
}()
257+
}
258+
237259
// Setup client with mock
238260
client := github.NewClient(tc.mockedClient)
239261
mockRawClient := raw.NewClient(client, &url.URL{Scheme: "https", Host: "raw.example.com", Path: "/"})
240-
_, handler := GetFileContents(stubGetClientFn(client), stubGetRawClientFn(mockRawClient), translations.NullTranslationHelper)
262+
mockGQLClient := githubv4.NewClient(tc.mockedClient)
263+
_, handler := GetFileContents(stubGetClientFn(client), stubGetRawClientFn(mockRawClient), stubGetGQLClientFn(mockGQLClient), translations.NullTranslationHelper)
241264

242265
// Create call request
243266
request := createMCPRequest(tc.requestArgs)

pkg/github/tools.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG
2424
repos := toolsets.NewToolset("repos", "GitHub Repository related tools").
2525
AddReadTools(
2626
toolsets.NewServerTool(SearchRepositories(getClient, t)),
27-
toolsets.NewServerTool(GetFileContents(getClient, getRawClient, t)),
27+
toolsets.NewServerTool(GetFileContents(getClient, getRawClient, getGQLClient, t)),
2828
toolsets.NewServerTool(ListCommits(getClient, t)),
2929
toolsets.NewServerTool(SearchCode(getClient, t)),
3030
toolsets.NewServerTool(GetCommit(getClient, t)),

0 commit comments

Comments
 (0)