Skip to content

Commit 307b672

Browse files
authored
Releasing 0.3.3
2 parents 2744e6c + 1ce2dda commit 307b672

File tree

19 files changed

+190
-168
lines changed

19 files changed

+190
-168
lines changed

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ systems, you should add the following dependencies:
6161
<dependency>
6262
<groupId>org.tensorflow</groupId>
6363
<artifactId>tensorflow-core-api</artifactId>
64-
<version>0.3.2</version>
64+
<version>0.3.3</version>
6565
</dependency>
6666
<dependency>
6767
<groupId>org.tensorflow</groupId>
6868
<artifactId>tensorflow-core-api</artifactId>
69-
<version>0.3.2</version>
69+
<version>0.3.3</version>
7070
<classifier>linux-x86_64${javacpp.platform.extension}</classifier>
7171
</dependency>
7272
```
@@ -77,24 +77,24 @@ native dependencies as follows:
7777
<dependency>
7878
<groupId>org.tensorflow</groupId>
7979
<artifactId>tensorflow-core-api</artifactId>
80-
<version>0.3.2</version>
80+
<version>0.3.3</version>
8181
</dependency>
8282
<dependency>
8383
<groupId>org.tensorflow</groupId>
8484
<artifactId>tensorflow-core-api</artifactId>
85-
<version>0.3.2</version>
85+
<version>0.3.3</version>
8686
<classifier>linux-x86_64${javacpp.platform.extension}</classifier>
8787
</dependency>
8888
<dependency>
8989
<groupId>org.tensorflow</groupId>
9090
<artifactId>tensorflow-core-api</artifactId>
91-
<version>0.3.2</version>
91+
<version>0.3.3</version>
9292
<classifier>macosx-x86_64${javacpp.platform.extension}</classifier>
9393
</dependency>
9494
<dependency>
9595
<groupId>org.tensorflow</groupId>
9696
<artifactId>tensorflow-core-api</artifactId>
97-
<version>0.3.2</version>
97+
<version>0.3.3</version>
9898
<classifier>windows-x86_64${javacpp.platform.extension}</classifier>
9999
</dependency>
100100
```
@@ -107,7 +107,7 @@ artifact includes transitively all the artifacts above as a single dependency:
107107
<dependency>
108108
<groupId>org.tensorflow</groupId>
109109
<artifactId>tensorflow-core-platform${javacpp.platform.extension}</artifactId>
110-
<version>0.3.2</version>
110+
<version>0.3.3</version>
111111
</dependency>
112112
```
113113

@@ -152,6 +152,7 @@ This table shows the mapping between different version of TensorFlow for Java an
152152
| 0.3.0 | 2.4.1 |
153153
| 0.3.1 | 2.4.1 |
154154
| 0.3.2 | 2.4.1 |
155+
| 0.3.3 | 2.4.1 |
155156

156157
## How to Contribute?
157158

ndarray/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ To import the NdArray library in your project, simply add the following dependen
1111
<dependency>
1212
<groupId>org.tensorflow</groupId>
1313
<artifactId>ndarray</artifactId>
14-
<version>0.3.2</version>
14+
<version>0.3.3</version>
1515
</dependency>
1616
```
1717

ndarray/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
<parent>
2323
<groupId>org.tensorflow</groupId>
2424
<artifactId>tensorflow-java</artifactId>
25-
<version>0.3.2</version>
25+
<version>0.3.3</version>
2626
</parent>
2727
<artifactId>ndarray</artifactId>
2828
<packaging>jar</packaging>

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
<groupId>org.tensorflow</groupId>
77
<artifactId>tensorflow-java</artifactId>
8-
<version>0.3.2</version>
8+
<version>0.3.3</version>
99
<packaging>pom</packaging>
1010

1111
<name>TensorFlow Java Parent</name>

tensorflow-core/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
<parent>
2323
<groupId>org.tensorflow</groupId>
2424
<artifactId>tensorflow-java</artifactId>
25-
<version>0.3.2</version>
25+
<version>0.3.3</version>
2626
</parent>
2727
<artifactId>tensorflow-core</artifactId>
2828
<packaging>pom</packaging>

tensorflow-core/tensorflow-core-api/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>org.tensorflow</groupId>
88
<artifactId>tensorflow-core</artifactId>
9-
<version>0.3.2</version>
9+
<version>0.3.3</version>
1010
</parent>
1111
<artifactId>tensorflow-core-api</artifactId>
1212
<packaging>jar</packaging>

tensorflow-core/tensorflow-core-api/pom.xml.asc

Lines changed: 0 additions & 11 deletions
This file was deleted.

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -453,19 +453,18 @@ public Output<?>[] whileLoop(
453453
synchronized SaverDef saverDef() {
454454
if (saverDef == null) {
455455
// Check to see if this graph has a restore operation
456-
if (operation("save/restore_all") == null) {
456+
if (operation(SAVER_DEF_SCOPE + "/" + SAVER_DEF_RESTORE_OP) == null) {
457457
// No saver, create one by mutating the graph
458458
saverDef = addVariableSaver(this);
459459
} else {
460460
// This graph already has saving/restoring operations,
461-
// regenerate SaverDef without mutating. The names mirror
462-
// the python implementation for compatibility.
463-
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
464-
saverDef = SaverDef.newBuilder()
465-
.setFilenameTensorName("save/filename")
466-
.setSaveTensorName("save/control_dependency")
467-
.setRestoreOpName("save/restore_all")
468-
.build();
461+
// regenerate SaverDef without mutating.
462+
saverDef =
463+
SaverDef.newBuilder()
464+
.setFilenameTensorName(SAVER_DEF_SCOPE + "/" + SAVER_DEF_FILENAME_OP + ":0")
465+
.setSaveTensorName(SAVER_DEF_SCOPE + "/" + SAVER_DEF_SAVE_OP)
466+
.setRestoreOpName(SAVER_DEF_SCOPE + "/" + SAVER_DEF_RESTORE_OP)
467+
.build();
469468
}
470469
}
471470
return saverDef;
@@ -570,6 +569,13 @@ public void remove() {
570569
private int position;
571570
}
572571

572+
// These names mirror the python implementation, to reduce the risk of incompatibility.
573+
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
574+
private static final String SAVER_DEF_SCOPE = "save";
575+
private static final String SAVER_DEF_FILENAME_OP = "filename";
576+
private static final String SAVER_DEF_SAVE_OP = "control_dependency";
577+
private static final String SAVER_DEF_RESTORE_OP = "restore_all";
578+
573579
private static TF_Graph allocate() {
574580
return TF_NewGraph();
575581
}
@@ -797,7 +803,7 @@ private static Object[] whileLoop(
797803
}
798804

799805
private static SaverDef addVariableSaver(Graph graph) {
800-
Ops tf = Ops.create(graph).withSubScope("save");
806+
Ops tf = Ops.create(graph).withSubScope(SAVER_DEF_SCOPE);
801807

802808
List<String> varNames = new ArrayList<>();
803809
List<Operand<?>> varOutputs = new ArrayList<>();
@@ -812,36 +818,35 @@ private static SaverDef addVariableSaver(Graph graph) {
812818
}
813819
}
814820

815-
// FIXME Need an easier way to initialize an NdArray from a list
816-
String[] tmp = new String[varNames.size()];
817-
Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp)));
818-
Operand<TString> varSlices = tf.zerosLike(varNamesTensor);
819-
820-
Placeholder<TString> saveFilename = tf.withName("filename").placeholder(TString.class);
821-
Save saveVariables = tf.train.save(
822-
saveFilename,
823-
varNamesTensor,
824-
varSlices,
825-
varOutputs
826-
);
827-
Identity<TString> id = tf.withControlDependencies(Arrays.asList(saveFilename,saveVariables))
828-
.withName("control_dependency").identity(saveFilename);
829-
Restore restoreVariables = tf.train.restore(
830-
saveFilename,
831-
varNamesTensor,
832-
varSlices,
833-
varTypes
834-
);
835-
List<Op> restoreOps = new ArrayList<>(varOutputs.size());
836-
for (int i = 0; i < varOutputs.size(); ++i) {
837-
restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i)));
821+
Placeholder<TString> filename = tf.withName(SAVER_DEF_FILENAME_OP).placeholder(TString.class);
822+
Identity<TString> save = null;
823+
NoOp restore = null;
824+
825+
if (varNames.isEmpty()) {
826+
save = tf.withName(SAVER_DEF_SAVE_OP).identity(filename);
827+
restore = tf.withName(SAVER_DEF_RESTORE_OP).noOp();
828+
} else {
829+
String[] tmp = new String[varNames.size()];
830+
Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp)));
831+
Operand<TString> varSlices = tf.zerosLike(varNamesTensor);
832+
Save saveVars = tf.train.save(filename, varNamesTensor, varSlices, varOutputs);
833+
List<Op> saveDeps = Arrays.asList(filename, saveVars);
834+
Restore restoreVars = tf.train.restore(filename, varNamesTensor, varSlices, varTypes);
835+
List<Op> restoreDeps = new ArrayList<>(varOutputs.size());
836+
for (int i = 0; i < varOutputs.size(); ++i) {
837+
restoreDeps.add(tf.assign(varOutputs.get(i), (Operand) restoreVars.tensors().get(i)));
838+
}
839+
save = tf.withControlDependencies(saveDeps).withName(SAVER_DEF_SAVE_OP).identity(filename);
840+
restore = tf.withControlDependencies(restoreDeps).withName(SAVER_DEF_RESTORE_OP).noOp();
838841
}
839-
NoOp restoreAll = tf.withControlDependencies(restoreOps).withName("restore_all").noOp();
840842

843+
// 'Filename' must be the name of a tensor (i.e. with output index)
844+
// 'Save' must be an operation name, even if the field name is confusing (see SaverDef doc)
845+
// 'Restore' must be an operation name
841846
return SaverDef.newBuilder()
842-
.setFilenameTensorName(saveFilename.op().name())
843-
.setSaveTensorName(id.op().name())
844-
.setRestoreOpName(restoreAll.op().name())
847+
.setFilenameTensorName(filename.output().name())
848+
.setSaveTensorName(save.op().name())
849+
.setRestoreOpName(restore.op().name())
845850
.build();
846851
}
847852

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ public int index() {
3838
return index;
3939
}
4040

41+
/** Returns the full name of this Output (a.k.a. tensor name) */
42+
public String name() {
43+
return op().name() + ":" + index;
44+
}
45+
4146
/** Returns the DataType of the tensor referred to by this Output. */
4247
@SuppressWarnings("unchecked")
4348
public DataType dataType() {
@@ -48,7 +53,7 @@ public DataType dataType() {
4853
@SuppressWarnings("unchecked")
4954
@Override
5055
public Class<T> type() {
51-
return (Class<T>)TensorTypeRegistry.find(dataType()).type();
56+
return (Class<T>) TensorTypeRegistry.find(dataType()).type();
5257
}
5358

5459
/**
@@ -63,7 +68,10 @@ public Class<T> type() {
6368
public <U extends TType> Output<U> expect(Class<U> type) {
6469
if (type != type()) {
6570
throw new IllegalArgumentException(
66-
"Cannot cast from output of " + this.type().getSimpleName() + " to output of " + type.getSimpleName());
71+
"Cannot cast from output of "
72+
+ this.type().getSimpleName()
73+
+ " to output of "
74+
+ type.getSimpleName());
6775
}
6876
return ((Output<U>) this);
6977
}
@@ -80,17 +88,16 @@ public <U extends TType> Output<U> expect(Class<U> type) {
8088
*
8189
* @return tensor
8290
* @throws IllegalStateException if this output results from a graph
83-
* @throws ClassCastException if the type of the tensor and this output are unexpectedly incompatible
91+
* @throws ClassCastException if the type of the tensor and this output are unexpectedly
92+
* incompatible
8493
* @see EagerSession
8594
*/
8695
@SuppressWarnings("unchecked")
8796
public T asTensor() {
88-
return (T)operation.tensor(index);
97+
return (T) operation.tensor(index);
8998
}
9099

91-
/**
92-
* Returns the (possibly partially known) shape of the tensor referred to by this output.
93-
*/
100+
/** Returns the (possibly partially known) shape of the tensor referred to by this output. */
94101
@Override
95102
public Shape shape() {
96103
return operation.shape(index);

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.util.LinkedHashMap;
3030
import java.util.List;
3131
import java.util.Map;
32+
import java.util.Objects;
3233
import java.util.stream.Collectors;
3334
import org.bytedeco.javacpp.BytePointer;
3435
import org.bytedeco.javacpp.PointerScope;
@@ -432,7 +433,7 @@ private static SavedModelBundle load(
432433
}
433434

434435
private static void validateTags(String[] tags) {
435-
if (tags == null || Arrays.stream(tags).anyMatch(t -> t == null || t.isEmpty())) {
436+
if (tags == null || Arrays.stream(tags).anyMatch(Objects::isNull)) {
436437
throw new IllegalArgumentException("Invalid tags: " + Arrays.toString(tags));
437438
}
438439
}

0 commit comments

Comments
 (0)