From ddd408f2078811b6f68cfe4a964088b65d714b9f Mon Sep 17 00:00:00 2001 From: ShohamBit Date: Wed, 13 Nov 2024 12:49:14 +0000 Subject: [PATCH 1/2] support one flag for connection to tracee --- cmd/event.go | 2 ++ cmd/event_test.go | 10 +++++----- cmd/root.go | 27 +++++++++++++-------------- cmd/stream.go | 1 - pkg/client/client.go | 44 +++++++++++++++++++++++++++++++++++++++----- 5 files changed, 59 insertions(+), 25 deletions(-) diff --git a/cmd/event.go b/cmd/event.go index 23a5653..2f3f6ef 100644 --- a/cmd/event.go +++ b/cmd/event.go @@ -110,6 +110,7 @@ func listEvents(cmd *cobra.Command, args []string) { var traceeClient client.ServiceClient if err := traceeClient.NewServiceClient(serverInfo); err != nil { cmd.PrintErrln("Error creating client: ", err) + return } defer traceeClient.CloseConnection() response, err := traceeClient.GetEventDefinitions(context.Background(), &pb.GetEventDefinitionsRequest{EventNames: args}) @@ -134,6 +135,7 @@ func getEventDescriptions(cmd *cobra.Command, args []string) { var traceeClient client.ServiceClient if err := traceeClient.NewServiceClient(serverInfo); err != nil { cmd.PrintErrln("Error creating client: ", err) + return } defer traceeClient.CloseConnection() response, err := traceeClient.GetEventDefinitions(context.Background(), &pb.GetEventDefinitionsRequest{EventNames: args}) diff --git a/cmd/event_test.go b/cmd/event_test.go index 75e9763..7e08774 100644 --- a/cmd/event_test.go +++ b/cmd/event_test.go @@ -29,20 +29,20 @@ func TestEvent(t *testing.T) { { TestName: "No events describe", OutputSlice: []string{"event", "describe", "--format", "json"}, - ExpectedPrinter: "", + ExpectedPrinter: nil, ExpectedError: fmt.Errorf("accepts 1 arg(s), received 0"), }, { - TestName: "event describe event", + TestName: "describe ", OutputSlice: []string{"event", "describe", "event_test1", "--format", "json"}, ExpectedPrinter: "event_test1", ExpectedError: nil, }, //event enable { - TestName: "No events enable", + TestName: "No events enable", OutputSlice: []string{"event", "enable"}, - ExpectedPrinter: "", + ExpectedPrinter: nil, ExpectedError: fmt.Errorf("accepts 1 arg(s), received 0"), // Update expected output }, @@ -56,7 +56,7 @@ func TestEvent(t *testing.T) { { TestName: "No disable events", OutputSlice: []string{"event", "disable"}, - ExpectedPrinter: "", + ExpectedPrinter: nil, ExpectedError: fmt.Errorf("accepts 1 arg(s), received 0"), // Update expected output }, { diff --git a/cmd/root.go b/cmd/root.go index c41f86e..158a593 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "fmt" "os" "github.com/ShohamBit/traceectl/pkg/client" @@ -10,12 +11,13 @@ import ( "github.com/spf13/cobra" ) -// var outputFlag string +var formatFlag string +var outputFlag string +var serverFlag string var ( serverInfo client.ServerInfo = client.ServerInfo{ ConnectionType: client.PROTOCOL_UNIX, - UnixSocketPath: client.SOCKET, - ADDR: client.DefaultIP + ":" + client.DefaultPort, + ADDR: client.SOCKET, } rootCmd = &cobra.Command{ @@ -45,17 +47,13 @@ func init() { rootCmd.AddCommand(configCmd) rootCmd.AddCommand(versionCmd) - //flags - rootCmd.PersistentFlags().StringVarP(&serverInfo.ConnectionType, "connectionType", "c", client.PROTOCOL_UNIX, "Connection type (unix|tcp)") - rootCmd.RegisterFlagCompletionFunc("connectionType", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return []string{client.PROTOCOL_TCP, client.PROTOCOL_UNIX}, cobra.ShellCompDirectiveNoFileComp - }) - //TODO: add an option to ony use this flag par connection type - //unix connection type flag - rootCmd.PersistentFlags().StringVar(&serverInfo.UnixSocketPath, "socketPath", client.SOCKET, "Path of the unix socket") - //tcp connection type flag - rootCmd.PersistentFlags().StringVarP(&serverInfo.ADDR, "server", "s", client.DefaultIP+":"+client.DefaultPort, "The address and port of the Kubernetes API server") - //rootCmd.PersistentFlags().StringVarP(&outputFlag, "output", "o", "", "Specify the output file path (default is stdout)") //if empty stdout + //one global flag for server connection(connection Type: tcp or unix socket) + //no default for tcp, only for unix socket + //for tcp + //for unix socket + rootCmd.PersistentFlags().StringVar(&serverInfo.ADDR, "server", fmt.Sprintf("%s", client.SOCKET), `Server connection path or address. + for unix socket (default: /tmp/tracee.sock) + for tcp `) } @@ -154,6 +152,7 @@ func displayVersion(cmd *cobra.Command, _ []string) { var traceeClient client.ServiceClient if err := traceeClient.NewServiceClient(serverInfo); err != nil { cmd.PrintErrln("Error creating client: ", err) + return } defer traceeClient.CloseConnection() //get version diff --git a/cmd/stream.go b/cmd/stream.go index 1aea77b..e4dd087 100644 --- a/cmd/stream.go +++ b/cmd/stream.go @@ -127,7 +127,6 @@ func stream(cmd *cobra.Command, args []string) { err := traceeClient.NewServiceClient(serverInfo) if err != nil { cmd.PrintErrln("Error creating client: ", err) - traceeClient.CloseConnection() return } defer traceeClient.CloseConnection() diff --git a/pkg/client/client.go b/pkg/client/client.go index c11669b..afb0c71 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -3,6 +3,8 @@ package client import ( "fmt" "log" + "net" + "strings" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -14,14 +16,11 @@ const ( PROTOCOL_UNIX = "unix" PROTOCOL_TCP = "tcp" SOCKET = "/tmp/tracee.sock" - DefaultIP = "localhost" - DefaultPort = "4466" ) type ServerInfo struct { ConnectionType string // Field to specify connection type (e.g., "unix" or "tcp") - UnixSocketPath string // Path for the Unix socket, if using Unix connection - ADDR string // Address for the connection + ADDR string // Address for the connection // Path for the Unix socket, if using Unix connection or IP and port for tcp } // this function use grpc to connect the server @@ -33,10 +32,14 @@ func connectToServer(serverInfo ServerInfo) (*grpc.ClientConn, error) { // Use switch case to determine connection type var conn *grpc.ClientConn var err error + err = determineConnectionType(serverInfo) + if err != nil { + return nil, err + } switch serverInfo.ConnectionType { case PROTOCOL_UNIX: // Dial a Unix socket - address := fmt.Sprintf("unix://%s", serverInfo.UnixSocketPath) + address := fmt.Sprintf("unix://%s", serverInfo.ADDR) conn, err = grpc.NewClient(address, opts...) if err != nil { @@ -62,3 +65,34 @@ func connectToServer(serverInfo ServerInfo) (*grpc.ClientConn, error) { } return conn, nil } + +func determineConnectionType(serverInfo ServerInfo) error { + if strings.Contains(serverInfo.ADDR, ":") && isValidTCPAddress(serverInfo.ADDR) { + // It's a TCP address + serverInfo.ConnectionType = PROTOCOL_TCP + return nil + } + if strings.HasPrefix(serverInfo.ADDR, "/") { + // It's a Unix socket path + serverInfo.ConnectionType = PROTOCOL_UNIX + return nil + } + + return fmt.Errorf("unsupported connection type: %s", serverInfo.ADDR) + +} + +// isValidTCPAddress checks if the address is a valid IP:PORT format +func isValidTCPAddress(addr string) bool { + host, port, err := net.SplitHostPort(addr) + if err != nil || host == "" || port == "" { + return false + } + + // Validate port number + if _, err := net.LookupPort("tcp", port); err != nil { + return false + } + + return true +} From 2432c6e8f7f6b63ef1d6d5dafc7a9bae8dcb5182 Mon Sep 17 00:00:00 2001 From: ShohamBit Date: Wed, 13 Nov 2024 12:50:15 +0000 Subject: [PATCH 2/2] support tests with new flag and new connection configuration --- pkg/cmd/formatter/formatter.go | 5 +++-- pkg/mock/server.go | 13 ++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pkg/cmd/formatter/formatter.go b/pkg/cmd/formatter/formatter.go index 0bd0094..26d61c5 100644 --- a/pkg/cmd/formatter/formatter.go +++ b/pkg/cmd/formatter/formatter.go @@ -76,8 +76,9 @@ func initOutput(cmd *cobra.Command, output string) error { } } else { // If no file is specified, use stdout - cmd.SetOut(os.Stdout) - cmd.SetErr(os.Stderr) + //NOTE: those commands brakes test do nothing + //cmd.SetOut(os.Stdout) + //cmd.SetErr(os.Stderr) } return nil } diff --git a/pkg/mock/server.go b/pkg/mock/server.go index 2014d0c..25d901f 100644 --- a/pkg/mock/server.go +++ b/pkg/mock/server.go @@ -14,8 +14,7 @@ import ( var ( ExpectedVersion string = "v0.22.0-15-gd09d7fca0d" // Match the output format serverInfo client.ServerInfo = client.ServerInfo{ - ADDR: client.DefaultIP + ":" + client.DefaultPort, - UnixSocketPath: client.SOCKET, + ADDR: client.SOCKET, } ) @@ -32,14 +31,14 @@ type MockDiagnosticServer struct { // CreateMockServer initializes the gRPC server and binds it to a Unix socket listener func CreateMockServer() (*grpc.Server, net.Listener, error) { // Check for existing Unix socket and remove it if necessary - if _, err := os.Stat(serverInfo.UnixSocketPath); err == nil { - if err := os.Remove(serverInfo.UnixSocketPath); err != nil { - return nil, nil, fmt.Errorf("failed to cleanup gRPC listening address (%s): %v", serverInfo.UnixSocketPath, err) + if _, err := os.Stat(serverInfo.ADDR); err == nil { + if err := os.Remove(serverInfo.ADDR); err != nil { + return nil, nil, fmt.Errorf("failed to cleanup gRPC listening address (%s): %v", serverInfo.ADDR, err) } } // Create the Unix socket listener - listener, err := net.Listen("unix", serverInfo.UnixSocketPath) + listener, err := net.Listen("unix", serverInfo.ADDR) if err != nil { return nil, nil, fmt.Errorf("failed to create Unix socket listener: %v", err) } @@ -74,7 +73,7 @@ func StartMockServer() (*grpc.Server, error) { // StopMockServer stops the server and removes the Unix socket func StopMockServer(server *grpc.Server) { server.GracefulStop() - if err := os.Remove(serverInfo.UnixSocketPath); err != nil { + if err := os.Remove(serverInfo.ADDR); err != nil { fmt.Printf("failed to remove Unix socket: %v\n", err) } }