Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package org.tensorflow;

import org.tensorflow.Tensor.ToStringOptions;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.Shaped;
import org.tensorflow.op.Op;
Expand Down Expand Up @@ -65,6 +66,17 @@ default T asTensor() {
return asOutput().asTensor();
}

/**
* Returns the String representation of the tensor elements at this operand.
*
* @param options overrides the default configuration
* @return the String representation of the tensor elements
* @throws IllegalStateException if this is an operand of a graph
*/
default String dataToString(ToStringOptions... options) {
return asTensor().dataToString(options);
}

/**
* Returns the tensor type of this operand
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package org.tensorflow;

import java.util.function.Consumer;
import org.tensorflow.internal.types.Tensors;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.Shaped;
import org.tensorflow.ndarray.buffer.ByteDataBuffer;
Expand Down Expand Up @@ -136,8 +135,7 @@ static <T extends TType> T of(Class<T> type, Shape shape, Consumer<T> dataInitia
* shape.
*
* <p>This could be useful for tensor types that stores data but also metadata in the tensor
* memory,
* such as the lookup table in a tensor of strings.
* memory, such as the lookup table in a tensor of strings.
*
* @param <T> the tensor type
* @param type the tensor type class
Expand Down Expand Up @@ -202,7 +200,7 @@ static <T extends TType> T of(Class<T> type, Shape shape, ByteDataBuffer rawData
* Returns the String representation of elements stored in the tensor.
*
* @param options overrides the default configuration
* @return the String representation of the tensor
* @return the String representation of the tensor elements
* @throws IllegalStateException if this is an operand of a graph
*/
default String dataToString(ToStringOptions... options) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the use of a vararg to handle the optional presence of options is a good idea. Having a second method accepting no parameter would be better.

We use vararg options in the op wrappers because we want to limit the number of methods that ending up in the *Ops classes, which is already more than a thousand. But here it's fine "duplicating" it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @karllessard

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.tensorflow.internal.types;
package org.tensorflow;

import java.util.ArrayList;
import java.util.Iterator;
Expand All @@ -11,7 +11,7 @@
/**
* Tensor helper methods.
*/
public final class Tensors {
final class Tensors {

/**
* Prevent construction.
Expand All @@ -30,12 +30,20 @@ public static String toString(Tensor tensor) {
}

/**
* Returns a String representation of a tensor's data. If the output is wider than {@code
* maxWidth} characters, it is truncated and "{@code ...}" is inserted in place of the removed
* data.
*
* @param tensor a tensor
* @param maxWidth the maximum width of the output in characters ({@code null} if unlimited). This
* limit may surpassed if the first or last element are too long.
* @return the String representation of the tensor
*/
public static String toString(Tensor tensor, Integer maxWidth) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about having that many ways to print a tensors (if we can call tensor.toString, why having Tensors.toString(tensor)?) It's a bit confusing and I find that that feature maybe spreads out a bit in multiple parts of the code.

I have a new design to propose, let me know what you think. But what about having this toString logic in another called, let say, TensorPrinter, which can be returned or directly invoked in the Tensor class? i.e.

class TensorPrinter {

    TensorPrinter(Tensor tensor, int maxWidth) { ... }

    TensorPrinter withMaxWidth(int maxWidth) {
        return new TensorPrinter(this.tensor, maxWidth);
    }

    String print() {
        return .... (this logic here)
    }
}

interface Tensor {

    String print() {
        return new TensorPrinter(this, null).print();
    }

    TensorPrinter printer() {
        return new TensorPrinter(this, null);
    }
}

Tensor t = TFloat32.scalarOf(10.0f);
t.print();
t.printer().maxWidth(10).anotherOption(234).print();

Copy link
Author

@cowwoc cowwoc Apr 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer your proposal. It is exactly what I was hoping to implement in the long term but I ended up pushing what you see above as a stepping stone in the right direction.

Also, consider the relationship (if any) between Ops.print() and this functionality. I may be wrong, but I believe the C++ implementation of tf.print() returns roughly what we're trying to implement in this PR. Does it make sense to have Ops.print() invoke this new code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to call it print() and not toString()?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do that, then I would then expect it to work the same as Ops.print() since the two share the same name. Are we planning to merge the two?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Ops.print() simply writes a given string to an output stream, like the console, when the graph is being executed.

I picked the "print" expression in this example just to show the logic but it can for sure be named otherwise to avoid the confusion. Maybe TensorStringifier, tensor.stringify() and tensor.stringifier()? Or the methods could remain dataToString() as well.

if (tensor instanceof RawTensor) {
System.out.println("Got rawTensor: " + tensor);
tensor = ((RawTensor) tensor).asTypedTensor();
}
if (!(tensor instanceof NdArray)) {
throw new AssertionError("Expected tensor to extend NdArray");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you only accept typed tensors for this method, the endpoint should probably be added to TType instead of Tensor.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karllessard Is there a way for me to handle all tensors (not just TType)? It's okay if not. I'm just asking.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karllessard How do you create a RawTensor that doesn't inherit from NdArray? Or more precisely how do you create a plain Tensor object that doesn't inherit from NdArray without using one of the org.tensorflow.types classes?

RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), -1);
Still inherits from NdArray.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't user-side, but all of the Tensor.fromHandle methods don't inherit, afaik. And in places where that's used (Session, somewhere in Operand.asTensor()) if a asTypedTensor() is missed somewhere it's nice to not crash when printing. However, afaik, that's covered by the tensor = ((RawTensor) tensor).asTypedTensor(); line, so the method should support all Tensors not just TType.

Expand All @@ -62,7 +70,7 @@ public static String toString(Tensor tensor, Integer maxWidth) {
private static String toString(Iterator<? extends NdArray<?>> iterator, Shape shape,
int dimension, Integer maxWidth) {
if (dimension < shape.numDimensions() - 1) {
StringJoiner joiner = new StringJoiner(",\n", indent(dimension) + "[\n",
StringJoiner joiner = new StringJoiner("\n", indent(dimension) + "[\n",
"\n" + indent(dimension) + "]");
for (long i = 0, size = shape.size(dimension); i < size; ++i) {
String element = toString(iterator, shape, dimension + 1, maxWidth);
Expand Down Expand Up @@ -92,6 +100,9 @@ private static String toString(Iterator<? extends NdArray<?>> iterator, Shape sh
}

/**
* Truncates the width of a String if it's too long, inserting "{@code ...}" in place of the
* removed data.
*
* @param input the input to truncate
* @param maxWidth the maximum width of the output in characters
* @param lengths the lengths of elements inside input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ public void fromHandle() {
// close() on both Tensors.
final FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{{1, 2, 3}, {4, 5, 6}});
try (TFloat32 src = TFloat32.tensorOf(matrix)) {
TFloat32 cpy = (TFloat32)RawTensor.fromHandle(src.asRawTensor().nativeHandle()).asTypedTensor();
TFloat32 cpy = (TFloat32) RawTensor.fromHandle(src.asRawTensor().nativeHandle())
.asTypedTensor();
assertEquals(src.type(), cpy.type());
assertEquals(src.dataType(), cpy.dataType());
assertEquals(src.shape().numDimensions(), cpy.shape().numDimensions());
Expand Down Expand Up @@ -566,6 +567,17 @@ public void dataToString() {
String actual = t.dataToString(Tensor.maxWidth(12));
assertEquals("[3, 0, 1, 2]", actual);
}
try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{1, 2, 3}, {3, 2, 1}}))) {
String actual = t.dataToString(Tensor.maxWidth(12));
assertEquals("[\n"
+ " [1, 2, 3]\n"
+ " [3, 2, 1]\n"
+ "]", actual);
}
try (RawTensor t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[]{3, 0, 1, 2})).asRawTensor()) {
String actual = t.dataToString(Tensor.maxWidth(12));
assertEquals("[3, 0, 1, 2]", actual);
}
}

// Workaround for cross compiliation
Expand Down