Skip to content

Commit 5f89ee1

Browse files
committed
Fix the tests.
1 parent 5e92bc1 commit 5f89ee1

File tree

9 files changed

+28
-81
lines changed

9 files changed

+28
-81
lines changed

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java

Lines changed: 0 additions & 27 deletions
This file was deleted.

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,15 +205,13 @@ public void testGradientsGraph() {
205205

206206
try (TFloat32 c1 = TFloat32.scalarOf(3.0f);
207207
TFloat32 c2 = TFloat32.scalarOf(2.0f);
208-
AutoCloseableList<Tensor> outputs =
209-
new AutoCloseableList<>(
210-
s.runner()
208+
Session.Result outputs = s.runner()
211209
.feed(x1, c1)
212210
.feed(x2, c2)
213211
.fetch(grads0[0])
214212
.fetch(grads1[0])
215213
.fetch(grads1[1])
216-
.run())) {
214+
.run()) {
217215

218216
assertEquals(3, outputs.size());
219217
assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f);

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ public void testCustomGradient() {
6666
assertEquals(DataType.DT_FLOAT, grads0[0].dataType());
6767

6868
try (TFloat32 c1 = TFloat32.vectorOf(3.0f, 2.0f, 1.0f, 0.0f);
69-
AutoCloseableList<Tensor> outputs =
70-
new AutoCloseableList<>(s.runner().feed(x, c1).fetch(grads0[0]).run())) {
69+
Session.Result outputs = s.runner().feed(x, c1).fetch(grads0[0]).run()) {
7170

7271
assertEquals(1, outputs.size());
7372
assertEquals(0.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f);

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ public void withDeviceMethod() {
5353
.abs(aOps)
5454
.asOutput();
5555

56-
try (AutoCloseableList<Tensor> t =
57-
new AutoCloseableList<>(session.runner().fetch(absOps).run())) {
56+
try (Session.Result t = session.runner().fetch(absOps).run()) {
5857
assertEquals(1, ((TInt32)t.get(0)).getInt());
5958
}
6059
}
@@ -85,8 +84,7 @@ public void withEmptyDeviceSpec() {
8584
.abs(aOps)
8685
.asOutput();
8786

88-
try (AutoCloseableList<Tensor> t =
89-
new AutoCloseableList<>(session.runner().fetch(absOps).run())) {
87+
try (Session.Result t = session.runner().fetch(absOps).run()) {
9088
assertEquals(1, ((TInt32)t.get(0)).getInt());
9189
}
9290
}
@@ -131,8 +129,7 @@ public void withTwoScopes() {
131129
.mul(absOps, bOps)
132130
.asOutput();
133131

134-
try (AutoCloseableList<Tensor> t =
135-
new AutoCloseableList<>(session.runner().fetch(mulOps).run())) {
132+
try (Session.Result t = session.runner().fetch(mulOps).run()) {
136133
assertEquals(10, ((TInt32)t.get(0)).getInt());
137134
}
138135
}
@@ -179,8 +176,7 @@ public void withIncorrectDeviceSpec() {
179176
.mul(absOps, bOps)
180177
.asOutput();
181178

182-
try (AutoCloseableList<Tensor> t =
183-
new AutoCloseableList<>(session.runner().fetch(mulOps).run())) {
179+
try (Session.Result t = session.runner().fetch(mulOps).run()) {
184180
fail();
185181
} catch (TFInvalidArgumentException e) {
186182
// ok
@@ -212,8 +208,7 @@ public void withDeviceSpecInScope() {
212208
.abs(aOps)
213209
.asOutput();
214210

215-
try (AutoCloseableList<Tensor> t =
216-
new AutoCloseableList<>(session.runner().fetch(absOps).run())) {
211+
try (Session.Result t = session.runner().fetch(absOps).run()) {
217212
assertEquals(1, ((TInt32)t.get(0)).getInt());
218213
}
219214
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,13 @@ public void graphDefRoundTripWithInit() {
8484

8585
Operand<TInt32> variable2 = init.withName("var2").variable(init.constant(4));
8686

87-
try (Session s = new Session(g, true)) {
88-
List<Tensor> results = s.runner().fetch("result").fetch("var2").run();
87+
try (Session s = new Session(g, true);
88+
Session.Result results = s.runner().fetch("result").fetch("var2").run()) {
8989
TInt32 result = (TInt32) results.get(0);
9090
assertEquals(6, result.getInt());
9191

9292
TInt32 var2Result = (TInt32) results.get(1);
9393
assertEquals(4, var2Result.getInt());
94-
95-
results.forEach(Tensor::close);
9694
}
9795
}
9896
}
@@ -266,15 +264,13 @@ public void addGradientsToGraph() {
266264

267265
try (TFloat32 c1 = TFloat32.scalarOf(3.0f);
268266
TFloat32 c2 = TFloat32.scalarOf(2.0f);
269-
AutoCloseableList<Tensor> outputs =
270-
new AutoCloseableList<>(
271-
s.runner()
267+
Session.Result outputs = s.runner()
272268
.feed(x1, c1)
273269
.feed(x2, c2)
274270
.fetch(grads0[0])
275271
.fetch(grads1[0])
276272
.fetch(grads1[1])
277-
.run())) {
273+
.run()) {
278274
assertEquals(3, outputs.size());
279275
assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f);
280276
assertEquals(6.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f);
@@ -418,14 +414,12 @@ public void buildWhileLoopMultipleInputs() {
418414

419415
try (TInt32 c1 = TInt32.scalarOf(2);
420416
TInt32 c2 = TInt32.scalarOf(5);
421-
AutoCloseableList<Tensor> outputs =
422-
new AutoCloseableList<>(
423-
s.runner()
417+
Session.Result outputs = s.runner()
424418
.feed(input1, c1)
425419
.feed(input2, c2)
426420
.fetch(loopOutputs[0])
427421
.fetch(loopOutputs[1])
428-
.run())) {
422+
.run()) {
429423
assertEquals(2, outputs.size());
430424
assertEquals(16, ((TInt32) outputs.get(0)).getInt()); // ((2^2)^2)
431425
assertEquals(625, ((TInt32) outputs.get(1)).getInt()); // ((5^2)^2)

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public void testBooleanMaskUpdateSlice() {
5050
Operand<TInt32> bcastOutput =
5151
BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1));
5252

53-
List<Tensor> results = sess.runner().fetch(output).fetch(bcastOutput).run();
53+
Session.Result results = sess.runner().fetch(output).fetch(bcastOutput).run();
5454
try (TInt32 result = (TInt32) results.get(0);
5555
TInt32 bcastResult = (TInt32) results.get(1)) {
5656

@@ -89,7 +89,7 @@ public void testBooleanMaskUpdateSliceWithBroadcast() {
8989
Operand<TInt32> bcastOutput =
9090
BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1));
9191

92-
List<Tensor> results = sess.runner().fetch(output).fetch(bcastOutput).run();
92+
Session.Result results = sess.runner().fetch(output).fetch(bcastOutput).run();
9393
try (TInt32 result = (TInt32) results.get(0);
9494
TInt32 bcastResult = (TInt32) results.get(1)) {
9595

@@ -131,7 +131,7 @@ public void testBooleanMaskUpdateAxis() {
131131
BooleanMaskUpdate.create(
132132
scope, input, mask, Constant.scalarOf(scope, -1), BooleanMaskUpdate.axis(2));
133133

134-
List<Tensor> results = sess.runner().fetch(output).fetch(bcastOutput).run();
134+
Session.Result results = sess.runner().fetch(output).fetch(bcastOutput).run();
135135
try (TInt32 result = (TInt32) results.get(0);
136136
TInt32 bcastResult = (TInt32) results.get(1)) {
137137

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@
1919

2020
import java.io.IOException;
2121
import org.junit.jupiter.api.Test;
22-
import org.tensorflow.AutoCloseableList;
2322
import org.tensorflow.EagerSession;
2423
import org.tensorflow.Graph;
2524
import org.tensorflow.Operand;
2625
import org.tensorflow.Session;
27-
import org.tensorflow.Tensor;
2826
import org.tensorflow.ndarray.DoubleNdArray;
2927
import org.tensorflow.ndarray.FloatNdArray;
3028
import org.tensorflow.ndarray.IntNdArray;
@@ -66,8 +64,7 @@ public void createInts() {
6664
Scope scope = new OpScope(g);
6765
Constant<TInt32> op1 = Constant.tensorOf(scope, shape, buffer);
6866
Constant<TInt32> op2 = Constant.tensorOf(scope, array);
69-
try (AutoCloseableList<Tensor> t =
70-
new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) {
67+
try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) {
7168
assertEquals(array, t.get(0));
7269
assertEquals(array, t.get(1));
7370
}
@@ -85,8 +82,7 @@ public void createFloats() {
8582
Scope scope = new OpScope(g);
8683
Constant<TFloat32> op1 = Constant.tensorOf(scope, shape, buffer);
8784
Constant<TFloat32> op2 = Constant.tensorOf(scope, array);
88-
try (AutoCloseableList<Tensor> t =
89-
new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) {
85+
try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) {
9086
assertEquals(array, t.get(0));
9187
assertEquals(array, t.get(1));
9288
}
@@ -104,8 +100,7 @@ public void createDoubles() {
104100
Scope scope = new OpScope(g);
105101
Constant<TFloat64> op1 = Constant.tensorOf(scope, shape, buffer);
106102
Constant<TFloat64> op2 = Constant.tensorOf(scope, array);
107-
try (AutoCloseableList<Tensor> t =
108-
new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) {
103+
try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) {
109104
assertEquals(array, t.get(0));
110105
assertEquals(array, t.get(1));
111106
}
@@ -123,8 +118,7 @@ public void createLongs() {
123118
Scope scope = new OpScope(g);
124119
Constant<TInt64> op1 = Constant.tensorOf(scope, shape, buffer);
125120
Constant<TInt64> op2 = Constant.tensorOf(scope, array);
126-
try (AutoCloseableList<Tensor> t =
127-
new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) {
121+
try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) {
128122
assertEquals(array, t.get(0));
129123
assertEquals(array, t.get(1));
130124
}
@@ -142,8 +136,7 @@ public void createStrings() throws IOException {
142136
Scope scope = new OpScope(g);
143137
Constant<TString> op1 = Constant.tensorOf(scope, shape, buffer);
144138
Constant<TString> op2 = Constant.tensorOf(scope, array);
145-
try (AutoCloseableList<Tensor> t =
146-
new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) {
139+
try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) {
147140
assertEquals(array, t.get(0));
148141
assertEquals(array, t.get(1));
149142
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,9 @@
2121

2222
import java.util.Arrays;
2323
import org.junit.jupiter.api.Test;
24-
import org.tensorflow.AutoCloseableList;
2524
import org.tensorflow.Graph;
2625
import org.tensorflow.Output;
2726
import org.tensorflow.Session;
28-
import org.tensorflow.Tensor;
2927
import org.tensorflow.op.Ops;
3028
import org.tensorflow.types.TFloat32;
3129

@@ -48,9 +46,8 @@ public void createGradients() {
4846
assertEquals(2, grads.dy().size());
4947

5048
try (TFloat32 c = TFloat32.scalarOf(3.0f);
51-
AutoCloseableList<Tensor> outputs =
52-
new AutoCloseableList<>(
53-
sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) {
49+
Session.Result outputs = sess.runner()
50+
.feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run()) {
5451

5552
assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f);
5653
assertEquals(18.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f);
@@ -75,8 +72,7 @@ public void createGradientsWithSum() {
7572
assertEquals(1, grads.dy().size());
7673

7774
try (TFloat32 c = TFloat32.scalarOf(3.0f);
78-
AutoCloseableList<Tensor> outputs =
79-
new AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) {
75+
Session.Result outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).run()) {
8076

8177
assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f);
8278
}
@@ -101,9 +97,8 @@ public void createGradientsWithInitialValues() {
10197
assertEquals(1, grads1.dy().size());
10298

10399
try (TFloat32 c = TFloat32.scalarOf(3.0f);
104-
AutoCloseableList<Tensor> outputs =
105-
new AutoCloseableList<>(
106-
sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) {
100+
Session.Result outputs =
101+
sess.runner().feed(x, c).fetch(grads1.dy(0)).run()) {
107102

108103
assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f);
109104
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ public void operationsComposingZerosAreCorrectlyNamed() {
134134
long[] shape = {2, 2};
135135
Zeros<TFloat32> zeros =
136136
Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.class);
137-
List<?> results =
137+
Session.Result results =
138138
sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run();
139139
}
140140
}

0 commit comments

Comments
 (0)