Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/vesoft-inc/nebula-importer/v4/pkg/errors"

"github.com/cenkalti/backoff/v4"
nebula "github.com/vesoft-inc/nebula-go/v3"
)

type (
Expand Down Expand Up @@ -136,3 +137,37 @@ func (c *defaultClient) Close() error {
}
return nil
}

func NewSessionPool(opts ...Option) (*nebula.SessionPool, error) {
ops := newOptions(opts...)
var (
hostAddresses []nebula.HostAddress
pool *nebula.SessionPool
)

for _, h := range ops.addresses {
hostPort := strings.Split(h, ":")
if len(hostPort) != 2 {
return nil, errors.ErrInvalidAddress
}
if hostPort[0] == "" {
return nil, errors.ErrInvalidAddress
}
port, err := strconv.Atoi(hostPort[1])
if err != nil {
err = errors.ErrInvalidAddress
}
hostAddresses = append(hostAddresses, nebula.HostAddress{Host: hostPort[0], Port: port})
}
conf, err := nebula.NewSessionPoolConf(ops.user, ops.password, hostAddresses,
"sf300_2", nebula.WithMaxSize(3000))
if err != nil {
return nil, err
}
pool, err = nebula.NewSessionPool(*conf, nebula.DefaultLogger{})
if err != nil {
return nil, err
}
return pool, nil

}
6 changes: 6 additions & 0 deletions pkg/config/base/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/vesoft-inc/nebula-importer/v4/pkg/client"
"github.com/vesoft-inc/nebula-importer/v4/pkg/errors"
"github.com/vesoft-inc/nebula-importer/v4/pkg/manager"
"github.com/vesoft-inc/nebula-importer/v4/pkg/utils"
)

Expand Down Expand Up @@ -85,6 +86,11 @@ func (c *Client) BuildClientPool(opts ...client.Option) (client.Pool, error) {
}
options = append(options, opts...)
pool := newClientPool(options...)
sessionPool, err := client.NewSessionPool(options...)
if err != nil {
return nil, err
}
manager.DefaultSessionPool = sessionPool
return pool, nil
}

Expand Down
1 change: 1 addition & 0 deletions pkg/config/v3/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ func (c *Config) Build() error {
if err != nil {
return err
}

mgr, err = c.Manager.BuildManager(l, pool, c.Sources,
manager.WithGetClientOptions(client.WithClientInitFunc(nil)), // clean the USE SPACE in 3.x
)
Expand Down
2 changes: 2 additions & 0 deletions pkg/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ var (
ErrUnsupportedFunction = stderrors.New("unsupported function")
ErrFilterSyntax = stderrors.New("filter syntax")
ErrUnsupportedMode = stderrors.New("unsupported mode")
ErrNoDynamicParam = stderrors.New("no dynamic param")
ErrFetchFailed = stderrors.New("fetch failed")
)
11 changes: 9 additions & 2 deletions pkg/manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"sync/atomic"
"time"

"github.com/panjf2000/ants/v2"
nebula "github.com/vesoft-inc/nebula-go/v3"
"github.com/vesoft-inc/nebula-importer/v4/pkg/client"
"github.com/vesoft-inc/nebula-importer/v4/pkg/errors"
"github.com/vesoft-inc/nebula-importer/v4/pkg/importer"
Expand All @@ -16,8 +18,6 @@ import (
"github.com/vesoft-inc/nebula-importer/v4/pkg/source"
"github.com/vesoft-inc/nebula-importer/v4/pkg/spec"
"github.com/vesoft-inc/nebula-importer/v4/pkg/stats"

"github.com/panjf2000/ants/v2"
)

const (
Expand Down Expand Up @@ -158,6 +158,7 @@ func WithLogger(l logger.Logger) Option {
}

func (m *defaultManager) Import(s source.Source, brr reader.BatchRecordReader, importers ...importer.Importer) error {

if len(importers) == 0 {
return nil
}
Expand Down Expand Up @@ -268,6 +269,9 @@ func (m *defaultManager) Stop() (err error) {
m.importerWaitGroup.Wait()

m.logStats()
if DefaultSessionPool != nil {
DefaultSessionPool.Close()
}
return m.After()
}

Expand Down Expand Up @@ -437,3 +441,6 @@ func (m *defaultManager) logError(err error, msg string, fields ...logger.Field)
fields = append(fields, logger.MapToFields(e.Fields())...)
m.logger.SkipCaller(1).WithError(e.Cause()).Error(msg, fields...)
}

// tmp var for test
var DefaultSessionPool *nebula.SessionPool
8 changes: 8 additions & 0 deletions pkg/spec/base/dynamicParam.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package specbase

type DynamicParam struct {
Address string `yaml:"address,omitempty"`
User string `yaml:"user,omitempty"`
Password string `yaml:"password,omitempty"`
Space string `yaml:"space,omitempty"`
}
11 changes: 6 additions & 5 deletions pkg/spec/base/mode.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package specbase
import "strings"

const (
DefaultMode = InsertMode
InsertMode Mode = "INSERT"
UpdateMode Mode = "UPDATE"
DeleteMode Mode = "DELETE"
DefaultMode = InsertMode
InsertMode Mode = "INSERT"
UpdateMode Mode = "UPDATE"
DeleteMode Mode = "DELETE"
BatchUpdateMode Mode = "BATCH_UPDATE"
)

type Mode string
Expand All @@ -19,5 +20,5 @@ func (m Mode) Convert() Mode {
}

func (m Mode) IsSupport() bool {
return m == InsertMode || m == UpdateMode || m == DeleteMode
return m == InsertMode || m == UpdateMode || m == DeleteMode || m == BatchUpdateMode
}
178 changes: 176 additions & 2 deletions pkg/spec/v3/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package specv3

import (
"fmt"
"sort"
"strings"

nebula "github.com/vesoft-inc/nebula-go/v3"
"github.com/vesoft-inc/nebula-importer/v4/pkg/bytebufferpool"
"github.com/vesoft-inc/nebula-importer/v4/pkg/errors"
"github.com/vesoft-inc/nebula-importer/v4/pkg/manager"
specbase "github.com/vesoft-inc/nebula-importer/v4/pkg/spec/base"
"github.com/vesoft-inc/nebula-importer/v4/pkg/utils"
)
Expand All @@ -21,13 +24,16 @@ type (

Filter *specbase.Filter `yaml:"filter,omitempty"`

Mode specbase.Mode `yaml:"mode,omitempty"`
Mode specbase.Mode `yaml:"mode,omitempty"`
DynamicParam *specbase.DynamicParam `yaml:"dynamicParam,omitempty"`

fnStatement func(records ...Record) (string, int, error)
fnStatement func(records ...Record) (string, int, error)
dynamicFnStatement func(pool *nebula.SessionPool, records ...Record) (string, int, error)
// "INSERT VERTEX name(prop_name, ..., prop_name) VALUES "
// "UPDATE VERTEX ON name "
// "DELETE TAG name FROM "
statementPrefix string
// session for batch update
}

Nodes []*Node
Expand Down Expand Up @@ -109,6 +115,11 @@ func (n *Node) Complete() {
case specbase.DeleteMode:
n.fnStatement = n.deleteStatement
n.statementPrefix = fmt.Sprintf("DELETE TAG %s FROM ", utils.ConvertIdentifier(n.Name))
case specbase.BatchUpdateMode:
//batch update, would fetch the node first.
//and then update the node with the props
//statementPrefix should be modified after fetch the node
n.fnStatement = n.updateBatchStatement
}
}

Expand Down Expand Up @@ -279,3 +290,166 @@ func (ns Nodes) Validate() error {
}
return nil
}

func (n *Node) updateBatchStatement(records ...Record) (statement string, nRecord int, err error) {
if n.DynamicParam == nil {
return "", 0, errors.ErrNoDynamicParam
}
buff := bytebufferpool.Get()
defer bytebufferpool.Put(buff)
var (
idValues []string
cols []string
needUpdateRecords []Record
)

for _, record := range records {
idValue, err := n.ID.Value(record)
if err != nil {
return "", 0, n.importError(err)
}
idValues = append(idValues, idValue)
propsSetValueList, err := n.Props.ValueList(record)
if err != nil {
return "", 0, err
}
needUpdateRecords = append(needUpdateRecords, propsSetValueList)
}
for _, prop := range n.Props {
cols = append(cols, prop.Name)
}

updatedCols, updatedRecords, err := n.genDynamicUpdateRecord(manager.DefaultSessionPool, idValues, cols, needUpdateRecords)
if err != nil {
return "", 0, err
}

// batch insert
// INSERT VERTEX %s(%s) VALUES
prefix := fmt.Sprintf("INSERT VERTEX %s(%s) VALUES ", utils.ConvertIdentifier(n.Name), strings.Join(updatedCols, ", "))
buff.SetString(prefix)

for index, record := range updatedRecords {
idValue := idValues[index]

if nRecord > 0 {
_, _ = buff.WriteString(", ")
}

// id:(prop_value1, prop_value2, ...)
_, _ = buff.WriteString(idValue)
_, _ = buff.WriteString(":(")
_, _ = buff.WriteStringSlice(record, ", ")
_, _ = buff.WriteString(")")

nRecord++
}
return buff.String(), nRecord, nil
}

// genDynamicUpdateRecord generate the update record for batch update
// return column values and records
func (n *Node) genDynamicUpdateRecord(pool *nebula.SessionPool, idValues []string, cols []string, records []Record) ([]string, []Record, error) {
stat := fmt.Sprintf("FETCH PROP ON %s %s YIELD VERTEX as v;", utils.ConvertIdentifier(n.Name), strings.Join(idValues, ","))
var (
rs *nebula.ResultSet
err error
updatedCols []string
updatedRecords []Record
)
for i := 0; i < 3; i++ {
rs, err = pool.Execute(stat)
if err != nil {
continue
}
if !rs.IsSucceed() {
continue
}
}
if err != nil {
return nil, nil, err
}
if !rs.IsSucceed() {
return nil, nil, fmt.Errorf(rs.GetErrorMsg())
}
fetchData, err := n.getNebulaFetchData(rs)
for _, property := range fetchData {
updatedCols = n.getDynamicUpdateCols(cols, property)
break
}
for index, id := range idValues {
originalData, ok := fetchData[id]
if !ok {
return nil, nil, fmt.Errorf("cannot find id, id: %s", id)
}
r := n.getUpdateRocord(originalData, updatedCols, records[index])
updatedRecords = append(updatedRecords, r)
}
return updatedCols, updatedRecords, nil
}

// append the need update column to the end of the cols
func (n *Node) getDynamicUpdateCols(updateCols []string, properties map[string]*nebula.ValueWrapper) []string {
needUpdate := make(map[string]struct{})
for _, c := range updateCols {
needUpdate[c] = struct{}{}
}
var cols []string
for k, _ := range properties {
if _, ok := needUpdate[k]; !ok {
cols = append(cols, k)
}
}
sort.Slice(cols, func(i, j int) bool {
return cols[i] < cols[j]
})
cols = append(cols, updateCols...)
return cols
}

func (n *Node) getNebulaFetchData(rs *nebula.ResultSet) (map[string]map[string]*nebula.ValueWrapper, error) {
m := make(map[string]map[string]*nebula.ValueWrapper)
for i := 0; i < rs.GetRowSize(); i++ {
row, err := rs.GetRowValuesByIndex(i)
if err != nil {
return nil, err
}
cell, err := row.GetValueByIndex(0)
if err != nil {
return nil, err
}
node, err := cell.AsNode()

if err != nil {
return nil, err
}
property, err := node.Properties(n.Name)
if err != nil {
return nil, err
}
m[node.GetID().String()] = property
}
return m, nil
}

func (n *Node) getUpdateRocord(original map[string]*nebula.ValueWrapper, Columns []string, update Record) Record {
r := make(Record, 0, len(Columns))
var vStr string
for _, c := range Columns {
value := original[c]

switch value.GetType() {
// TODO should handle other type
case "datetime":
vStr = fmt.Sprintf("datetime(\"%s\")", value.String())
default:
vStr = value.String()
}
r = append(r, vStr)
}
// update
for i := 0; i < len(update); i++ {
r[len(Columns)-len(update)+i] = update[i]
}
return r
}