Skip to content
This repository was archived by the owner on Jul 15, 2025. It is now read-only.

Kotlin friendly names #1

Merged
merged 6 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
*.iml
.idea
target
147 changes: 87 additions & 60 deletions ndarray/src/main/java/org/tensorflow/ndarray/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@

package org.tensorflow.ndarray;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
* The shape of a Tensor or {@link NdArray}.
*
* <p>A {@code Shape} defines sizes along its axes. It may contain an unknown size for one of the
* axes or may be totally unknown, in which case not even the number of axes is known. If the size
* of an axis is unknown, {@link Shape#UNKNOWN_SIZE} should be used as its size.
* axes or may be totally unknown, in which case not even the number of axes is known. If the size of an axis is
* unknown, {@link Shape#UNKNOWN_SIZE} should be used as its size.
*/
public final class Shape {

/** The size of an unknown axis or the total unknown size for an unknown Shape. */
/**
* The size of an unknown axis or the total unknown size for an unknown Shape.
*/
public static long UNKNOWN_SIZE = -1L;

/**
Expand All @@ -53,9 +57,8 @@ public static Shape scalar() {
* Create a Shape representing a scalar or an N-dimensional value.
*
* <p>Creates a Shape representing a scalar or an N-dimensional value (N being at least 1), with
* the provided size for each dimension. A -1 indicates that the size of the corresponding
* dimension is unknown. If no sizes are provided, a Shape representing a scalar is created. For
* example:
* the provided size for each dimension. A -1 indicates that the size of the corresponding dimension is unknown. If no
* sizes are provided, a Shape representing a scalar is created. For example:
*
* <pre>{@code
* // A 2-element vector.
Expand All @@ -74,8 +77,8 @@ public static Shape scalar() {
* Shape scalar = Shape.of()
* }</pre>
*
* @param dimensionSizes number of elements in each dimension of this shape, if any, or
* {@link Shape#UNKNOWN_SIZE} if unknown.
* @param dimensionSizes number of elements in each dimension of this shape, if any, or {@link Shape#UNKNOWN_SIZE} if
* unknown.
* @return a new shape
*/
public static Shape of(long... dimensionSizes) {
Expand All @@ -91,8 +94,8 @@ public static Shape of(long... dimensionSizes) {
* <p>If {@link Shape#isUnknown()} is true or {@link Shape#hasUnknownDimension()} is true, {@link
* Shape#UNKNOWN_SIZE} is returned.
*
* @return The total number of elements a Tensor with this shape would have if it can be
* calculated, else {@link Shape#UNKNOWN_SIZE}.
* @return The total number of elements a Tensor with this shape would have if it can be calculated, else {@link
* Shape#UNKNOWN_SIZE}.
*/
public long size() {
if (size == null) {
Expand All @@ -107,14 +110,13 @@ public long size() {
* <p>If {@link Shape#isUnknown()} is true or the size of the dimension with the given index has
* an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
*
* @param i the index of the dimension to get the size for. If this Shape has a known number of
* dimensions, it must be &lt; {@link Shape#numDimensions()}. The index may be negative, in which
* case the position is counted from the end of the shape. E.g.: {@code size(-1)} returns the
* size of the last dimension, {@code size(-2)} the size of the second to last dimension etc.
* @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
* otherwise.
* @param i the index of the dimension to get the size for. If this Shape has a known number of dimensions, it must be
* &lt; {@link Shape#numDimensions()}. The index may be negative, in which case the position is counted from the end
* of the shape. E.g.: {@code size(-1)} returns the size of the last dimension, {@code size(-2)} the size of the
* second to last dimension etc.
* @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE} otherwise.
*/
public long size(int i) {
public long get(int i) {
if (dimensionSizes == null) {
return UNKNOWN_SIZE;
} else if (i >= 0) {
Expand All @@ -125,14 +127,15 @@ public long size(int i) {
}

/**
* Returns the number of dimensions of this Shape. -1 if unknown, 0 for a scalar, 1 for a vector,
* 2 for a matrix etc.
* Returns the number of dimensions of this Shape. -1 if unknown, 0 for a scalar, 1 for a vector, 2 for a matrix etc.
*/
public int numDimensions() {
return dimensionSizes != null ? dimensionSizes.length : -1;
}

/** Returns whether one or more dimensions of this Shape have an unknown size. */
/**
* Returns whether one or more dimensions of this Shape have an unknown size.
*/
public boolean hasUnknownDimension() {
if (dimensionSizes == null) {
return true;
Expand All @@ -145,29 +148,37 @@ public boolean hasUnknownDimension() {
return false;
}

/** Returns whether this Shape represents a scalar. */
/**
* Returns whether this Shape represents a scalar.
*/
public boolean isScalar() {
return dimensionSizes != null && dimensionSizes.length == 0;
}

/** Returns whether this Shape is the shape of a vector. */
/**
* Returns whether this Shape is the shape of a vector.
*/
public boolean isVector() {
return dimensionSizes != null && dimensionSizes.length == 1;
}

/** Returns whether this Shape is the shape of a matrix */
/**
* Returns whether this Shape is the shape of a matrix
*/
public boolean isMatrix() {
return dimensionSizes != null && dimensionSizes.length == 2;
}

/** Returns whether the number of dimensions of this Shape is unknown. */
/**
* Returns whether the number of dimensions of this Shape is unknown.
*/
public boolean isUnknown() {
return dimensionSizes == null;
}

/**
* Returns a defensive copy of the this Shape's axes. Changes to the returned array to not change
* this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
* Returns a defensive copy of the this Shape's axes. Changes to the returned array to not change this Shape's state.
* Returns null if {@link Shape#isUnknown()} is true.
*/
public long[] asArray() {
if (this.dimensionSizes == null) {
Expand All @@ -177,6 +188,24 @@ public long[] asArray() {
}
}

/**
* Returns a defensive copy of the this Shape's axes. Changes to the returned list do not change this Shape's state.
* Returns null if {@link Shape#isUnknown()} is true.
*/
public List<Long> toListOrNull() {
long[] array = asArray();
if (array == null) {
return null;
}

List<Long> list = new ArrayList<>(array.length);
for (long l : array) {
list.add(l);
}

return list;
}

@Override
public int hashCode() {
return dimensionSizes != null ? Arrays.hashCode(dimensionSizes) : super.hashCode();
Expand Down Expand Up @@ -212,7 +241,9 @@ public boolean equals(Object obj) {
return false;
}

/** Succinct description of the Shape meant for debugging. */
/**
* Succinct description of the Shape meant for debugging.
*/
@Override
public String toString() {
return Arrays.toString(dimensionSizes);
Expand All @@ -233,12 +264,10 @@ public Shape head() {
}

/**
* Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this
* shape
* Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this shape
*
* @param n the number of leading dimensions to get, must be &lt;= than {@link Shape#numDimensions()}
* @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of
* this Shape
* @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of this Shape
*/
public Shape take(int n) {
if (n > numDimensions()) {
Expand All @@ -250,20 +279,21 @@ public Shape take(int n) {
return Shape.of(newDimensions);
}

/** Returns a new Shape, with this Shape's first dimension removed. */
/**
* Returns a new Shape, with this Shape's first dimension removed.
*/
public Shape tail() {
if (dimensionSizes.length < 2) return Shape.of();
if (dimensionSizes.length < 2) {
return Shape.of();
}
return Shape.of(Arrays.copyOfRange(dimensionSizes, 1, dimensionSizes.length));
}

/**
* Returns an n-dimensional Shape with the dimensions matching the last n dimensions of this
* Shape.
* Returns an n-dimensional Shape with the dimensions matching the last n dimensions of this Shape.
*
* @param n the number of trailing dimensions to get, must be &lt;= than {@link
* Shape#numDimensions()}
* @return an n-dimensional shape with the dimensions matching the last n dimensions of this
* Shape, never null
* @param n the number of trailing dimensions to get, must be &lt;= than {@link Shape#numDimensions()}
* @return an n-dimensional shape with the dimensions matching the last n dimensions of this Shape, never null
*/
public Shape takeLast(int n) {
if (n > numDimensions()) {
Expand All @@ -276,12 +306,14 @@ public Shape takeLast(int n) {
}

/**
* Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code begin} to {@code end}.
* Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code begin} to {@code
* end}.
*
* @param begin Where to start the sub-shape.
* @param end Where to end the sub-shape, exclusive.
* @return the sub-shape bounded by begin and end.
*/
public Shape subShape(int begin, int end){
public Shape subShape(int begin, int end) {
if (end > numDimensions()) {
throw new ArrayIndexOutOfBoundsException(
"End index " + end + " out of bounds: shape only has " + numDimensions() + " dimensions.");
Expand All @@ -297,12 +329,11 @@ public Shape subShape(int begin, int end){
}

/**
* Returns a new Shape, with a new first dimension added. In order for this call to succeed,
* {@link Shape#isUnknown()} must be {@code false}.
* Returns a new Shape, with a new first dimension added. In order for this call to succeed, {@link Shape#isUnknown()}
* must be {@code false}.
*
* @param firstDimension the dimension to prepend
* @return a new shape with the given dimension first, followed by this Shape's dimensions, never
* null
* @return a new shape with the given dimension first, followed by this Shape's dimensions, never null
*/
public Shape prepend(long firstDimension) {
long[] newDimensions = new long[dimensionSizes.length + 1];
Expand All @@ -313,8 +344,8 @@ public Shape prepend(long firstDimension) {
}

/**
* Returns a new Shape, with a new last dimension added. In order for this call to succeed, {@link
* Shape#isUnknown()} must be {@code false}.
* Returns a new Shape, with a new last dimension added. In order for this call to succeed, {@link Shape#isUnknown()}
* must be {@code false}.
*
* @param lastDimension the dimension to append
* @return a new Shape with this Shape's dimensions followed by the given dimension, never null
Expand All @@ -328,13 +359,11 @@ public Shape append(long lastDimension) {
}

/**
* Returns a new Shape, with another Shape's dimensions prepended. For both this Shape and the
* other Shape, {@link Shape#isUnknown()} must return false. E.g. {@code
* Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) }
* Returns a new Shape, with another Shape's dimensions prepended. For both this Shape and the other Shape, {@link
* Shape#isUnknown()} must return false. E.g. {@code Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) }
*
* @param other another Shape, must not be {@code null}, must not be unknown
* @return A new Shape consisting of the given Shape's dimensions followed by this Shape's
* dimensions, never null
* @return A new Shape consisting of the given Shape's dimensions followed by this Shape's dimensions, never null
*/
public Shape prepend(Shape other) {
long[] newDimensions = new long[other.dimensionSizes.length + dimensionSizes.length];
Expand All @@ -345,13 +374,11 @@ public Shape prepend(Shape other) {
}

/**
* Returns a new Shape, with another Shapes' dimensions appended. For both this Shape and the
* other Shape, {@link Shape#isUnknown()} must return false. E.g. @code
* Shape.of(3,4).append(Shape.of(1,2)) =&gt; Shape.of(3,4,1,2) }
* Returns a new Shape, with another Shapes' dimensions appended. For both this Shape and the other Shape, {@link
* Shape#isUnknown()} must return false. E.g. @code Shape.of(3,4).append(Shape.of(1,2)) =&gt; Shape.of(3,4,1,2) }
*
* @param other another Shape, must not be {@code null}, must not be unknown
* @return A new Shape consisting of this Shape's dimensions followed by the given Shape's
* dimensions
* @return A new Shape consisting of this Shape's dimensions followed by the given Shape's dimensions
*/
public Shape append(Shape other) {
long[] newDimensions = new long[dimensionSizes.length + other.dimensionSizes.length];
Expand Down Expand Up @@ -381,8 +408,8 @@ private static long computeSize(long[] dimensionSizes) {
* <p>
*
* <p>Two possibly-partially-defined shapes are compatible if there exists a fully-defined shape
* that both shapes can represent. Thus, compatibility allows the shape inference code to reason
* about partially-defined shapes. For example:
* that both shapes can represent. Thus, compatibility allows the shape inference code to reason about
* partially-defined shapes. For example:
*
* <ul>
* <li><code>Shape.unknown()</code> is compatible with all shapes.
Expand Down Expand Up @@ -423,7 +450,7 @@ public boolean isCompatibleWith(Shape shape) {
return false;
}
for (int i = 0; i < numDimensions(); i++) {
if (!isCompatible(size(i), shape.size(i))) {
if (!isCompatible(get(i), shape.get(i))) {
return false;
}
}
Expand Down
Loading