diff --git a/src/transports/graphql/graphql_transport.go b/src/transports/graphql/graphql_transport.go index 703bb3e..3e2ef07 100644 --- a/src/transports/graphql/graphql_transport.go +++ b/src/transports/graphql/graphql_transport.go @@ -271,52 +271,46 @@ func (t *GraphQLClientTransport) RegisterToolProvider( return nil, fmt.Errorf("introspection failed: %w", err) } - // Build tool list + // Build tool list with optional filtering by operation type/name var toolsList []Tool - // Register query fields - for _, f := range resp.Schema.QueryType.Fields { + opType := strings.ToLower(prov.OperationType) + + // Helper to register a field if it matches the optional OperationName + addTool := func(fieldName string, descPtr *string) { + if prov.OperationName != nil && *prov.OperationName != fieldName { + return + } desc := "" - if f.Description != nil { - desc = *f.Description + if descPtr != nil { + desc = *descPtr } toolsList = append(toolsList, Tool{ - Name: fmt.Sprintf("%s.%s", prov.Name, f.Name), + Name: fmt.Sprintf("%s.%s", prov.Name, fieldName), Description: desc, Inputs: ToolInputOutputSchema{Required: nil}, Provider: prov, }) } + // Register query fields + if opType == "" || opType == "query" { + for _, f := range resp.Schema.QueryType.Fields { + addTool(f.Name, f.Description) + } + } + // Register mutation fields - if resp.Schema.MutationType != nil { + if (opType == "" || opType == "mutation") && resp.Schema.MutationType != nil { for _, f := range resp.Schema.MutationType.Fields { - desc := "" - if f.Description != nil { - desc = *f.Description - } - toolsList = append(toolsList, Tool{ - Name: fmt.Sprintf("%s.%s", prov.Name, f.Name), - Description: desc, - Inputs: ToolInputOutputSchema{Required: nil}, - Provider: prov, - }) + addTool(f.Name, f.Description) } } // Register subscription fields - if resp.Schema.SubscriptionType != nil { + if (opType == "" || opType == "subscription") && resp.Schema.SubscriptionType != nil { for _, f := range resp.Schema.SubscriptionType.Fields { - desc := "" - if f.Description != nil { - desc = *f.Description - } - toolsList = append(toolsList, Tool{ - Name: fmt.Sprintf("%s.%s", prov.Name, f.Name), - Description: desc, - Inputs: ToolInputOutputSchema{Required: nil}, - Provider: prov, - }) + addTool(f.Name, f.Description) } } return toolsList, nil diff --git a/src/transports/graphql/graphql_transport_test.go b/src/transports/graphql/graphql_transport_test.go index 70e3f63..dfdb922 100644 --- a/src/transports/graphql/graphql_transport_test.go +++ b/src/transports/graphql/graphql_transport_test.go @@ -60,3 +60,59 @@ func TestGraphQLClientTransport_RegisterAndCall(t *testing.T) { t.Fatalf("unexpected result: %#v", res) } } + +// TestGraphQLClientTransport_RegisterToolFiltering ensures that tools are +// filtered by provider OperationType and OperationName to avoid duplicates. +func TestGraphQLClientTransport_RegisterToolFiltering(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req struct { + Query string `json:"query"` + } + json.NewDecoder(r.Body).Decode(&req) + if strings.Contains(req.Query, "__schema") { + // Return one query field and one subscription field + resp := map[string]any{"data": map[string]any{"__schema": map[string]any{ + "queryType": map[string]any{"fields": []map[string]any{{"name": "echo"}, {"name": "ping"}}}, + "subscriptionType": map[string]any{"fields": []map[string]any{{"name": "updates"}}}, + }}} + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + http.Error(w, "bad request", http.StatusBadRequest) + })) + defer server.Close() + + tr := NewGraphQLClientTransport(nil) + ctx := context.Background() + + // Query provider should only register query fields and respect OperationName + qName := "echo" + provQuery := &GraphQLProvider{ + BaseProvider: BaseProvider{Name: "gql", ProviderType: ProviderGraphQL}, + URL: server.URL, + OperationType: "query", + OperationName: &qName, + } + tools, err := tr.RegisterToolProvider(ctx, provQuery) + if err != nil { + t.Fatalf("register error: %v", err) + } + if len(tools) != 1 || tools[0].Name != "gql.echo" { + t.Fatalf("unexpected tools: %#v", tools) + } + + // Subscription provider should only register subscription field + provSub := &GraphQLProvider{ + BaseProvider: BaseProvider{Name: "gqlsub", ProviderType: ProviderGraphQL}, + URL: server.URL, + OperationType: "subscription", + } + tools, err = tr.RegisterToolProvider(ctx, provSub) + if err != nil { + t.Fatalf("register error: %v", err) + } + if len(tools) != 1 || tools[0].Name != "gqlsub.updates" { + t.Fatalf("unexpected tools for subscription: %#v", tools) + } +}