From c2443e84cbbd2b76c81da4c4cd57fcd6de9675ea Mon Sep 17 00:00:00 2001 From: MahdiOR Date: Sat, 21 Jun 2025 11:16:53 +0330 Subject: [PATCH] fix import in protobuf --- wrap/grpc.go | 178 ++++++++++++++++++++++++++++++++++++++++++----- wrap/template.go | 29 ++++---- 2 files changed, 177 insertions(+), 30 deletions(-) diff --git a/wrap/grpc.go b/wrap/grpc.go index 7db055c..53ecc00 100644 --- a/wrap/grpc.go +++ b/wrap/grpc.go @@ -3,6 +3,7 @@ package wrap import ( "bytes" "errors" + "fmt" "os" "path" "strings" @@ -30,11 +31,69 @@ var ( ErrWritingFile = errors.New("error writing the generated code to the file") ) +var googleProtobufs = map[string]string{ + "google.protobuf.Any": "anypb.Any", + "google.protobuf.Api": "apipb.Api", + "google.protobuf.Method": "apipb.Method", + "google.protobuf.Mixin": "apipb.Mixin", + "google.protobuf.FileDescriptorSet": "descriptorpb.FileDescriptorSet", + "google.protobuf.FileDescriptorProto": "descriptorpb.FileDescriptorProto", + "google.protobuf.DescriptorProto": "descriptorpb.DescriptorProto", + "google.protobuf.ExtensionRangeOptions": "descriptorpb.ExtensionRangeOptions", + "google.protobuf.FieldDescriptorProto": "descriptorpb.FieldDescriptorProto", + "google.protobuf.OneofDescriptorProto": "descriptorpb.OneofDescriptorProto", + "google.protobuf.EnumDescriptorProto": "descriptorpb.EnumDescriptorProto", + "google.protobuf.EnumValueDescriptorProto": "descriptorpb.EnumValueDescriptorProto", + "google.protobuf.ServiceDescriptorProto": "descriptorpb.ServiceDescriptorProto", + "google.protobuf.MethodDescriptorProto": "descriptorpb.MethodDescriptorProto", + "google.protobuf.FileOptions": "descriptorpb.FileOptions", + "google.protobuf.MessageOptions": "descriptorpb.MessageOptions", + "google.protobuf.FieldOptions": "descriptorpb.FieldOptions", + "google.protobuf.OneofOptions": "descriptorpb.OneofOptions", + "google.protobuf.EnumOptions": "descriptorpb.EnumOptions", + "google.protobuf.EnumValueOptions": "descriptorpb.EnumValueOptions", + "google.protobuf.ServiceOptions": "descriptorpb.ServiceOptions", + "google.protobuf.MethodOptions": "descriptorpb.MethodOptions", + "google.protobuf.UninterpretedOption": "descriptorpb.UninterpretedOption", + "google.protobuf.FeatureSet": "descriptorpb.FeatureSet", + "google.protobuf.FeatureSetDefaults": "descriptorpb.FeatureSetDefaults", + "google.protobuf.SourceCodeInfo": "descriptorpb.SourceCodeInfo", + "google.protobuf.GeneratedCodeInfo": "descriptorpb.GeneratedCodeInfo", + "google.protobuf.SymbolVisibility": "descriptorpb.SymbolVisibility", + "google.protobuf.Duration": "durationpb.Duration", + "google.protobuf.Empty": "emptypb.Empty", + "google.protobuf.FieldMask": "fieldmaskpb.FieldMask", + "google.protobuf.GoFeatures": "gofeaturespb.GoFeatures", + "google.protobuf.SourceContext": "sourcecontextpb.SourceContext", + "google.protobuf.Struct": "structpb.Struct", + "google.protobuf.Value": "structpb.Value", + "google.protobuf.NullValue": "structpb.NullValue", + "google.protobuf.ListValue": "structpb.ListValue", + "google.protobuf.Timestamp": "timestamppb.Timestamp", + "google.protobuf.Type": "typepb.Type", + "google.protobuf.Field": "typepb.Field", + "google.protobuf.Enum": "typepb.Enum", + "google.protobuf.EnumValue": "typepb.EnumValue", + "google.protobuf.Option": "typepb.Option", + "google.protobuf.Syntax": "typepb.Syntax", + "google.protobuf.DoubleValue": "wrapperspb.DoubleValue", + "google.protobuf.FloatValue": "wrapperspb.FloatValue", + "google.protobuf.Int64Value": "wrapperspb.Int64Value", + "google.protobuf.UInt64Value": "wrapperspb.UInt64Value", + "google.protobuf.Int32Value": "wrapperspb.Int32Value", + "google.protobuf.UInt32Value": "wrapperspb.UInt32Value", + "google.protobuf.BoolValue": "wrapperspb.BoolValue", + "google.protobuf.StringValue": "wrapperspb.StringValue", + "google.protobuf.BytesValue": "wrapperspb.BytesValue", +} + // ServiceMethod represents a method in a proto service. type ServiceMethod struct { Name string Request string Response string + RawRequest string + RawResponse string StreamsRequest bool StreamsResponse bool } @@ -50,8 +109,14 @@ type WrapperData struct { Package string Service string Methods []ServiceMethod - Requests []string + Requests []ServiceRequest Source string + Imports []string +} + +type ServiceRequest struct { + Request string + RawRequest string } type FileType struct { @@ -97,6 +162,7 @@ func generateWrapper(ctx *gofr.Context, options ...FileType) (any, error) { return nil, err } + imports := getImports(ctx, definition, protoPath) projectPath, packageName := getPackageAndProject(ctx, definition, protoPath) services := getServices(ctx, definition) requests := getRequests(ctx, services) @@ -108,6 +174,7 @@ func generateWrapper(ctx *gofr.Context, options ...FileType) (any, error) { Methods: service.Methods, Requests: uniqueRequestTypes(ctx, service.Methods), Source: path.Base(protoPath), + Imports: imports, } if err := generateFiles(ctx, projectPath, service.Name, &wrapperData, requests, options...); err != nil { @@ -128,7 +195,6 @@ func parseProtoFile(ctx *gofr.Context, protoPath string) (*proto.Proto, error) { return nil, ErrOpeningProtoFile } defer file.Close() - parser := proto.NewParser(file) definition, err := parser.Parse() @@ -142,7 +208,7 @@ func parseProtoFile(ctx *gofr.Context, protoPath string) (*proto.Proto, error) { // generateFiles generates files for a given service. func generateFiles(ctx *gofr.Context, projectPath, serviceName string, wrapperData *WrapperData, - requests []string, options ...FileType) error { + requests []ServiceRequest, options ...FileType) error { for _, option := range options { if option.FileSuffix == serverRequestFile { wrapperData.Requests = requests @@ -181,41 +247,47 @@ func getOutputFilePath(projectPath, serviceName, fileSuffix string) string { } // getRequests extracts all unique request types from the services. -func getRequests(ctx *gofr.Context, services []ProtoService) []string { - requests := make(map[string]bool) +func getRequests(ctx *gofr.Context, services []ProtoService) []ServiceRequest { + requests := make(map[string]ServiceRequest) for _, service := range services { for _, method := range service.Methods { - requests[method.Request] = true + requests[method.Request] = ServiceRequest{ + Request: method.Request, + RawRequest: method.RawRequest, + } } } ctx.Logger.Debugf("Extracted unique request types: %v", requests) - return mapKeysToSlice(requests) + return mapValuesToSlice(requests) } // uniqueRequestTypes extracts unique request types from methods. -func uniqueRequestTypes(ctx *gofr.Context, methods []ServiceMethod) []string { - requests := make(map[string]bool) +func uniqueRequestTypes(ctx *gofr.Context, methods []ServiceMethod) []ServiceRequest { + requests := make(map[string]ServiceRequest) for _, method := range methods { - requests[method.Request] = true // Include all request types + requests[method.Request] = ServiceRequest{ + Request: method.Request, + RawRequest: method.RawRequest, + } // Include all request types } ctx.Logger.Debugf("Extracted unique request types for methods: %v", requests) - return mapKeysToSlice(requests) + return mapValuesToSlice(requests) } // mapKeysToSlice converts a map's keys to a slice. -func mapKeysToSlice(m map[string]bool) []string { - keys := make([]string, 0, len(m)) - for key := range m { - keys = append(keys, key) +func mapValuesToSlice(m map[string]ServiceRequest) []ServiceRequest { + values := make([]ServiceRequest, 0, len(m)) + for _, value := range m { + values = append(values, value) } - return keys + return values } // executeTemplate executes a template with the provided data. @@ -282,6 +354,52 @@ func getPackageAndProject(ctx *gofr.Context, definition *proto.Proto, protoPath return projectPath, packageName } +// getImports extracts the import directories from google protobufs and relative go_package proto definitions. +func getImports(ctx *gofr.Context, definition *proto.Proto, protoPath string) []string { + imports := []string{} + googleImports := map[string]string{ + "google/protobuf/any.proto": "anypb \"google.golang.org/protobuf/types/known/anypb\"", + "google/protobuf/api.proto": "apipb \"google.golang.org/protobuf/types/known/apipb\"", + "google/protobuf/descriptor.proto": "descriptorpb \"google.golang.org/protobuf/types/descriptorpb\"", + "google/protobuf/duration.proto": "durationpb \"google.golang.org/protobuf/types/known/durationpb\"", + "google/protobuf/empty.proto": "emptypb \"google.golang.org/protobuf/types/known/emptypb\"", + "google/protobuf/field_mask.proto": "fieldmaskpb \"google.golang.org/protobuf/types/known/fieldmaskpb\"", + "google/protobuf/go_features.proto": "gofeaturespb \"google.golang.org/protobuf/types/gofeaturespb\"", + "google/protobuf/source_context.proto": "sourcecontextpb \"google.golang.org/protobuf/types/known/sourcecontextpb\"", + "google/protobuf/struct.proto": "structpb \"google.golang.org/protobuf/types/known/structpb\"", + "google/protobuf/timestamp.proto": "timestamppb \"google.golang.org/protobuf/types/known/timestamppb\"", + "google/protobuf/type.proto": "typepb \"google.golang.org/protobuf/types/known/typepb\"", + "google/protobuf/wrappers.proto": "wrapperspb \"google.golang.org/protobuf/types/known/wrapperspb\"", + } + for _, elem := range definition.Elements { + if imported, ok := elem.(*proto.Import); ok { + if googleImport, ok := googleImports[imported.Filename]; ok { + imports = append(imports, googleImport) + } else { + lastIndex := strings.LastIndex(protoPath, "/") + newProto, err := parseProtoFile(ctx, protoPath[:lastIndex+1]+imported.Filename) + if err != nil { + ctx.Logger.Errorf("Failed to parse imported proto file %s: %v", imported.Filename, err) + continue + } + packageSource := "" + packageName := "" + for _, newElem := range newProto.Elements { + if goPackage, ok := newElem.(*proto.Option); ok && goPackage.Name == "go_package" { + packageSource = goPackage.Constant.Source + } + if goPackage, ok := newElem.(*proto.Package); ok { + packageName = goPackage.Name + } + } + lastPiece := strings.LastIndex(packageName, ".") + imports = append(imports, fmt.Sprintf("%s \"%s\"", packageName[lastPiece+1:], packageSource)) + } + } + } + return imports +} + // getServices extracts services from the proto definition. func getServices(ctx *gofr.Context, definition *proto.Proto) []ProtoService { var services []ProtoService @@ -294,8 +412,10 @@ func getServices(ctx *gofr.Context, definition *proto.Proto) []ProtoService { if rpc, ok := element.(*proto.RPC); ok { service.Methods = append(service.Methods, ServiceMethod{ Name: rpc.Name, - Request: rpc.RequestType, - Response: rpc.ReturnsType, + Request: getProperType(rpc.RequestType), + Response: getProperType(rpc.ReturnsType), + RawRequest: getRawType(rpc.RequestType), + RawResponse: getRawType(rpc.ReturnsType), StreamsRequest: rpc.StreamsRequest, StreamsResponse: rpc.StreamsReturns, }) @@ -310,3 +430,25 @@ func getServices(ctx *gofr.Context, definition *proto.Proto) []ProtoService { return services } + +func getProperType(tpe string) string { + if strings.HasPrefix(tpe, "google.protobuf.") { + if protobuf, ok := googleProtobufs[tpe]; ok { + return protobuf + } else { + return tpe + } + } else if strings.Contains(tpe, ".") { + lastIndex := strings.LastIndex(tpe, ".") + submoduleIndex := strings.LastIndex(tpe[:lastIndex], ".") + if submoduleIndex != -1 { + return fmt.Sprintf("%s.%s", tpe[submoduleIndex+1:lastIndex], tpe[lastIndex+1:]) + } + } + return tpe +} + +func getRawType(tpe string) string { + lastIndex := strings.LastIndex(tpe, ".") + return tpe[lastIndex+1:] +} diff --git a/wrap/template.go b/wrap/template.go index 66ff672..169c698 100644 --- a/wrap/template.go +++ b/wrap/template.go @@ -18,11 +18,9 @@ package {{ .Package }} import ( "context" - "time" "gofr.dev/pkg/gofr" "gofr.dev/pkg/gofr/container" - gofrgRPC "gofr.dev/pkg/gofr/grpc" "google.golang.org/grpc" {{- if $hasUnary }} @@ -31,6 +29,10 @@ import ( {{- end }} healthpb "google.golang.org/grpc/health/grpc_health_v1" + + {{- range .Imports }} + {{ . }} + {{- end }} ) // New{{ .Service }}GoFrServer creates a new instance of {{ .Service }}GoFrServer @@ -267,7 +269,7 @@ func (h *{{ $.Service }}ServerWrapper) {{ .Name }}(stream {{ $.Service }}_{{ .Na {{- else }} // Unary method handler for {{ .Name }} func (h *{{ $.Service }}ServerWrapper) {{ .Name }}(ctx context.Context, req *{{ .Request }}) (*{{ .Response }}, error) { - gctx := h.getGofrContext(ctx, &{{ .Request }}Wrapper{ctx: ctx, {{ .Request }}: req}) + gctx := h.getGofrContext(ctx, &{{ .RawRequest }}Wrapper{ctx: ctx, {{ .RawRequest }}: req}) res, err := h.server.{{ .Name }}(gctx) if err != nil { @@ -324,34 +326,37 @@ import ( "context" "fmt" "reflect" + {{- range .Imports }} + {{ . }} + {{- end }} ) // Request Wrappers {{- range $request := .Requests }} -type {{ $request }}Wrapper struct { +type {{ $request.RawRequest }}Wrapper struct { ctx context.Context - *{{ $request }} + *{{ $request.Request }} } -func (h *{{ $request }}Wrapper) Context() context.Context { +func (h *{{ $request.RawRequest }}Wrapper) Context() context.Context { return h.ctx } -func (h *{{ $request }}Wrapper) Param(s string) string { +func (h *{{ $request.RawRequest }}Wrapper) Param(s string) string { return "" } -func (h *{{ $request }}Wrapper) PathParam(s string) string { +func (h *{{ $request.RawRequest }}Wrapper) PathParam(s string) string { return "" } -func (h *{{ $request }}Wrapper) Bind(p interface{}) error { +func (h *{{ $request.RawRequest }}Wrapper) Bind(p interface{}) error { ptr := reflect.ValueOf(p) if ptr.Kind() != reflect.Ptr { return fmt.Errorf("expected a pointer, got %T", p) } - hValue := reflect.ValueOf(h.{{ $request }}).Elem() + hValue := reflect.ValueOf(h.{{ $request.RawRequest }}).Elem() ptrValue := ptr.Elem() for i := 0; i < hValue.NumField(); i++ { @@ -368,11 +373,11 @@ func (h *{{ $request }}Wrapper) Bind(p interface{}) error { return nil } -func (h *{{ $request }}Wrapper) HostName() string { +func (h *{{ $request.RawRequest }}Wrapper) HostName() string { return "" } -func (h *{{ $request }}Wrapper) Params(s string) []string { +func (h *{{ $request.RawRequest }}Wrapper) Params(s string) []string { return nil } {{- end }}`