diff --git a/examples/grpc_gnmi_client/main.go b/examples/grpc_gnmi_client/main.go index 39550b1..2de0162 100644 --- a/examples/grpc_gnmi_client/main.go +++ b/examples/grpc_gnmi_client/main.go @@ -2,10 +2,13 @@ package main import ( "context" + "encoding/base64" "encoding/json" "fmt" "log" "net" + "path/filepath" + "strings" "time" gnmi "github.com/openconfig/gnmi/proto/gnmi" @@ -14,6 +17,7 @@ import ( "github.com/universal-tool-calling-protocol/go-utcp/src/repository" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" ) type UnifiedServer struct { @@ -21,6 +25,49 @@ type UnifiedServer struct { gnmi.UnimplementedGNMIServer } +const ( + user = "alice" + pass = "secret" +) + +func authFromContext(ctx context.Context) error { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return fmt.Errorf("missing metadata") + } + vals := md.Get("authorization") + if len(vals) == 0 { + return fmt.Errorf("unauthorized") + } + parts := strings.SplitN(vals[0], " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Basic") { + return fmt.Errorf("unauthorized") + } + decoded, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return fmt.Errorf("unauthorized") + } + up := strings.SplitN(string(decoded), ":", 2) + if len(up) != 2 || up[0] != user || up[1] != pass { + return fmt.Errorf("unauthorized") + } + return nil +} + +func unaryAuthInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if err := authFromContext(ctx); err != nil { + return nil, err + } + return handler(ctx, req) +} + +func streamAuthInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if err := authFromContext(ss.Context()); err != nil { + return err + } + return handler(srv, ss) +} + func (s *UnifiedServer) Capabilities(ctx context.Context, req *gnmi.CapabilityRequest) (*gnmi.CapabilityResponse, error) { return &gnmi.CapabilityResponse{}, nil } @@ -65,10 +112,12 @@ func (s *UnifiedServer) CallToolStream(req *grpcpb.ToolCallRequest, stream grpcp // Create a mock update response update := map[string]interface{}{ - "timestamp": time.Now().UnixNano(), - "path": args["path"], - "value": fmt.Sprintf("mock_value_%d", counter), - "mode": args["mode"], + "timestamp": time.Now().UnixNano(), + "path": args["path"], + "value": fmt.Sprintf("mock_value_%d", counter), + "mode": args["mode"], + "sub_mode": args["sub_mode"], + "sample_interval_ns": args["sample_interval_ns"], } updateJson, err := json.Marshal(update) @@ -121,8 +170,11 @@ func (s *UnifiedServer) Subscribe(stream gnmi.GNMI_SubscribeServer) error { Update: &gnmi.Notification{ Timestamp: time.Now().UnixNano(), Update: []*gnmi.Update{{ - Path: &gnmi.Path{Element: []string{"interfaces", "interface", "eth0"}}, - Val: &gnmi.TypedValue{Value: &gnmi.TypedValue_StringVal{StringVal: state}}, + Path: &gnmi.Path{Elem: []*gnmi.PathElem{ + {Name: "interfaces"}, + {Name: "interface", Key: map[string]string{"name": "eth0"}}, + }}, + Val: &gnmi.TypedValue{Value: &gnmi.TypedValue_StringVal{StringVal: state}}, }}, }, }, @@ -203,7 +255,10 @@ func startGNMIServer(addr string) *grpc.Server { if err != nil { log.Fatalf("listen: %v", err) } - srv := grpc.NewServer() + srv := grpc.NewServer( + grpc.UnaryInterceptor(unaryAuthInterceptor), + grpc.StreamInterceptor(streamAuthInterceptor), + ) gnmi.RegisterGNMIServer(srv, &UnifiedServer{}) grpcpb.RegisterUTCPServiceServer(srv, &UnifiedServer{}) go srv.Serve(lis) @@ -217,7 +272,7 @@ func main() { ctx := context.Background() repo := repository.NewInMemoryToolRepository() - cfg := &utcp.UtcpClientConfig{ProvidersFilePath: "provider.json"} + cfg := &utcp.UtcpClientConfig{ProvidersFilePath: filepath.Join("examples", "grpc_gnmi_client", "provider.json")} client, err := utcp.NewUTCPClient(ctx, cfg, repo, nil) if err != nil { log.Fatalf("client error: %v", err) @@ -229,8 +284,10 @@ func main() { } stream, err := client.CallToolStream(ctx, "gnmi.gnmi_subscribe", map[string]any{ - "path": "/interfaces/interface/eth0", - "mode": "STREAM", + "path": "/interfaces/interface[name=eth0]", + "mode": "STREAM", + "sub_mode": "SAMPLE", + "sample_interval_ns": 500000000, }) if err != nil { log.Fatalf("call stream: %v", err) diff --git a/examples/grpc_gnmi_client/provider.json b/examples/grpc_gnmi_client/provider.json index 7636072..17b2206 100644 --- a/examples/grpc_gnmi_client/provider.json +++ b/examples/grpc_gnmi_client/provider.json @@ -6,7 +6,12 @@ "host": "127.0.0.1", "port": 9339, "service_name": "gnmi.gNMI", - "method_name": "Subscribe" + "method_name": "Subscribe", + "auth": { + "auth_type": "basic", + "username": "alice", + "password": "secret" + } } ] } diff --git a/examples/grpc_gnmi_transport/main.go b/examples/grpc_gnmi_transport/main.go index 872b2ff..a0dad2d 100644 --- a/examples/grpc_gnmi_transport/main.go +++ b/examples/grpc_gnmi_transport/main.go @@ -2,18 +2,22 @@ package main import ( "context" + "encoding/base64" "encoding/json" "fmt" "log" "net" + "strings" "time" gnmi "github.com/openconfig/gnmi/proto/gnmi" + auth "github.com/universal-tool-calling-protocol/go-utcp/src/auth" "github.com/universal-tool-calling-protocol/go-utcp/src/grpcpb" . "github.com/universal-tool-calling-protocol/go-utcp/src/providers/base" providers "github.com/universal-tool-calling-protocol/go-utcp/src/providers/grpc" transports "github.com/universal-tool-calling-protocol/go-utcp/src/transports/grpc" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" ) type UnifiedServer struct { @@ -21,6 +25,49 @@ type UnifiedServer struct { grpcpb.UnimplementedUTCPServiceServer } +const ( + user = "alice" + pass = "secret" +) + +func authFromContext(ctx context.Context) error { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return fmt.Errorf("missing metadata") + } + vals := md.Get("authorization") + if len(vals) == 0 { + return fmt.Errorf("unauthorized") + } + parts := strings.SplitN(vals[0], " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Basic") { + return fmt.Errorf("unauthorized") + } + decoded, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return fmt.Errorf("unauthorized") + } + up := strings.SplitN(string(decoded), ":", 2) + if len(up) != 2 || up[0] != user || up[1] != pass { + return fmt.Errorf("unauthorized") + } + return nil +} + +func unaryAuthInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if err := authFromContext(ctx); err != nil { + return nil, err + } + return handler(ctx, req) +} + +func streamAuthInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if err := authFromContext(ss.Context()); err != nil { + return err + } + return handler(srv, ss) +} + func (s *UnifiedServer) CallTool(ctx context.Context, req *grpcpb.ToolCallRequest) (*grpcpb.ToolCallResponse, error) { // Simple implementation - could be expanded based on tool name return &grpcpb.ToolCallResponse{ @@ -52,10 +99,12 @@ func (s *UnifiedServer) CallToolStream(req *grpcpb.ToolCallRequest, stream grpcp // Create a mock update response update := map[string]interface{}{ - "timestamp": time.Now().UnixNano(), - "path": args["path"], - "value": fmt.Sprintf("mock_value_%d", counter), - "mode": args["mode"], + "timestamp": time.Now().UnixNano(), + "path": args["path"], + "value": fmt.Sprintf("mock_value_%d", counter), + "mode": args["mode"], + "sub_mode": args["sub_mode"], + "sample_interval_ns": args["sample_interval_ns"], } updateJson, err := json.Marshal(update) @@ -97,14 +146,19 @@ func (s *UnifiedServer) GetManual(ctx context.Context, e *grpcpb.Empty) (*grpcpb } func (s *UnifiedServer) Subscribe(stream gnmi.GNMI_SubscribeServer) error { + ctx := stream.Context() + if _, err := stream.Recv(); err != nil { return err } resp := &gnmi.SubscribeResponse{ Response: &gnmi.SubscribeResponse_Update{ Update: &gnmi.Notification{Update: []*gnmi.Update{{ - Path: &gnmi.Path{Element: []string{"interfaces", "interface", "eth0"}}, - Val: &gnmi.TypedValue{Value: &gnmi.TypedValue_StringVal{StringVal: "UP"}}, + Path: &gnmi.Path{Elem: []*gnmi.PathElem{ + {Name: "interfaces"}, + {Name: "interface", Key: map[string]string{"name": "eth0"}}, + }}, + Val: &gnmi.TypedValue{Value: &gnmi.TypedValue_StringVal{StringVal: "UP"}}, }}}, }, } @@ -116,7 +170,10 @@ func startGNMIServer(addr string) *grpc.Server { if err != nil { log.Fatalf("listen: %v", err) } - srv := grpc.NewServer() + srv := grpc.NewServer( + grpc.UnaryInterceptor(unaryAuthInterceptor), + grpc.StreamInterceptor(streamAuthInterceptor), + ) gnmi.RegisterGNMIServer(srv, &UnifiedServer{}) grpcpb.RegisterUTCPServiceServer(srv, &UnifiedServer{}) go srv.Serve(lis) @@ -130,14 +187,15 @@ func main() { logger := func(format string, args ...interface{}) { log.Printf(format, args...) } tr := transports.NewGRPCClientTransport(logger) - prov := &providers.GRPCProvider{BaseProvider: BaseProvider{Name: "g", ProviderType: ProviderGRPC}, Host: "127.0.0.1", Port: 9339, ServiceName: "gnmi.gNMI", MethodName: "Subscribe"} + var a auth.Auth = auth.NewBasicAuth(user, pass) + prov := &providers.GRPCProvider{BaseProvider: BaseProvider{Name: "g", ProviderType: ProviderGRPC}, Host: "127.0.0.1", Port: 9339, ServiceName: "gnmi.gNMI", MethodName: "Subscribe", Auth: &a} ctx := context.Background() if _, err := tr.RegisterToolProvider(ctx, prov); err != nil { log.Fatalf("register: %v", err) } - stream, err := tr.CallToolStream(ctx, "gnmi_subscribe", map[string]any{"path": "/interfaces/interface/eth0", "mode": "STREAM"}, prov) + stream, err := tr.CallToolStream(ctx, "gnmi_subscribe", map[string]any{"path": "/interfaces/interface[name=eth0]", "mode": "STREAM", "sub_mode": "SAMPLE", "sample_interval_ns": 500000000}, prov) if err != nil { log.Fatalf("call stream: %v", err) } diff --git a/src/transports/grpc/grpc_transport.go b/src/transports/grpc/grpc_transport.go index 74be592..1e25cb6 100644 --- a/src/transports/grpc/grpc_transport.go +++ b/src/transports/grpc/grpc_transport.go @@ -2,6 +2,7 @@ package grpc import ( "context" + "encoding/base64" "errors" "fmt" "io" @@ -16,6 +17,7 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/protobuf/encoding/protojson" + auth "github.com/universal-tool-calling-protocol/go-utcp/src/auth" "github.com/universal-tool-calling-protocol/go-utcp/src/grpcpb" . "github.com/universal-tool-calling-protocol/go-utcp/src/providers/base" . "github.com/universal-tool-calling-protocol/go-utcp/src/providers/grpc" @@ -24,6 +26,18 @@ import ( . "github.com/universal-tool-calling-protocol/go-utcp/src/tools" ) +// addAuthToContext adds authentication metadata to the context if required +func (t *GRPCClientTransport) addAuthToContext(ctx context.Context, prov *GRPCProvider) context.Context { + if prov.Auth != nil { + switch a := (*prov.Auth).(type) { + case *auth.BasicAuth: + token := base64.StdEncoding.EncodeToString([]byte(a.Username + ":" + a.Password)) + ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Basic "+token) + } + } + return ctx +} + // GRPCClientTransport implements ClientTransport over gRPC using the UTCPService. // It expects the remote server to implement the grpcpb.UTCPService service. type GRPCClientTransport struct { @@ -66,7 +80,9 @@ func (t *GRPCClientTransport) dial(ctx context.Context, prov *GRPCProvider) (*gr return nil, errors.New("SSL not implemented") } else { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + t.logger("Using insecure gRPC transport for %s, suitable only for non-production", addr) } + return grpc.DialContext(ctx, addr, opts...) } @@ -77,8 +93,9 @@ func (t *GRPCClientTransport) RegisterToolProvider(ctx context.Context, prov Pro return nil, errors.New("GRPCClientTransport can only be used with GRPCProvider") } - // Add target to context if specified + // Add target and auth metadata to context if specified ctx = t.addTargetToContext(ctx, gp) + ctx = t.addAuthToContext(ctx, gp) conn, err := t.dial(ctx, gp) if err != nil { @@ -113,8 +130,9 @@ func (t *GRPCClientTransport) CallTool(ctx context.Context, toolName string, arg return nil, errors.New("GRPCClientTransport can only be used with GRPCProvider") } - // Add target to context if specified + // Add target and auth metadata to context if specified ctx = t.addTargetToContext(ctx, gp) + ctx = t.addAuthToContext(ctx, gp) conn, err := t.dial(ctx, gp) if err != nil { @@ -157,9 +175,11 @@ func (t *GRPCClientTransport) CallToolStream( // Route to appropriate streaming implementation if gp.ServiceName == "gnmi.gNMI" && gp.MethodName == "Subscribe" { + ctx = t.addAuthToContext(ctx, gp) return t.callGNMISubscribe(ctx, args, gp) } + ctx = t.addAuthToContext(ctx, gp) return t.callUTCPToolStream(ctx, toolName, args, gp) } @@ -216,31 +236,72 @@ func (t *GRPCClientTransport) callGNMISubscribe( // buildSubscribeRequest constructs a gNMI SubscribeRequest from arguments func (t *GRPCClientTransport) buildSubscribeRequest(args map[string]any, gp *GRPCProvider) (*gnmi.SubscribeRequest, error) { pathStr, _ := args["path"].(string) - modeStr, _ := args["mode"].(string) + listModeStr, _ := args["mode"].(string) // ONCE | POLL | STREAM - subMode := gnmi.SubscriptionList_STREAM - switch strings.ToUpper(modeStr) { + // List (outer) mode + listMode := gnmi.SubscriptionList_STREAM + switch strings.ToUpper(listModeStr) { case "ONCE": - subMode = gnmi.SubscriptionList_ONCE + listMode = gnmi.SubscriptionList_ONCE case "POLL": - subMode = gnmi.SubscriptionList_POLL + listMode = gnmi.SubscriptionList_POLL + } + + // Per-subscription mode + subMode := gnmi.SubscriptionMode_SAMPLE + if v, ok := args["sub_mode"].(string); ok { + switch strings.ToUpper(v) { + case "SAMPLE": + subMode = gnmi.SubscriptionMode_SAMPLE + case "ON_CHANGE": + subMode = gnmi.SubscriptionMode_ON_CHANGE + case "TARGET_DEFINED": + subMode = gnmi.SubscriptionMode_TARGET_DEFINED + } } + // Optional intervals / flags + toUint64 := func(x any) uint64 { + switch n := x.(type) { + case int: + return uint64(n) + case int64: + return uint64(n) + case float64: + return uint64(n) + case uint64: + return n + default: + return 0 + } + } + sampleInterval := toUint64(args["sample_interval_ns"]) + heartbeatInterval := toUint64(args["heartbeat_interval_ns"]) + suppressRedundant, _ := args["suppress_redundant"].(bool) + + // Path and subscription path := parseGNMIPath(pathStr) - subReq := &gnmi.SubscribeRequest{ + sub := &gnmi.Subscription{ + Path: path, + Mode: subMode, + SampleInterval: sampleInterval, + HeartbeatInterval: heartbeatInterval, + SuppressRedundant: suppressRedundant, + } + + req := &gnmi.SubscribeRequest{ Request: &gnmi.SubscribeRequest_Subscribe{ Subscribe: &gnmi.SubscriptionList{ - Mode: subMode, - Subscription: []*gnmi.Subscription{{Path: path}}, + Mode: listMode, + Subscription: []*gnmi.Subscription{sub}, }, }, } if gp.Target != "" { - subReq.GetSubscribe().Prefix = &gnmi.Path{Target: gp.Target} + req.GetSubscribe().Prefix = &gnmi.Path{Target: gp.Target} } - - return subReq, nil + return req, nil } // startPollingIfNeeded starts a polling goroutine for POLL mode subscriptions @@ -424,10 +485,45 @@ func (t *GRPCClientTransport) startUTCPReceiveLoop( // parseGNMIPath parses a path string into a gNMI Path func parseGNMIPath(p string) *gnmi.Path { - p = strings.TrimPrefix(p, "/") - if p == "" { + p = strings.TrimSpace(p) + if p == "" || p == "/" { return &gnmi.Path{} } - elems := strings.Split(p, "/") - return &gnmi.Path{Element: elems} + p = strings.TrimPrefix(p, "/") + segs := strings.Split(p, "/") + + elems := make([]*gnmi.PathElem, 0, len(segs)) + for _, seg := range segs { + seg = strings.TrimSpace(seg) + if seg == "" { + continue + } + name := seg + keys := map[string]string{} + + if i := strings.IndexRune(seg, '['); i >= 0 { + name = seg[:i] + rest := seg[i:] + for len(rest) > 0 { + if rest[0] != '[' { + break + } + end := strings.IndexRune(rest, ']') + if end <= 1 { + break + } + kv := rest[1:end] + rest = rest[end+1:] + if eq := strings.IndexRune(kv, '='); eq > 0 && eq < len(kv)-1 { + k := kv[:eq] + v := kv[eq+1:] + keys[k] = v + } + } + } + + elems = append(elems, &gnmi.PathElem{Name: name, Key: keys}) + } + + return &gnmi.Path{Elem: elems} } diff --git a/src/transports/grpc/grpc_transport_additional_test.go b/src/transports/grpc/grpc_transport_additional_test.go new file mode 100644 index 0000000..cb66f22 --- /dev/null +++ b/src/transports/grpc/grpc_transport_additional_test.go @@ -0,0 +1,45 @@ +package grpc + +import ( + "testing" + + gnmi "github.com/openconfig/gnmi/proto/gnmi" + provgrpc "github.com/universal-tool-calling-protocol/go-utcp/src/providers/grpc" +) + +func TestParseGNMIPath(t *testing.T) { + p := parseGNMIPath("/interfaces/interface[name=Ethernet2][subif=0]/state/oper-status") + if len(p.GetElem()) != 4 { + t.Fatalf("expected 4 elems, got %d", len(p.GetElem())) + } + if p.GetElem()[1].GetName() != "interface" { + t.Fatalf("unexpected second element name: %s", p.GetElem()[1].GetName()) + } + if p.GetElem()[1].GetKey()["name"] != "Ethernet2" || p.GetElem()[1].GetKey()["subif"] != "0" { + t.Fatalf("unexpected keys: %v", p.GetElem()[1].GetKey()) + } +} + +func TestBuildSubscribeRequest_SubMode(t *testing.T) { + tpt := NewGRPCClientTransport(nil) + gp := &provgrpc.GRPCProvider{} + args := map[string]any{ + "path": "/interfaces/interface[name=eth0]/state/oper-status", + "mode": "STREAM", + "sub_mode": "ON_CHANGE", + } + req, err := tpt.buildSubscribeRequest(args, gp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + sub := req.GetSubscribe().GetSubscription() + if len(sub) != 1 { + t.Fatalf("expected 1 subscription, got %d", len(sub)) + } + if sub[0].Mode != gnmi.SubscriptionMode_ON_CHANGE { + t.Fatalf("expected sub mode ON_CHANGE, got %v", sub[0].Mode) + } + if req.GetSubscribe().Mode != gnmi.SubscriptionList_STREAM { + t.Fatalf("expected list mode STREAM, got %v", req.GetSubscribe().Mode) + } +}