Skip to content

Commit 9c47862

Browse files
authored
feature/goctl-api-swagger (#4780)
1 parent 801c283 commit 9c47862

38 files changed

+2445
-23
lines changed

tools/goctl/api/cmd.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ import (
1010
"github.com/zeromicro/go-zero/tools/goctl/api/javagen"
1111
"github.com/zeromicro/go-zero/tools/goctl/api/ktgen"
1212
"github.com/zeromicro/go-zero/tools/goctl/api/new"
13+
"github.com/zeromicro/go-zero/tools/goctl/api/swagger"
1314
"github.com/zeromicro/go-zero/tools/goctl/api/tsgen"
1415
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
1516
"github.com/zeromicro/go-zero/tools/goctl/config"
1617
"github.com/zeromicro/go-zero/tools/goctl/internal/cobrax"
18+
"github.com/zeromicro/go-zero/tools/goctl/pkg/env"
1719
"github.com/zeromicro/go-zero/tools/goctl/plugin"
1820
)
1921

@@ -31,6 +33,7 @@ var (
3133
ktCmd = cobrax.NewCommand("kt", cobrax.WithRunE(ktgen.KtCommand))
3234
pluginCmd = cobrax.NewCommand("plugin", cobrax.WithRunE(plugin.PluginCommand))
3335
tsCmd = cobrax.NewCommand("ts", cobrax.WithRunE(tsgen.TsCommand))
36+
swaggerCmd = cobrax.NewCommand("swagger", cobrax.WithRunE(swagger.Command))
3437
)
3538

3639
func init() {
@@ -46,6 +49,7 @@ func init() {
4649
pluginCmdFlags = pluginCmd.Flags()
4750
tsCmdFlags = tsCmd.Flags()
4851
validateCmdFlags = validateCmd.Flags()
52+
swaggerCmdFlags = swaggerCmd.Flags()
4953
)
5054

5155
apiCmdFlags.StringVar(&apigen.VarStringOutput, "o")
@@ -97,8 +101,15 @@ func init() {
97101
tsCmdFlags.StringVar(&tsgen.VarStringCaller, "caller")
98102
tsCmdFlags.BoolVar(&tsgen.VarBoolUnWrap, "unwrap")
99103

104+
swaggerCmdFlags.StringVar(&swagger.VarStringAPI, "api")
105+
swaggerCmdFlags.StringVar(&swagger.VarStringDir, "dir")
106+
swaggerCmdFlags.BoolVar(&swagger.VarBoolYaml, "yaml")
107+
100108
validateCmdFlags.StringVar(&validate.VarStringAPI, "api")
101109

102110
// Add sub-commands
103111
Cmd.AddCommand(dartCmd, docCmd, formatCmd, goCmd, javaCmd, ktCmd, newCmd, pluginCmd, tsCmd, validateCmd)
112+
if env.UseExperimental() {
113+
Cmd.AddCommand(swaggerCmd)
114+
}
104115
}

tools/goctl/api/spec/spec.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ type (
2121

2222
// ApiSpec describes an api file
2323
ApiSpec struct {
24-
Info Info // Deprecated: useless expression
24+
Info Info
2525
Syntax ApiSyntax // Deprecated: useless expression
2626
Imports []Import // Deprecated: useless expression
2727
Types []Type

tools/goctl/api/swagger/annotation.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package swagger
2+
3+
import (
4+
"strconv"
5+
6+
"github.com/zeromicro/go-zero/tools/goctl/util"
7+
"google.golang.org/grpc/metadata"
8+
)
9+
10+
func getBoolFromKVOrDefault(properties map[string]string, key string, def bool) bool {
11+
if len(properties) == 0 {
12+
return def
13+
}
14+
md := metadata.New(properties)
15+
val := md.Get(key)
16+
if len(val) == 0 {
17+
return def
18+
}
19+
str := util.Unquote(val[0])
20+
if len(str) == 0 {
21+
return def
22+
}
23+
res, _ := strconv.ParseBool(str)
24+
return res
25+
}
26+
27+
func getStringFromKVOrDefault(properties map[string]string, key string, def string) string {
28+
if len(properties) == 0 {
29+
return def
30+
}
31+
md := metadata.New(properties)
32+
val := md.Get(key)
33+
if len(val) == 0 {
34+
return def
35+
}
36+
str := util.Unquote(val[0])
37+
if len(str) == 0 {
38+
return def
39+
}
40+
return str
41+
}
42+
43+
func getListFromInfoOrDefault(properties map[string]string, key string, def []string) []string {
44+
if len(properties) == 0 {
45+
return def
46+
}
47+
md := metadata.New(properties)
48+
val := md.Get(key)
49+
if len(val) == 0 {
50+
return def
51+
}
52+
53+
str := util.Unquote(val[0])
54+
if len(str) == 0 {
55+
return def
56+
}
57+
resp := util.FieldsAndTrimSpace(str, commaRune)
58+
if len(resp) == 0 {
59+
return def
60+
}
61+
return resp
62+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package swagger
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func Test_getBoolFromKVOrDefault(t *testing.T) {
10+
properties := map[string]string{
11+
"enabled": `"true"`,
12+
"disabled": `"false"`,
13+
"invalid": `"notabool"`,
14+
"empty_value": `""`,
15+
}
16+
17+
assert.True(t, getBoolFromKVOrDefault(properties, "enabled", false))
18+
assert.False(t, getBoolFromKVOrDefault(properties, "disabled", true))
19+
assert.False(t, getBoolFromKVOrDefault(properties, "invalid", false))
20+
assert.True(t, getBoolFromKVOrDefault(properties, "missing", true))
21+
assert.False(t, getBoolFromKVOrDefault(properties, "empty_value", false))
22+
assert.False(t, getBoolFromKVOrDefault(nil, "nil", false))
23+
assert.False(t, getBoolFromKVOrDefault(map[string]string{}, "empty", false))
24+
}
25+
26+
func Test_getStringFromKVOrDefault(t *testing.T) {
27+
properties := map[string]string{
28+
"name": `"example"`,
29+
"empty": `""`,
30+
}
31+
32+
assert.Equal(t, "example", getStringFromKVOrDefault(properties, "name", "default"))
33+
assert.Equal(t, "default", getStringFromKVOrDefault(properties, "empty", "default"))
34+
assert.Equal(t, "default", getStringFromKVOrDefault(properties, "missing", "default"))
35+
assert.Equal(t, "default", getStringFromKVOrDefault(nil, "nil", "default"))
36+
assert.Equal(t, "default", getStringFromKVOrDefault(map[string]string{}, "empty", "default"))
37+
}
38+
39+
func Test_getListFromInfoOrDefault(t *testing.T) {
40+
properties := map[string]string{
41+
"list": `"a, b, c"`,
42+
"empty": `""`,
43+
}
44+
45+
assert.Equal(t, []string{"a", "b", "c"}, getListFromInfoOrDefault(properties, "list", []string{"default"}))
46+
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "empty", []string{"default"}))
47+
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "missing", []string{"default"}))
48+
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(nil, "nil", []string{"default"}))
49+
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(map[string]string{}, "empty", []string{"default"}))
50+
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(map[string]string{
51+
"foo": ",,",
52+
}, "foo", []string{"default"}))
53+
}

tools/goctl/api/swagger/api.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package swagger
2+
3+
import "github.com/zeromicro/go-zero/tools/goctl/api/spec"
4+
5+
func fillAllStructs(api *spec.ApiSpec) {
6+
var (
7+
tps []spec.Type
8+
structTypes = make(map[string]spec.DefineStruct)
9+
groups []spec.Group
10+
)
11+
for _, tp := range api.Types {
12+
structTypes[tp.Name()] = tp.(spec.DefineStruct)
13+
}
14+
15+
for _, tp := range api.Types {
16+
filledTP := fillStruct("", tp, structTypes)
17+
tps = append(tps, filledTP)
18+
structTypes[filledTP.Name()] = filledTP.(spec.DefineStruct)
19+
}
20+
21+
for _, group := range api.Service.Groups {
22+
var routes []spec.Route
23+
for _, route := range group.Routes {
24+
route.RequestType = fillStruct("", route.RequestType, structTypes)
25+
route.ResponseType = fillStruct("", route.ResponseType, structTypes)
26+
routes = append(routes, route)
27+
}
28+
group.Routes = routes
29+
groups = append(groups, group)
30+
}
31+
api.Service.Groups = groups
32+
api.Types = tps
33+
}
34+
35+
func fillStruct(parent string, tp spec.Type, allTypes map[string]spec.DefineStruct) spec.Type {
36+
switch val := tp.(type) {
37+
case spec.DefineStruct:
38+
var members []spec.Member
39+
for _, member := range val.Members {
40+
switch memberType := member.Type.(type) {
41+
case spec.PointerType:
42+
member.Type = spec.PointerType{
43+
RawName: memberType.RawName,
44+
Type: fillStruct(val.Name(), memberType.Type, allTypes),
45+
}
46+
case spec.ArrayType:
47+
member.Type = spec.ArrayType{
48+
RawName: memberType.RawName,
49+
Value: fillStruct(val.Name(), memberType.Value, allTypes),
50+
}
51+
case spec.MapType:
52+
member.Type = spec.MapType{
53+
RawName: memberType.RawName,
54+
Key: memberType.Key,
55+
Value: fillStruct(val.Name(), memberType.Value, allTypes),
56+
}
57+
case spec.DefineStruct:
58+
if parent != memberType.Name() { // avoid recursive struct
59+
if st, ok := allTypes[memberType.Name()]; ok {
60+
member.Type = fillStruct("", st, allTypes)
61+
}
62+
}
63+
case spec.NestedStruct:
64+
member.Type = fillStruct("", member.Type, allTypes)
65+
}
66+
members = append(members, member)
67+
}
68+
if len(members) == 0 {
69+
st, ok := allTypes[val.RawName]
70+
if ok {
71+
members = st.Members
72+
}
73+
}
74+
val.Members = members
75+
return val
76+
case spec.NestedStruct:
77+
var members []spec.Member
78+
for _, member := range val.Members {
79+
switch memberType := member.Type.(type) {
80+
case spec.PointerType:
81+
member.Type = spec.PointerType{
82+
RawName: memberType.RawName,
83+
Type: fillStruct(val.Name(), memberType.Type, allTypes),
84+
}
85+
case spec.ArrayType:
86+
member.Type = spec.ArrayType{
87+
RawName: memberType.RawName,
88+
Value: fillStruct(val.Name(), memberType.Value, allTypes),
89+
}
90+
case spec.MapType:
91+
member.Type = spec.MapType{
92+
RawName: memberType.RawName,
93+
Key: memberType.Key,
94+
Value: fillStruct(val.Name(), memberType.Value, allTypes),
95+
}
96+
case spec.DefineStruct:
97+
if parent != memberType.Name() { // avoid recursive struct
98+
if st, ok := allTypes[memberType.Name()]; ok {
99+
member.Type = fillStruct("", st, allTypes)
100+
}
101+
}
102+
case spec.NestedStruct:
103+
if parent != memberType.Name() {
104+
if st, ok := allTypes[memberType.Name()]; ok {
105+
member.Type = fillStruct("", st, allTypes)
106+
}
107+
}
108+
}
109+
members = append(members, member)
110+
}
111+
if len(members) == 0 {
112+
st, ok := allTypes[val.RawName]
113+
if ok {
114+
members = st.Members
115+
}
116+
}
117+
val.Members = members
118+
return val
119+
case spec.PointerType:
120+
return spec.PointerType{
121+
RawName: val.RawName,
122+
Type: fillStruct(parent, val.Type, allTypes),
123+
}
124+
case spec.ArrayType:
125+
return spec.ArrayType{
126+
RawName: val.RawName,
127+
Value: fillStruct(parent, val.Value, allTypes),
128+
}
129+
case spec.MapType:
130+
return spec.MapType{
131+
RawName: val.RawName,
132+
Key: val.Key,
133+
Value: fillStruct(parent, val.Value, allTypes),
134+
}
135+
default:
136+
return tp
137+
}
138+
}

tools/goctl/api/swagger/command.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package swagger
2+
3+
import (
4+
"encoding/json"
5+
"errors"
6+
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
7+
"os"
8+
"path/filepath"
9+
"strings"
10+
11+
"github.com/spf13/cobra"
12+
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/parser"
13+
"gopkg.in/yaml.v2"
14+
)
15+
16+
var (
17+
// VarStringAPI specifies the API filename.
18+
VarStringAPI string
19+
20+
// VarStringDir specifies the directory to generate swagger file.
21+
VarStringDir string
22+
23+
// VarBoolYaml specifies whether to generate a YAML file.
24+
VarBoolYaml bool
25+
)
26+
27+
func Command(_ *cobra.Command, _ []string) error {
28+
if len(VarStringAPI) == 0 {
29+
return errors.New("missing -api")
30+
}
31+
32+
if len(VarStringDir) == 0 {
33+
return errors.New("missing -dir")
34+
}
35+
36+
api, err := parser.Parse(VarStringAPI, "")
37+
if err != nil {
38+
return err
39+
}
40+
41+
fillAllStructs(api)
42+
43+
if err := api.Validate(); err != nil {
44+
return err
45+
}
46+
swagger, err := spec2Swagger(api)
47+
if err != nil {
48+
return err
49+
}
50+
data, err := json.MarshalIndent(swagger, "", " ")
51+
if err != nil {
52+
return err
53+
}
54+
55+
err = pathx.MkdirIfNotExist(VarStringDir)
56+
if err != nil {
57+
return err
58+
}
59+
60+
base := filepath.Base(VarStringAPI)
61+
if VarBoolYaml {
62+
filename := filepath.Join(VarStringDir, strings.TrimSuffix(base, filepath.Ext(base))+".yaml")
63+
64+
var jsonObj interface{}
65+
if err := yaml.Unmarshal(data, &jsonObj); err != nil {
66+
return err
67+
}
68+
69+
data, err := yaml.Marshal(jsonObj)
70+
if err != nil {
71+
return err
72+
}
73+
return os.WriteFile(filename, data, 0644)
74+
}
75+
// generate json swagger file
76+
filename := filepath.Join(VarStringDir, strings.TrimSuffix(base, filepath.Ext(base))+".json")
77+
78+
return os.WriteFile(filename, data, 0644)
79+
}

0 commit comments

Comments
 (0)