Skip to content

Commit 8f936f0

Browse files
authored
fix saved model path (#1718)
1 parent 8c6d30b commit 8f936f0

File tree

5 files changed

+45
-19
lines changed

5 files changed

+45
-19
lines changed

go.mod

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ require (
1212

1313
github.com/fortytw2/leaktest v1.3.0
1414
github.com/go-delve/delve v1.3.2 // indirect
15-
github.com/go-openapi/spec v0.19.4 // indirect
15+
github.com/go-openapi/spec v0.19.5 // indirect
1616
github.com/go-sql-driver/mysql v1.4.1
1717
github.com/golang/protobuf v1.3.2
1818
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
19+
github.com/kr/pty v1.1.5 // indirect
1920
github.com/mattn/go-colorable v0.1.4 // indirect
2021
github.com/mattn/go-isatty v0.0.11 // indirect
2122
github.com/mattn/go-runewidth v0.0.7 // indirect
@@ -31,6 +32,7 @@ require (
3132
github.com/sirupsen/logrus v1.4.2
3233
github.com/soniakeys/quant v1.0.0 // indirect
3334
github.com/spf13/cobra v0.0.5 // indirect
35+
github.com/stretchr/objx v0.2.0 // indirect
3436
github.com/stretchr/testify v1.4.0
3537
go.starlark.net v0.0.0-20191218235703-9fcb808a6221 // indirect
3638
golang.org/x/arch v0.0.0-20191126211547-368ea8f32fff // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ github.com/go-openapi/jsonreference v0.19.2/go.mod h1:jMjeRr2HHw6nAVajTXJ4eiUwoh
9595
github.com/go-openapi/spec v0.0.0-20160808142527-6aced65f8501/go.mod h1:J8+jY1nAiCcj+friV/PDoE1/3eeccG9LYBs0tYvLOWc=
9696
github.com/go-openapi/spec v0.19.4 h1:ixzUSnHTd6hCemgtAJgluaTSGYpLNpJY4mA2DIkdOAo=
9797
github.com/go-openapi/spec v0.19.4/go.mod h1:FpwSN1ksY1eteniUU7X0N/BgJ7a4WvBFVA8Lj9mJglo=
98+
github.com/go-openapi/spec v0.19.5 h1:Xm0Ao53uqnk9QE/LlYV5DEU09UAgpliA85QoT9LzqPw=
99+
github.com/go-openapi/spec v0.19.5/go.mod h1:Hm2Jr4jv8G1ciIAo+frC/Ft+rR2kQDh8JHKHb3gWUSk=
98100
github.com/go-openapi/swag v0.0.0-20160704191624-1d0bd113de87/go.mod h1:DXUve3Dpr1UfpPtxFw+EFuQ41HhCWZfha5jSVRG7C7I=
99101
github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
100102
github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY=

pkg/sql/alisa_submitter.go

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ type alisaSubmitter struct {
3737
}
3838

3939
func (s *alisaSubmitter) submitAlisaTask(code, resourceName string) error {
40-
_, dSName, err := database.ParseURL(s.Session.DbConnStr)
40+
_, dsName, err := database.ParseURL(s.Session.DbConnStr)
4141
if err != nil {
4242
return err
4343
}
44-
cfg, e := goalisa.ParseDSN(dSName)
44+
cfg, e := goalisa.ParseDSN(dsName)
4545
if e != nil {
4646
return e
4747
}
@@ -59,6 +59,22 @@ func (s *alisaSubmitter) submitAlisaTask(code, resourceName string) error {
5959
return e
6060
}
6161

62+
func (s *alisaSubmitter) getModelPath(modelName string) (string, error) {
63+
_, dsName, err := database.ParseURL(s.Session.DbConnStr)
64+
if err != nil {
65+
return "", err
66+
}
67+
cfg, err := goalisa.ParseDSN(dsName)
68+
if err != nil {
69+
return "", err
70+
}
71+
userID := s.Session.UserId
72+
if userID == "" {
73+
userID = "unkown"
74+
}
75+
return strings.Join([]string{cfg.Project, userID, modelName}, "/"), nil
76+
}
77+
6278
func (s *alisaSubmitter) ExecuteTrain(ts *ir.TrainStmt) (e error) {
6379
ts.TmpTrainTable, ts.TmpValidateTable, e = createTempTrainAndValTable(ts.Select, ts.ValidationSelect, s.Session.DbConnStr)
6480
if e != nil {
@@ -71,12 +87,17 @@ func (s *alisaSubmitter) ExecuteTrain(ts *ir.TrainStmt) (e error) {
7187
return e
7288
}
7389

74-
paiCmd, e := getPAIcmd(cc, ts.Into, ts.TmpTrainTable, ts.TmpValidateTable, "")
90+
modelPath, e := s.getModelPath(ts.Into)
7591
if e != nil {
7692
return e
7793
}
7894

79-
code, e := pai.TFTrainAndSave(ts, s.Session, ts.Into)
95+
paiCmd, e := getPAIcmd(cc, ts.Into, modelPath, ts.TmpTrainTable, ts.TmpValidateTable, "")
96+
if e != nil {
97+
return e
98+
}
99+
100+
code, e := pai.TFTrainAndSave(ts, s.Session, modelPath)
80101
if e != nil {
81102
return e
82103
}
@@ -121,13 +142,15 @@ func (s *alisaSubmitter) ExecutePredict(ps *ir.PredictStmt) error {
121142
if e != nil {
122143
return e
123144
}
124-
125-
paiCmd, e := getPAIcmd(cc, ps.Using, ps.TmpPredictTable, "", ps.ResultTable)
145+
modelPath, e := s.getModelPath(ps.Using)
126146
if e != nil {
127147
return e
128148
}
129-
130-
code, e := pai.TFLoadAndPredict(ps, s.Session, ps.Using)
149+
paiCmd, e := getPAIcmd(cc, ps.Using, modelPath, ps.TmpPredictTable, "", ps.ResultTable)
150+
if e != nil {
151+
return e
152+
}
153+
code, e := pai.TFLoadAndPredict(ps, s.Session, modelPath)
131154
if e != nil {
132155
return e
133156
}
@@ -198,14 +221,14 @@ func odpsTables(table string) (string, error) {
198221
return fmt.Sprintf("odps://%s/tables/%s", parts[0], parts[1]), nil
199222
}
200223

201-
func getPAIcmd(cc *pai.ClusterConfig, modelName, trainTable, valTable, resTable string) (string, error) {
224+
func getPAIcmd(cc *pai.ClusterConfig, modelName, ossModelPath, trainTable, valTable, resTable string) (string, error) {
202225
jobName := strings.Replace(strings.Join([]string{"sqlflow", modelName}, "_"), ".", "_", 0)
203226
cfString, err := json.Marshal(cc)
204227
if err != nil {
205228
return "", err
206229
}
207230
cfQuote := strconv.Quote(string(cfString))
208-
ckpDir, err := pai.FormatCkptDir(modelName)
231+
ckpDir, err := pai.FormatCkptDir(ossModelPath)
209232
if err != nil {
210233
return "", err
211234
}

pkg/sql/alisa_submitter_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ func TestGetPAICmd(t *testing.T) {
5050
}
5151
os.Setenv("SQLFLOW_OSS_CHECKPOINT_DIR", "oss://bucket/?role_arn=xxx&host=xxx")
5252
defer os.Unsetenv("SQLFLOW_OSS_CHECKPOINT_DIR")
53-
paiCmd, err := getPAIcmd(cc, "my_model", "testdb.test", "", "testdb.result")
53+
paiCmd, err := getPAIcmd(cc, "my_model", "project/12345/my_model", "testdb.test", "", "testdb.result")
5454
a.NoError(err)
55-
ckpDir, err := pai.FormatCkptDir("my_model")
55+
ckpDir, err := pai.FormatCkptDir("project/12345/my_model")
5656
a.NoError(err)
5757
expected := fmt.Sprintf("pai -name tensorflow1120 -DjobName=sqlflow_my_model -Dtags=dnn -Dscript=file://@@task.tar.gz -DentryFile=entry.py -Dtables=odps://testdb/tables/test -Doutputs=odps://testdb/tables/result -DcheckpointDir=\"%s\"", ckpDir)
5858
a.Equal(expected, paiCmd)

pkg/sql/codegen/pai/codegen.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ func FormatCkptDir(modelName string) (string, error) {
6666
}
6767
ossDir := strings.Join([]string{strings.TrimRight(ossURIParts[0], "/"), modelName}, "/")
6868
// Form URI like: oss://bucket/your/path/modelname/?args=...
69-
ossCkptDir = strings.Join([]string{ossDir + "/", ossURIParts[1]}, "?")
70-
return ossCkptDir, nil
69+
return strings.Join([]string{ossDir + "/", ossURIParts[1]}, "?"), nil
7170
}
7271

7372
// wrapper generates a Python program for submit TensorFlow tasks to PAI.
@@ -228,15 +227,15 @@ func Train(ir *ir.TrainStmt, session *pb.Session, modelName, cwd string) (string
228227
}
229228

230229
// TFTrainAndSave generates PAI-TF train program.
231-
func TFTrainAndSave(ir *ir.TrainStmt, session *pb.Session, modelName string) (string, error) {
230+
func TFTrainAndSave(ir *ir.TrainStmt, session *pb.Session, modelPath string) (string, error) {
232231
code, err := tensorflow.Train(ir, session)
233232
if err != nil {
234233
return "", err
235234
}
236235

237236
// append code snippet to save model
238237
var tpl = template.Must(template.New("SaveModel").Parse(tfSaveModelTmplText))
239-
ckptDir, err := FormatCkptDir(ir.Into)
238+
ckptDir, err := FormatCkptDir(modelPath)
240239
if err != nil {
241240
return "", err
242241
}
@@ -332,9 +331,9 @@ func Predict(ir *ir.PredictStmt, session *pb.Session, modelName, cwd string) (st
332331
}
333332

334333
// TFLoadAndPredict generates PAI-TF prediction program.
335-
func TFLoadAndPredict(ir *ir.PredictStmt, session *pb.Session, modelName string) (string, error) {
334+
func TFLoadAndPredict(ir *ir.PredictStmt, session *pb.Session, modelPath string) (string, error) {
336335
var tpl = template.Must(template.New("Predict").Parse(tfPredictTmplText))
337-
ossModelDir, err := FormatCkptDir(modelName)
336+
ossModelDir, err := FormatCkptDir(modelPath)
338337
if err != nil {
339338
return "", err
340339
}

0 commit comments

Comments
 (0)