@@ -453,19 +453,17 @@ public Output<?>[] whileLoop(
453
453
synchronized SaverDef saverDef () {
454
454
if (saverDef == null ) {
455
455
// Check to see if this graph has a restore operation
456
- if (operation ("save/restore_all" ) == null ) {
456
+ if (operation (SAVER_DEF_SCOPE + "/" + SAVER_DEF_RESTORE_OP ) == null ) {
457
457
// No saver, create one by mutating the graph
458
458
saverDef = addVariableSaver (this );
459
459
} else {
460
460
// This graph already has saving/restoring operations,
461
- // regenerate SaverDef without mutating. The names mirror
462
- // the python implementation for compatibility.
463
- // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
461
+ // regenerate SaverDef without mutating.
464
462
saverDef =
465
463
SaverDef .newBuilder ()
466
- .setFilenameTensorName ("save/filename :0" )
467
- .setSaveTensorName ("save/control_dependency" )
468
- .setRestoreOpName ("save/restore_all" )
464
+ .setFilenameTensorName (SAVER_DEF_SCOPE + "/" + SAVER_DEF_FILENAME_OP + " :0" )
465
+ .setSaveTensorName (SAVER_DEF_SCOPE + "/" + SAVER_DEF_SAVE_OP )
466
+ .setRestoreOpName (SAVER_DEF_SCOPE + "/" + SAVER_DEF_RESTORE_OP )
469
467
.build ();
470
468
}
471
469
}
@@ -571,6 +569,13 @@ public void remove() {
571
569
private int position ;
572
570
}
573
571
572
+ // These names mirror the python implementation, to reduce the risk of incompatibility.
573
+ // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
574
+ private static final String SAVER_DEF_SCOPE = "save" ;
575
+ private static final String SAVER_DEF_FILENAME_OP = "filename" ;
576
+ private static final String SAVER_DEF_SAVE_OP = "control_dependency" ;
577
+ private static final String SAVER_DEF_RESTORE_OP = "restore_all" ;
578
+
574
579
private static TF_Graph allocate () {
575
580
return TF_NewGraph ();
576
581
}
@@ -798,7 +803,7 @@ private static Object[] whileLoop(
798
803
}
799
804
800
805
private static SaverDef addVariableSaver (Graph graph ) {
801
- Ops tf = Ops .create (graph ).withSubScope ("save" );
806
+ Ops tf = Ops .create (graph ).withSubScope (SAVER_DEF_SCOPE );
802
807
803
808
List <String > varNames = new ArrayList <>();
804
809
List <Operand <?>> varOutputs = new ArrayList <>();
@@ -813,13 +818,13 @@ private static SaverDef addVariableSaver(Graph graph) {
813
818
}
814
819
}
815
820
816
- Placeholder <TString > filename = tf .withName ("filename" ).placeholder (TString .class );
821
+ Placeholder <TString > filename = tf .withName (SAVER_DEF_FILENAME_OP ).placeholder (TString .class );
817
822
Identity <TString > save = null ;
818
823
NoOp restore = null ;
819
824
820
825
if (varNames .isEmpty ()) {
821
- save = tf .withName ("empty_save" ).identity (filename );
822
- restore = tf .withName ("restore_all" ).noOp ();
826
+ save = tf .withName (SAVER_DEF_SAVE_OP ).identity (filename );
827
+ restore = tf .withName (SAVER_DEF_RESTORE_OP ).noOp ();
823
828
} else {
824
829
String [] tmp = new String [varNames .size ()];
825
830
Constant <TString > varNamesTensor = tf .constant (StdArrays .ndCopyOf (varNames .toArray (tmp )));
@@ -831,8 +836,8 @@ private static SaverDef addVariableSaver(Graph graph) {
831
836
for (int i = 0 ; i < varOutputs .size (); ++i ) {
832
837
restoreDeps .add (tf .assign (varOutputs .get (i ), (Operand ) restoreVars .tensors ().get (i )));
833
838
}
834
- save = tf .withControlDependencies (saveDeps ).withName ("control_dependency" ).identity (filename );
835
- restore = tf .withControlDependencies (restoreDeps ).withName ("restore_all" ).noOp ();
839
+ save = tf .withControlDependencies (saveDeps ).withName (SAVER_DEF_SAVE_OP ).identity (filename );
840
+ restore = tf .withControlDependencies (restoreDeps ).withName (SAVER_DEF_RESTORE_OP ).noOp ();
836
841
}
837
842
838
843
// 'Filename' must be the name of a tensor (i.e. with output index)
0 commit comments