Skip to content

fix import in protobuf #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 160 additions & 18 deletions wrap/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import (
"bytes"
"errors"
"fmt"
"os"
"path"
"strings"
Expand Down Expand Up @@ -30,11 +31,69 @@
ErrWritingFile = errors.New("error writing the generated code to the file")
)

var googleProtobufs = map[string]string{

Check failure on line 34 in wrap/grpc.go

View workflow job for this annotation

GitHub Actions / 🎖Code Quality️

googleProtobufs is a global variable (gochecknoglobals)
"google.protobuf.Any": "anypb.Any",
"google.protobuf.Api": "apipb.Api",
"google.protobuf.Method": "apipb.Method",
"google.protobuf.Mixin": "apipb.Mixin",
"google.protobuf.FileDescriptorSet": "descriptorpb.FileDescriptorSet",
"google.protobuf.FileDescriptorProto": "descriptorpb.FileDescriptorProto",
"google.protobuf.DescriptorProto": "descriptorpb.DescriptorProto",
"google.protobuf.ExtensionRangeOptions": "descriptorpb.ExtensionRangeOptions",
"google.protobuf.FieldDescriptorProto": "descriptorpb.FieldDescriptorProto",
"google.protobuf.OneofDescriptorProto": "descriptorpb.OneofDescriptorProto",
"google.protobuf.EnumDescriptorProto": "descriptorpb.EnumDescriptorProto",
"google.protobuf.EnumValueDescriptorProto": "descriptorpb.EnumValueDescriptorProto",
"google.protobuf.ServiceDescriptorProto": "descriptorpb.ServiceDescriptorProto",
"google.protobuf.MethodDescriptorProto": "descriptorpb.MethodDescriptorProto",
"google.protobuf.FileOptions": "descriptorpb.FileOptions",
"google.protobuf.MessageOptions": "descriptorpb.MessageOptions",
"google.protobuf.FieldOptions": "descriptorpb.FieldOptions",
"google.protobuf.OneofOptions": "descriptorpb.OneofOptions",
"google.protobuf.EnumOptions": "descriptorpb.EnumOptions",
"google.protobuf.EnumValueOptions": "descriptorpb.EnumValueOptions",
"google.protobuf.ServiceOptions": "descriptorpb.ServiceOptions",
"google.protobuf.MethodOptions": "descriptorpb.MethodOptions",
"google.protobuf.UninterpretedOption": "descriptorpb.UninterpretedOption",
"google.protobuf.FeatureSet": "descriptorpb.FeatureSet",
"google.protobuf.FeatureSetDefaults": "descriptorpb.FeatureSetDefaults",
"google.protobuf.SourceCodeInfo": "descriptorpb.SourceCodeInfo",
"google.protobuf.GeneratedCodeInfo": "descriptorpb.GeneratedCodeInfo",
"google.protobuf.SymbolVisibility": "descriptorpb.SymbolVisibility",
"google.protobuf.Duration": "durationpb.Duration",
"google.protobuf.Empty": "emptypb.Empty",
"google.protobuf.FieldMask": "fieldmaskpb.FieldMask",
"google.protobuf.GoFeatures": "gofeaturespb.GoFeatures",
"google.protobuf.SourceContext": "sourcecontextpb.SourceContext",
"google.protobuf.Struct": "structpb.Struct",
"google.protobuf.Value": "structpb.Value",
"google.protobuf.NullValue": "structpb.NullValue",
"google.protobuf.ListValue": "structpb.ListValue",
"google.protobuf.Timestamp": "timestamppb.Timestamp",
"google.protobuf.Type": "typepb.Type",
"google.protobuf.Field": "typepb.Field",
"google.protobuf.Enum": "typepb.Enum",
"google.protobuf.EnumValue": "typepb.EnumValue",
"google.protobuf.Option": "typepb.Option",
"google.protobuf.Syntax": "typepb.Syntax",
"google.protobuf.DoubleValue": "wrapperspb.DoubleValue",
"google.protobuf.FloatValue": "wrapperspb.FloatValue",
"google.protobuf.Int64Value": "wrapperspb.Int64Value",
"google.protobuf.UInt64Value": "wrapperspb.UInt64Value",
"google.protobuf.Int32Value": "wrapperspb.Int32Value",
"google.protobuf.UInt32Value": "wrapperspb.UInt32Value",
"google.protobuf.BoolValue": "wrapperspb.BoolValue",
"google.protobuf.StringValue": "wrapperspb.StringValue",
"google.protobuf.BytesValue": "wrapperspb.BytesValue",
}

// ServiceMethod represents a method in a proto service.
type ServiceMethod struct {
Name string
Request string
Response string
RawRequest string
RawResponse string
StreamsRequest bool
StreamsResponse bool
}
Expand All @@ -50,8 +109,14 @@
Package string
Service string
Methods []ServiceMethod
Requests []string
Requests []ServiceRequest
Source string
Imports []string
}

type ServiceRequest struct {
Request string
RawRequest string
}

type FileType struct {
Expand Down Expand Up @@ -97,6 +162,7 @@
return nil, err
}

imports := getImports(ctx, definition, protoPath)
projectPath, packageName := getPackageAndProject(ctx, definition, protoPath)
services := getServices(ctx, definition)
requests := getRequests(ctx, services)
Expand All @@ -108,6 +174,7 @@
Methods: service.Methods,
Requests: uniqueRequestTypes(ctx, service.Methods),
Source: path.Base(protoPath),
Imports: imports,
}

if err := generateFiles(ctx, projectPath, service.Name, &wrapperData, requests, options...); err != nil {
Expand All @@ -127,8 +194,7 @@
ctx.Logger.Errorf("Failed to open proto file: %v", err)
return nil, ErrOpeningProtoFile
}
defer file.Close()

Check failure on line 197 in wrap/grpc.go

View workflow job for this annotation

GitHub Actions / 🎖Code Quality️

only one cuddle assignment allowed before defer statement (wsl)

parser := proto.NewParser(file)

definition, err := parser.Parse()
Expand All @@ -142,7 +208,7 @@

// generateFiles generates files for a given service.
func generateFiles(ctx *gofr.Context, projectPath, serviceName string, wrapperData *WrapperData,
requests []string, options ...FileType) error {
requests []ServiceRequest, options ...FileType) error {
for _, option := range options {
if option.FileSuffix == serverRequestFile {
wrapperData.Requests = requests
Expand Down Expand Up @@ -181,41 +247,47 @@
}

// getRequests extracts all unique request types from the services.
func getRequests(ctx *gofr.Context, services []ProtoService) []string {
requests := make(map[string]bool)
func getRequests(ctx *gofr.Context, services []ProtoService) []ServiceRequest {
requests := make(map[string]ServiceRequest)

for _, service := range services {
for _, method := range service.Methods {
requests[method.Request] = true
requests[method.Request] = ServiceRequest{
Request: method.Request,
RawRequest: method.RawRequest,
}
}
}

ctx.Logger.Debugf("Extracted unique request types: %v", requests)

return mapKeysToSlice(requests)
return mapValuesToSlice(requests)
}

// uniqueRequestTypes extracts unique request types from methods.
func uniqueRequestTypes(ctx *gofr.Context, methods []ServiceMethod) []string {
requests := make(map[string]bool)
func uniqueRequestTypes(ctx *gofr.Context, methods []ServiceMethod) []ServiceRequest {
requests := make(map[string]ServiceRequest)

for _, method := range methods {
requests[method.Request] = true // Include all request types
requests[method.Request] = ServiceRequest{
Request: method.Request,
RawRequest: method.RawRequest,
} // Include all request types
}

ctx.Logger.Debugf("Extracted unique request types for methods: %v", requests)

return mapKeysToSlice(requests)
return mapValuesToSlice(requests)
}

// mapKeysToSlice converts a map's keys to a slice.
func mapKeysToSlice(m map[string]bool) []string {
keys := make([]string, 0, len(m))
for key := range m {
keys = append(keys, key)
func mapValuesToSlice(m map[string]ServiceRequest) []ServiceRequest {
values := make([]ServiceRequest, 0, len(m))
for _, value := range m {
values = append(values, value)
}

return keys
return values
}

// executeTemplate executes a template with the provided data.
Expand Down Expand Up @@ -270,7 +342,7 @@
func getPackageAndProject(ctx *gofr.Context, definition *proto.Proto, protoPath string) (projectPath, packageName string) {
proto.Walk(definition,
proto.WithOption(func(opt *proto.Option) {
if opt.Name == "go_package" {

Check failure on line 345 in wrap/grpc.go

View workflow job for this annotation

GitHub Actions / 🎖Code Quality️

string `go_package` has 2 occurrences, make it a constant (goconst)
packageName = path.Base(opt.Constant.Source)
}
}),
Expand All @@ -282,6 +354,52 @@
return projectPath, packageName
}

// getImports extracts the import directories from google protobufs and relative go_package proto definitions.
func getImports(ctx *gofr.Context, definition *proto.Proto, protoPath string) []string {
imports := []string{}
googleImports := map[string]string{
"google/protobuf/any.proto": "anypb \"google.golang.org/protobuf/types/known/anypb\"",
"google/protobuf/api.proto": "apipb \"google.golang.org/protobuf/types/known/apipb\"",
"google/protobuf/descriptor.proto": "descriptorpb \"google.golang.org/protobuf/types/descriptorpb\"",
"google/protobuf/duration.proto": "durationpb \"google.golang.org/protobuf/types/known/durationpb\"",
"google/protobuf/empty.proto": "emptypb \"google.golang.org/protobuf/types/known/emptypb\"",
"google/protobuf/field_mask.proto": "fieldmaskpb \"google.golang.org/protobuf/types/known/fieldmaskpb\"",
"google/protobuf/go_features.proto": "gofeaturespb \"google.golang.org/protobuf/types/gofeaturespb\"",
"google/protobuf/source_context.proto": "sourcecontextpb \"google.golang.org/protobuf/types/known/sourcecontextpb\"",
"google/protobuf/struct.proto": "structpb \"google.golang.org/protobuf/types/known/structpb\"",
"google/protobuf/timestamp.proto": "timestamppb \"google.golang.org/protobuf/types/known/timestamppb\"",
"google/protobuf/type.proto": "typepb \"google.golang.org/protobuf/types/known/typepb\"",
"google/protobuf/wrappers.proto": "wrapperspb \"google.golang.org/protobuf/types/known/wrapperspb\"",
}
Comment on lines +360 to +373
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move it out of the function ? As a constant ?

for _, elem := range definition.Elements {
if imported, ok := elem.(*proto.Import); ok {

Check failure on line 375 in wrap/grpc.go

View workflow job for this annotation

GitHub Actions / 🎖Code Quality️

`if ok` has complex nested blocks (complexity: 8) (nestif)
if googleImport, ok := googleImports[imported.Filename]; ok {
imports = append(imports, googleImport)
} else {
lastIndex := strings.LastIndex(protoPath, "/")
newProto, err := parseProtoFile(ctx, protoPath[:lastIndex+1]+imported.Filename)
if err != nil {
ctx.Logger.Errorf("Failed to parse imported proto file %s: %v", imported.Filename, err)
continue
}
packageSource := ""

Check failure on line 385 in wrap/grpc.go

View workflow job for this annotation

GitHub Actions / 🎖Code Quality️

assignments should only be cuddled with other assignments (wsl)
packageName := ""
for _, newElem := range newProto.Elements {

Check failure on line 387 in wrap/grpc.go

View workflow job for this annotation

GitHub Actions / 🎖Code Quality️

only one cuddle assignment allowed before range statement (wsl)
if goPackage, ok := newElem.(*proto.Option); ok && goPackage.Name == "go_package" {
packageSource = goPackage.Constant.Source
}
if goPackage, ok := newElem.(*proto.Package); ok {
packageName = goPackage.Name
}
}
lastPiece := strings.LastIndex(packageName, ".")

Check failure on line 395 in wrap/grpc.go

View workflow job for this annotation

GitHub Actions / 🎖Code Quality️

assignments should only be cuddled with other assignments (wsl)
imports = append(imports, fmt.Sprintf("%s \"%s\"", packageName[lastPiece+1:], packageSource))

Check failure on line 396 in wrap/grpc.go

View workflow job for this annotation

GitHub Actions / 🎖Code Quality️

sprintfQuotedString: use %q instead of "%s" for quoted strings (gocritic)
}
}
}
return imports
}

// getServices extracts services from the proto definition.
func getServices(ctx *gofr.Context, definition *proto.Proto) []ProtoService {
var services []ProtoService
Expand All @@ -294,8 +412,10 @@
if rpc, ok := element.(*proto.RPC); ok {
service.Methods = append(service.Methods, ServiceMethod{
Name: rpc.Name,
Request: rpc.RequestType,
Response: rpc.ReturnsType,
Request: getProperType(rpc.RequestType),
Response: getProperType(rpc.ReturnsType),
RawRequest: getRawType(rpc.RequestType),
RawResponse: getRawType(rpc.ReturnsType),
StreamsRequest: rpc.StreamsRequest,
StreamsResponse: rpc.StreamsReturns,
})
Expand All @@ -310,3 +430,25 @@

return services
}

func getProperType(tpe string) string {
if strings.HasPrefix(tpe, "google.protobuf.") {
if protobuf, ok := googleProtobufs[tpe]; ok {
return protobuf
} else {

Check failure on line 438 in wrap/grpc.go

View workflow job for this annotation

GitHub Actions / 🎖Code Quality️

indent-error-flow: if block ends with a return statement, so drop this else and outdent its block (move short variable declaration to its own line if necessary) (revive)
return tpe
}
} else if strings.Contains(tpe, ".") {
lastIndex := strings.LastIndex(tpe, ".")
submoduleIndex := strings.LastIndex(tpe[:lastIndex], ".")
if submoduleIndex != -1 {
return fmt.Sprintf("%s.%s", tpe[submoduleIndex+1:lastIndex], tpe[lastIndex+1:])
}
}
return tpe

Check failure on line 448 in wrap/grpc.go

View workflow job for this annotation

GitHub Actions / 🎖Code Quality️

return statements should not be cuddled if block has more than two lines (wsl)
}

func getRawType(tpe string) string {
lastIndex := strings.LastIndex(tpe, ".")
return tpe[lastIndex+1:]
}
29 changes: 17 additions & 12 deletions wrap/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@ package {{ .Package }}

import (
"context"
"time"

"gofr.dev/pkg/gofr"
"gofr.dev/pkg/gofr/container"
gofrgRPC "gofr.dev/pkg/gofr/grpc"
"google.golang.org/grpc"

{{- if $hasUnary }}
Expand All @@ -31,6 +29,10 @@ import (
{{- end }}

healthpb "google.golang.org/grpc/health/grpc_health_v1"

{{- range .Imports }}
{{ . }}
{{- end }}
)

// New{{ .Service }}GoFrServer creates a new instance of {{ .Service }}GoFrServer
Expand Down Expand Up @@ -267,7 +269,7 @@ func (h *{{ $.Service }}ServerWrapper) {{ .Name }}(stream {{ $.Service }}_{{ .Na
{{- else }}
// Unary method handler for {{ .Name }}
func (h *{{ $.Service }}ServerWrapper) {{ .Name }}(ctx context.Context, req *{{ .Request }}) (*{{ .Response }}, error) {
gctx := h.getGofrContext(ctx, &{{ .Request }}Wrapper{ctx: ctx, {{ .Request }}: req})
gctx := h.getGofrContext(ctx, &{{ .RawRequest }}Wrapper{ctx: ctx, {{ .RawRequest }}: req})

res, err := h.server.{{ .Name }}(gctx)
if err != nil {
Expand Down Expand Up @@ -324,34 +326,37 @@ import (
"context"
"fmt"
"reflect"
{{- range .Imports }}
{{ . }}
{{- end }}
)

// Request Wrappers
{{- range $request := .Requests }}
type {{ $request }}Wrapper struct {
type {{ $request.RawRequest }}Wrapper struct {
ctx context.Context
*{{ $request }}
*{{ $request.Request }}
}

func (h *{{ $request }}Wrapper) Context() context.Context {
func (h *{{ $request.RawRequest }}Wrapper) Context() context.Context {
return h.ctx
}

func (h *{{ $request }}Wrapper) Param(s string) string {
func (h *{{ $request.RawRequest }}Wrapper) Param(s string) string {
return ""
}

func (h *{{ $request }}Wrapper) PathParam(s string) string {
func (h *{{ $request.RawRequest }}Wrapper) PathParam(s string) string {
return ""
}

func (h *{{ $request }}Wrapper) Bind(p interface{}) error {
func (h *{{ $request.RawRequest }}Wrapper) Bind(p interface{}) error {
ptr := reflect.ValueOf(p)
if ptr.Kind() != reflect.Ptr {
return fmt.Errorf("expected a pointer, got %T", p)
}

hValue := reflect.ValueOf(h.{{ $request }}).Elem()
hValue := reflect.ValueOf(h.{{ $request.RawRequest }}).Elem()
ptrValue := ptr.Elem()

for i := 0; i < hValue.NumField(); i++ {
Expand All @@ -368,11 +373,11 @@ func (h *{{ $request }}Wrapper) Bind(p interface{}) error {
return nil
}

func (h *{{ $request }}Wrapper) HostName() string {
func (h *{{ $request.RawRequest }}Wrapper) HostName() string {
return ""
}

func (h *{{ $request }}Wrapper) Params(s string) []string {
func (h *{{ $request.RawRequest }}Wrapper) Params(s string) []string {
return nil
}
{{- end }}`
Expand Down
Loading