Skip to content

Commit b122dd0

Browse files
committed
Change the signature of TensorFunction.call to return Result. Migrate Session.Result to be a top level class.
1 parent 5f89ee1 commit b122dd0

File tree

20 files changed

+249
-206
lines changed

20 files changed

+249
-206
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2020-2022 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -295,7 +295,7 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
295295
}
296296

297297
@Override
298-
public Map<String, Tensor> call(Map<String, Tensor> arguments) {
298+
public Result call(Map<String, Tensor> arguments) {
299299
// FIXME need to manage input/output operand lifetimes
300300
Ops tf = Ops.create();
301301
Map<String, Operand<?>> inputs = new LinkedHashMap<>(arguments.size());
@@ -305,11 +305,11 @@ public Map<String, Tensor> call(Map<String, Tensor> arguments) {
305305
inputs.put(inputName, tf.constantOf((TType) argument));
306306
}
307307
Map<String, Operand<?>> outputs = tf.call(this, inputs);
308-
Map<String, Tensor> tensorOutputs = new LinkedHashMap<>(outputs.size());
308+
LinkedHashMap<String, Tensor> tensorOutputs = new LinkedHashMap<>(outputs.size());
309309
for (String outputName : outputs.keySet()) {
310310
tensorOutputs.put(outputName, outputs.get(outputName).asTensor());
311311
}
312-
return tensorOutputs;
312+
return new Result(tensorOutputs);
313313
}
314314

315315
/**
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
/*
2+
Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
3+
Copyright 2022 The TensorFlow Authors. All Rights Reserved.
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
=======================================================================
17+
*/
18+
package org.tensorflow;
19+
20+
import org.tensorflow.proto.framework.RunMetadata;
21+
22+
import java.util.ArrayList;
23+
import java.util.Collections;
24+
import java.util.Iterator;
25+
import java.util.LinkedHashMap;
26+
import java.util.List;
27+
import java.util.Map;
28+
import java.util.Optional;
29+
import java.util.Set;
30+
import java.util.logging.Logger;
31+
32+
/**
33+
* An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s.
34+
*
35+
* <p>When this is closed it closes all the {@link Tensor}s inside it. If you maintain a
36+
* reference to a value after this object has been closed it will throw an {@link
37+
* IllegalStateException} upon access.
38+
*/
39+
public final class Result implements AutoCloseable, Iterable<Map.Entry<String, Tensor>> {
40+
@Override
41+
public void close() {
42+
if (!closed) {
43+
closed = true;
44+
for (Tensor t : map.values()) {
45+
t.close();
46+
}
47+
} else {
48+
logger.warning("Closing an already closed Result");
49+
}
50+
}
51+
52+
@Override
53+
public Iterator<Map.Entry<String, Tensor>> iterator() {
54+
if (!closed) {
55+
return map.entrySet().iterator();
56+
} else {
57+
throw new IllegalStateException("Result is closed");
58+
}
59+
}
60+
61+
/**
62+
* Returns the number of outputs in this Result.
63+
*
64+
* @return The number of outputs.
65+
*/
66+
public int size() {
67+
return map.size();
68+
}
69+
70+
/**
71+
* Gets the set containing all the tensor names.
72+
* @return The tensor names set.
73+
*/
74+
public Set<String> keySet() {
75+
return Collections.unmodifiableSet(map.keySet());
76+
}
77+
78+
/**
79+
* Does this result object have a tensor for the supplied key?
80+
* @param key The key to check.
81+
* @return True if this result object has a tensor for this key.
82+
*/
83+
public boolean containsKey(String key) {
84+
return map.containsKey(key);
85+
}
86+
87+
/**
88+
* Gets the value from the container at the specified index.
89+
*
90+
* <p>Throws {@link IllegalStateException} if the container has been closed, and {@link
91+
* IndexOutOfBoundsException} if the index is invalid.
92+
*
93+
* @param index The index to lookup.
94+
* @return The value at the index.
95+
*/
96+
public Tensor get(int index) {
97+
if (!closed) {
98+
return list.get(index);
99+
} else {
100+
throw new IllegalStateException("Result is closed");
101+
}
102+
}
103+
104+
/**
105+
* Gets the value from the container assuming it's not been closed.
106+
*
107+
* <p>Throws {@link IllegalStateException} if the container has been closed.
108+
*
109+
* @param key The key to lookup.
110+
* @return Optional.of the value if it exists.
111+
*/
112+
public Optional<Tensor> get(String key) {
113+
if (!closed) {
114+
Tensor value = map.get(key);
115+
if (value != null) {
116+
return Optional.of(value);
117+
} else {
118+
return Optional.empty();
119+
}
120+
} else {
121+
throw new IllegalStateException("Result is closed");
122+
}
123+
}
124+
125+
/**
126+
* Metadata about the run.
127+
*
128+
* <p>A <a
129+
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
130+
* protocol buffer</a>.
131+
*/
132+
public Optional<RunMetadata> getMetadata() {
133+
return Optional.ofNullable(metadata);
134+
}
135+
136+
/**
137+
* Creates a Result from the names and values produced by {@link Session.Runner#run()}.
138+
*
139+
* @param names The output names.
140+
* @param values The output values.
141+
* @param metadata The run metadata, may be null.
142+
*/
143+
Result(List<String> names, List<Tensor> values, RunMetadata metadata) {
144+
this.map = new LinkedHashMap<>();
145+
this.list = new ArrayList<>(values);
146+
147+
if (names.size() != values.size()) {
148+
throw new IllegalArgumentException(
149+
"Expected same number of names and values, found names.length = "
150+
+ names.size()
151+
+ ", values.length = "
152+
+ values.size());
153+
}
154+
155+
for (int i = 0; i < names.size(); i++) {
156+
this.map.put(names.get(i), values.get(i));
157+
}
158+
this.metadata = metadata;
159+
this.closed = false;
160+
}
161+
162+
/**
163+
* Creates a Result from the names and values.
164+
*
165+
* @param outputs The run outputs.
166+
*/
167+
Result(LinkedHashMap<String,Tensor> outputs) {
168+
this.map = outputs;
169+
this.list = new ArrayList<>(outputs.size());
170+
for (Map.Entry<String, Tensor> e : outputs.entrySet()) {
171+
list.add(e.getValue());
172+
}
173+
this.metadata = null;
174+
this.closed = false;
175+
}
176+
177+
private final Map<String, Tensor> map;
178+
179+
private final List<Tensor> list;
180+
181+
private final RunMetadata metadata;
182+
183+
private boolean closed;
184+
185+
private static final Logger logger = Logger.getLogger(Result.class.getName());
186+
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

Lines changed: 3 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2019-2022 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -22,17 +22,12 @@
2222

2323
import com.google.protobuf.InvalidProtocolBufferException;
2424

25-
import java.sql.Array;
2625
import java.util.ArrayList;
2726
import java.util.Collections;
28-
import java.util.Iterator;
29-
import java.util.LinkedHashMap;
3027
import java.util.LinkedHashSet;
3128
import java.util.List;
3229
import java.util.Map;
33-
import java.util.Optional;
3430
import java.util.Set;
35-
import java.util.logging.Logger;
3631

3732
import org.bytedeco.javacpp.BytePointer;
3833
import org.bytedeco.javacpp.Pointer;
@@ -654,8 +649,9 @@ public SessionFunction function(Signature signature) {
654649
*
655650
* @param signature the signature of the function
656651
* @param arguments the arguments to call with.
652+
* @return The results of the function call.
657653
*/
658-
public Map<String, Tensor> run(Signature signature, Map<String, Tensor> arguments) {
654+
public Result run(Signature signature, Map<String, Tensor> arguments) {
659655
return function(signature).call(arguments);
660656
}
661657

@@ -704,130 +700,6 @@ public void restore(String prefix) {
704700
setInitialized();
705701
}
706702

707-
/**
708-
* An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s.
709-
*
710-
* <p>When this is closed it closes all the {@link Tensor}s inside it. If you maintain a
711-
* reference to a value after this object has been closed it will throw an {@link
712-
* IllegalStateException} upon access.
713-
*/
714-
public static final class Result implements AutoCloseable, Iterable<Map.Entry<String, Tensor>> {
715-
@Override
716-
public void close() {
717-
if (!closed) {
718-
closed = true;
719-
for (Tensor t : map.values()) {
720-
t.close();
721-
}
722-
} else {
723-
logger.warning("Closing an already closed Result");
724-
}
725-
}
726-
727-
@Override
728-
public Iterator<Map.Entry<String, Tensor>> iterator() {
729-
if (!closed) {
730-
return map.entrySet().iterator();
731-
} else {
732-
throw new IllegalStateException("Result is closed");
733-
}
734-
}
735-
736-
/**
737-
* Gets the value from the container at the specified index.
738-
*
739-
* <p>Throws {@link IllegalStateException} if the container has been closed, and {@link
740-
* IndexOutOfBoundsException} if the index is invalid.
741-
*
742-
* @param index The index to lookup.
743-
* @return The value at the index.
744-
*/
745-
public Tensor get(int index) {
746-
if (!closed) {
747-
return list.get(index);
748-
} else {
749-
throw new IllegalStateException("Result is closed");
750-
}
751-
}
752-
753-
/**
754-
* Returns the number of outputs in this Result.
755-
*
756-
* @return The number of outputs.
757-
*/
758-
public int size() {
759-
return map.size();
760-
}
761-
762-
/**
763-
* Gets the value from the container assuming it's not been closed.
764-
*
765-
* <p>Throws {@link IllegalStateException} if the container has been closed.
766-
*
767-
* @param key The key to lookup.
768-
* @return Optional.of the value if it exists.
769-
*/
770-
public Optional<Tensor> get(String key) {
771-
if (!closed) {
772-
Tensor value = map.get(key);
773-
if (value != null) {
774-
return Optional.of(value);
775-
} else {
776-
return Optional.empty();
777-
}
778-
} else {
779-
throw new IllegalStateException("Result is closed");
780-
}
781-
}
782-
783-
/**
784-
* Metadata about the run.
785-
*
786-
* <p>A <a
787-
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
788-
* protocol buffer</a>.
789-
*/
790-
public Optional<RunMetadata> getMetadata() {
791-
return Optional.ofNullable(metadata);
792-
}
793-
794-
/**
795-
* Creates a Result from the names and values produced by {@link Session.Runner#run()}.
796-
*
797-
* @param names The output names.
798-
* @param values The output values.
799-
* @param metadata The run metadata, may be null.
800-
*/
801-
Result(List<String> names, List<Tensor> values, RunMetadata metadata) {
802-
this.map = new LinkedHashMap<>();
803-
this.list = new ArrayList<>(values);
804-
805-
if (names.size() != values.size()) {
806-
throw new IllegalArgumentException(
807-
"Expected same number of names and values, found names.length = "
808-
+ names.size()
809-
+ ", values.length = "
810-
+ values.size());
811-
}
812-
813-
for (int i = 0; i < names.size(); i++) {
814-
this.map.put(names.get(i), values.get(i));
815-
}
816-
this.metadata = metadata;
817-
this.closed = false;
818-
}
819-
820-
private final Map<String, Tensor> map;
821-
822-
private final List<Tensor> list;
823-
824-
private final RunMetadata metadata;
825-
826-
private boolean closed;
827-
828-
private static final Logger logger = Logger.getLogger(Result.class.getName());
829-
}
830-
831703
Graph graph() {
832704
return graph;
833705
}

0 commit comments

Comments
 (0)