Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 22 additions & 28 deletions src/transports/graphql/graphql_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,52 +271,46 @@ func (t *GraphQLClientTransport) RegisterToolProvider(
return nil, fmt.Errorf("introspection failed: %w", err)
}

// Build tool list
// Build tool list with optional filtering by operation type/name
var toolsList []Tool

// Register query fields
for _, f := range resp.Schema.QueryType.Fields {
opType := strings.ToLower(prov.OperationType)

// Helper to register a field if it matches the optional OperationName
addTool := func(fieldName string, descPtr *string) {
if prov.OperationName != nil && *prov.OperationName != fieldName {
return
}
desc := ""
if f.Description != nil {
desc = *f.Description
if descPtr != nil {
desc = *descPtr
}
toolsList = append(toolsList, Tool{
Name: fmt.Sprintf("%s.%s", prov.Name, f.Name),
Name: fmt.Sprintf("%s.%s", prov.Name, fieldName),
Description: desc,
Inputs: ToolInputOutputSchema{Required: nil},
Provider: prov,
})
}

// Register query fields
if opType == "" || opType == "query" {
for _, f := range resp.Schema.QueryType.Fields {
addTool(f.Name, f.Description)
}
}

// Register mutation fields
if resp.Schema.MutationType != nil {
if (opType == "" || opType == "mutation") && resp.Schema.MutationType != nil {
for _, f := range resp.Schema.MutationType.Fields {
desc := ""
if f.Description != nil {
desc = *f.Description
}
toolsList = append(toolsList, Tool{
Name: fmt.Sprintf("%s.%s", prov.Name, f.Name),
Description: desc,
Inputs: ToolInputOutputSchema{Required: nil},
Provider: prov,
})
addTool(f.Name, f.Description)
}
}

// Register subscription fields
if resp.Schema.SubscriptionType != nil {
if (opType == "" || opType == "subscription") && resp.Schema.SubscriptionType != nil {
for _, f := range resp.Schema.SubscriptionType.Fields {
desc := ""
if f.Description != nil {
desc = *f.Description
}
toolsList = append(toolsList, Tool{
Name: fmt.Sprintf("%s.%s", prov.Name, f.Name),
Description: desc,
Inputs: ToolInputOutputSchema{Required: nil},
Provider: prov,
})
addTool(f.Name, f.Description)
}
}
return toolsList, nil
Expand Down
56 changes: 56 additions & 0 deletions src/transports/graphql/graphql_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,59 @@ func TestGraphQLClientTransport_RegisterAndCall(t *testing.T) {
t.Fatalf("unexpected result: %#v", res)
}
}

// TestGraphQLClientTransport_RegisterToolFiltering ensures that tools are
// filtered by provider OperationType and OperationName to avoid duplicates.
func TestGraphQLClientTransport_RegisterToolFiltering(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req struct {
Query string `json:"query"`
}
json.NewDecoder(r.Body).Decode(&req)
if strings.Contains(req.Query, "__schema") {
// Return one query field and one subscription field
resp := map[string]any{"data": map[string]any{"__schema": map[string]any{
"queryType": map[string]any{"fields": []map[string]any{{"name": "echo"}, {"name": "ping"}}},
"subscriptionType": map[string]any{"fields": []map[string]any{{"name": "updates"}}},
}}}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
return
}
http.Error(w, "bad request", http.StatusBadRequest)
}))
defer server.Close()

tr := NewGraphQLClientTransport(nil)
ctx := context.Background()

// Query provider should only register query fields and respect OperationName
qName := "echo"
provQuery := &GraphQLProvider{
BaseProvider: BaseProvider{Name: "gql", ProviderType: ProviderGraphQL},
URL: server.URL,
OperationType: "query",
OperationName: &qName,
}
tools, err := tr.RegisterToolProvider(ctx, provQuery)
if err != nil {
t.Fatalf("register error: %v", err)
}
if len(tools) != 1 || tools[0].Name != "gql.echo" {
t.Fatalf("unexpected tools: %#v", tools)
}

// Subscription provider should only register subscription field
provSub := &GraphQLProvider{
BaseProvider: BaseProvider{Name: "gqlsub", ProviderType: ProviderGraphQL},
URL: server.URL,
OperationType: "subscription",
}
tools, err = tr.RegisterToolProvider(ctx, provSub)
if err != nil {
t.Fatalf("register error: %v", err)
}
if len(tools) != 1 || tools[0].Name != "gqlsub.updates" {
t.Fatalf("unexpected tools for subscription: %#v", tools)
}
}
Loading