Skip to content

Commit 3973b15

Browse files
authored
Support scalar view for tf backend. (#21802)
* Support scalar view for tf backend. * update last dim
1 parent 72dfa36 commit 3973b15

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

keras/src/backend/tensorflow/numpy.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,11 @@ def view(x, dtype=None):
10101010
old_itemsize = old_dtype.size
10111011
new_itemsize = new_dtype.size
10121012

1013-
if list(x.shape)[-1] * old_itemsize % new_itemsize != 0:
1013+
old_shape = list(shape_op(x))
1014+
last_dim_size = old_shape[-1] if len(old_shape) > 0 else -1
1015+
if (last_dim_size == -1 and old_itemsize != new_itemsize) or (
1016+
last_dim_size * old_itemsize % new_itemsize != 0
1017+
):
10141018
raise ValueError(
10151019
f"Cannot view array of shape {x.shape} and dtype {old_dtype} "
10161020
f"as dtype {new_dtype} because the total number of bytes "
@@ -1027,8 +1031,6 @@ def view(x, dtype=None):
10271031
cast_tensor = tf.bitcast(flat_tensor, type=new_dtype)
10281032
return tf.reshape(cast_tensor, new_shape)
10291033
else:
1030-
old_shape = list(shape_op(x))
1031-
last_dim_size = old_shape[-1]
10321034
ratio = new_itemsize // old_itemsize
10331035
if isinstance(last_dim_size, int) and last_dim_size % ratio != 0:
10341036
raise ValueError(

keras/src/ops/numpy_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4136,6 +4136,16 @@ def test_concatenate(self):
41364136
)
41374137

41384138
def test_view(self):
4139+
x = np.array(1, dtype="int16")
4140+
result = knp.view(x, dtype="float16")
4141+
assert backend.standardize_dtype(result.dtype) == "float16"
4142+
4143+
with self.assertRaises(Exception):
4144+
result = knp.view(x, dtype="int8")
4145+
4146+
with self.assertRaises(Exception):
4147+
result = knp.view(x, dtype="int32")
4148+
41394149
x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype="int16")
41404150
result = knp.view(x, dtype="int16")
41414151
assert backend.standardize_dtype(result.dtype) == "int16"

0 commit comments

Comments
 (0)