Skip to content

Commit 1ce2dda

Browse files
committed
Use invariable op names for saver def
1 parent ae7194f commit 1ce2dda

File tree

1 file changed

+18
-13
lines changed
  • tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow

1 file changed

+18
-13
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -453,19 +453,17 @@ public Output<?>[] whileLoop(
453453
synchronized SaverDef saverDef() {
454454
if (saverDef == null) {
455455
// Check to see if this graph has a restore operation
456-
if (operation("save/restore_all") == null) {
456+
if (operation(SAVER_DEF_SCOPE + "/" + SAVER_DEF_RESTORE_OP) == null) {
457457
// No saver, create one by mutating the graph
458458
saverDef = addVariableSaver(this);
459459
} else {
460460
// This graph already has saving/restoring operations,
461-
// regenerate SaverDef without mutating. The names mirror
462-
// the python implementation for compatibility.
463-
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
461+
// regenerate SaverDef without mutating.
464462
saverDef =
465463
SaverDef.newBuilder()
466-
.setFilenameTensorName("save/filename:0")
467-
.setSaveTensorName("save/control_dependency")
468-
.setRestoreOpName("save/restore_all")
464+
.setFilenameTensorName(SAVER_DEF_SCOPE + "/" + SAVER_DEF_FILENAME_OP + ":0")
465+
.setSaveTensorName(SAVER_DEF_SCOPE + "/" + SAVER_DEF_SAVE_OP)
466+
.setRestoreOpName(SAVER_DEF_SCOPE + "/" + SAVER_DEF_RESTORE_OP)
469467
.build();
470468
}
471469
}
@@ -571,6 +569,13 @@ public void remove() {
571569
private int position;
572570
}
573571

572+
// These names mirror the python implementation, to reduce the risk of incompatibility.
573+
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
574+
private static final String SAVER_DEF_SCOPE = "save";
575+
private static final String SAVER_DEF_FILENAME_OP = "filename";
576+
private static final String SAVER_DEF_SAVE_OP = "control_dependency";
577+
private static final String SAVER_DEF_RESTORE_OP = "restore_all";
578+
574579
private static TF_Graph allocate() {
575580
return TF_NewGraph();
576581
}
@@ -798,7 +803,7 @@ private static Object[] whileLoop(
798803
}
799804

800805
private static SaverDef addVariableSaver(Graph graph) {
801-
Ops tf = Ops.create(graph).withSubScope("save");
806+
Ops tf = Ops.create(graph).withSubScope(SAVER_DEF_SCOPE);
802807

803808
List<String> varNames = new ArrayList<>();
804809
List<Operand<?>> varOutputs = new ArrayList<>();
@@ -813,13 +818,13 @@ private static SaverDef addVariableSaver(Graph graph) {
813818
}
814819
}
815820

816-
Placeholder<TString> filename = tf.withName("filename").placeholder(TString.class);
821+
Placeholder<TString> filename = tf.withName(SAVER_DEF_FILENAME_OP).placeholder(TString.class);
817822
Identity<TString> save = null;
818823
NoOp restore = null;
819824

820825
if (varNames.isEmpty()) {
821-
save = tf.withName("empty_save").identity(filename);
822-
restore = tf.withName("restore_all").noOp();
826+
save = tf.withName(SAVER_DEF_SAVE_OP).identity(filename);
827+
restore = tf.withName(SAVER_DEF_RESTORE_OP).noOp();
823828
} else {
824829
String[] tmp = new String[varNames.size()];
825830
Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp)));
@@ -831,8 +836,8 @@ private static SaverDef addVariableSaver(Graph graph) {
831836
for (int i = 0; i < varOutputs.size(); ++i) {
832837
restoreDeps.add(tf.assign(varOutputs.get(i), (Operand) restoreVars.tensors().get(i)));
833838
}
834-
save = tf.withControlDependencies(saveDeps).withName("control_dependency").identity(filename);
835-
restore = tf.withControlDependencies(restoreDeps).withName("restore_all").noOp();
839+
save = tf.withControlDependencies(saveDeps).withName(SAVER_DEF_SAVE_OP).identity(filename);
840+
restore = tf.withControlDependencies(restoreDeps).withName(SAVER_DEF_RESTORE_OP).noOp();
836841
}
837842

838843
// 'Filename' must be the name of a tensor (i.e. with output index)

0 commit comments

Comments
 (0)