Skip to content

Commit ec9bc09

Browse files
authored
Merge pull request #12 from ColdWaterLW/support-1746
Support 1746
2 parents 3163799 + d7804d9 commit ec9bc09

28 files changed

+6202
-4
lines changed

ast/context.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package ast
22

3+
import "fmt"
4+
35
type Context struct {
4-
QueryType string // select, insert, update, delete
5-
Variable map[string]string
6-
Sqls map[string]*SqlNode
6+
QueryType string // select, insert, update, delete
7+
Variable map[string]string
8+
Sqls map[string]*SqlNode
9+
DefaultNamespace string // namespace of current mapper
710
}
811

912
func NewContext() *Context {
@@ -24,5 +27,10 @@ func (c *Context) SetVariable(k, v string) {
2427

2528
func (c *Context) GetSql(k string) (*SqlNode, bool) {
2629
sql, ok := c.Sqls[k]
30+
if ok {
31+
return sql, true
32+
}
33+
// 当存在跨namespace引用时,需要通过namespace区分引用的SQL id
34+
sql, ok = c.Sqls[fmt.Sprintf("%v.%v", c.DefaultNamespace, k)]
2735
return sql, ok
2836
}

ast/mapper.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ func (m *Mapper) GetStmt(ctx *Context) (string, error) {
7070

7171
func (m *Mapper) GetStmts(ctx *Context, skipErrorQuery bool) ([]string, error) {
7272
var stmts []string
73-
ctx.Sqls = m.SqlNodes
73+
if len(ctx.Sqls) == 0 {
74+
ctx.Sqls = m.SqlNodes
75+
}
76+
ctx.DefaultNamespace = m.NameSpace
7477
for _, a := range m.QueryNodes {
7578
data, err := a.GetStmt(ctx)
7679
if err == nil {

ast/mappers.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package ast
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
)
7+
8+
type Mappers struct {
9+
mappers []*Mapper
10+
}
11+
12+
func NewMappers() *Mappers {
13+
return &Mappers{}
14+
}
15+
16+
func (s *Mappers) AddMapper(ms ...*Mapper) error {
17+
for _, m := range ms {
18+
if m == nil {
19+
return errors.New("can not add null mapper to mappers")
20+
}
21+
s.mappers = append(s.mappers, m)
22+
}
23+
return nil
24+
}
25+
26+
func (s *Mappers) GetStmts(skipErrorQuery bool) ([]string, error) {
27+
ctx := NewContext()
28+
stmts := []string{}
29+
for _, m := range s.mappers {
30+
for id, node := range m.SqlNodes {
31+
ctx.Sqls[fmt.Sprintf("%v.%v", m.NameSpace, id)] = node
32+
}
33+
}
34+
35+
for _, m := range s.mappers {
36+
ctx.DefaultNamespace = m.NameSpace
37+
stmt, err := m.GetStmts(ctx, skipErrorQuery)
38+
if err != nil {
39+
return nil, fmt.Errorf("get sqls from mapper failed, namespace: %v, err: %v", m.NameSpace, err)
40+
}
41+
stmts = append(stmts, stmt...)
42+
}
43+
return stmts, nil
44+
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ go 1.14
55
require (
66
github.com/pingcap/parser v3.0.12+incompatible
77
github.com/pingcap/tidb v0.0.0-20200312110807-8c4696b3f340 // v3.0.12
8+
github.com/stretchr/testify v1.3.0
89
)

parser.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package parser
22

33
import (
44
"encoding/xml"
5+
"errors"
56
"fmt"
67
"io"
78
"strings"
@@ -27,6 +28,47 @@ func ParseXML(data string) (string, error) {
2728
return stmt, nil
2829
}
2930

31+
// ParseXMLs is a parser for parse all query in several XML files to []string one by one;
32+
// you can set `skipErrorQuery` true to ignore invalid query.
33+
func ParseXMLs(data []string, skipErrorQuery bool) ([]string, error) {
34+
ms := ast.NewMappers()
35+
for i := range data {
36+
r := strings.NewReader(data[i])
37+
d := xml.NewDecoder(r)
38+
n, err := parse(d)
39+
if err != nil {
40+
if skipErrorQuery {
41+
continue
42+
} else {
43+
return nil, err
44+
}
45+
}
46+
47+
if n == nil {
48+
continue
49+
}
50+
51+
m, ok := n.(*ast.Mapper)
52+
if !ok {
53+
if skipErrorQuery {
54+
continue
55+
} else {
56+
return nil, errors.New("the mapper is not found")
57+
}
58+
}
59+
err = ms.AddMapper(m)
60+
if err != nil && !skipErrorQuery {
61+
return nil, fmt.Errorf("add mapper failed: %v", err)
62+
}
63+
}
64+
stmts, err := ms.GetStmts(skipErrorQuery)
65+
if err != nil {
66+
return nil, err
67+
}
68+
69+
return stmts, nil
70+
}
71+
3072
// ParseXMLQuery is a parser for parse all query in XML to []string one by one;
3173
// you can set `skipErrorQuery` true to ignore invalid query.
3274
func ParseXMLQuery(data string, skipErrorQuery bool) ([]string, error) {

parser_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package parser
22

33
import (
44
"testing"
5+
6+
"github.com/stretchr/testify/assert"
57
)
68

79
func testParser(t *testing.T, xmlData, expect string) {
@@ -1046,3 +1048,62 @@ func TestOtherwise_issue1193(t *testing.T) {
10461048
`, "SELECT * FROM `fruits` WHERE `name`=? AND `price`=? AND `category`=?;",
10471049
)
10481050
}
1051+
1052+
func TestParseXMLs(t *testing.T) {
1053+
xmlCommonData := `
1054+
<?xml version="1.0" encoding="UTF-8"?><!--Converted at: Mon Jun 07 09:48:24 CST 2021-->
1055+
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
1056+
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
1057+
<mapper namespace="test.common">
1058+
<sql id="prefix">
1059+
SELECT * FROM (
1060+
</sql>
1061+
1062+
<sql id="suffix">
1063+
WHERE a=1 )
1064+
</sql>
1065+
1066+
<select id="sql1" parameterType="customer" resultMap="custResultMap">
1067+
<include refid="prefix"/>
1068+
SELECT a,b FROM tb1
1069+
<include refid="suffix"/>
1070+
</select>
1071+
</mapper>
1072+
`
1073+
xmlData := `
1074+
<?xml version="1.0" encoding="UTF-8"?><!--Converted at: Tue May 10 15:50:21 CST 2022-->
1075+
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
1076+
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
1077+
<mapper namespace="employee">
1078+
<select id="queryEmpHireSepList" parameterType="employee" resultType="employeeResult">
1079+
<include refid="test.common.prefix"/>
1080+
SELECT a,b FROM tb1
1081+
<include refid="test.common.suffix"/>
1082+
</select>
1083+
</mapper>
1084+
`
1085+
1086+
sqls, err := ParseXMLs([]string{xmlCommonData, xmlData}, false)
1087+
if err != nil {
1088+
if !assert.NoError(t, err) {
1089+
t.Fatal(err)
1090+
}
1091+
}
1092+
assert.Equal(t, 2, len(sqls))
1093+
assert.Equal(t, `
1094+
SELECT * FROM (
1095+
1096+
SELECT a,b FROM tb1
1097+
1098+
WHERE a=1 )
1099+
`, sqls[0])
1100+
1101+
assert.Equal(t, `
1102+
SELECT * FROM (
1103+
1104+
SELECT a,b FROM tb1
1105+
1106+
WHERE a=1 )
1107+
`, sqls[1])
1108+
1109+
}

vendor/github.com/davecgh/go-spew/LICENSE

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vendor/github.com/davecgh/go-spew/spew/bypass.go

Lines changed: 145 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)