Skip to content

Commit 0b86851

Browse files
float32 in tests
1 parent ed60651 commit 0b86851

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

tests/tensor/test_extra_ops.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,7 +1474,10 @@ def test_pack_basic(self):
14741474
y = pt.tensor("y", shape=(5,))
14751475
z = pt.tensor("z", shape=(3, 3))
14761476

1477-
input_dict = {variable: np.zeros(variable.type.shape) for variable in [x, y, z]}
1477+
input_dict = {
1478+
variable: np.zeros(variable.type.shape, dtype=config.floatX)
1479+
for variable in [x, y, z]
1480+
}
14781481

14791482
# Simple case, reduce all axes, equivalent to einops '*'
14801483
packed_tensor, packed_shapes = pack(x, y, z, axes=None)
@@ -1504,7 +1507,10 @@ def test_pack_basic(self):
15041507
y = pt.tensor("y", shape=(3, 5))
15051508
z = pt.tensor("z", shape=(3, 3, 3))
15061509
packed_tensor, packed_shapes = pack(x, y, z, axes=0)
1507-
input_dict = {variable: np.zeros(variable.type.shape) for variable in [x, y, z]}
1510+
input_dict = {
1511+
variable: np.zeros(variable.type.shape, dtype=config.floatX)
1512+
for variable in [x, y, z]
1513+
}
15081514
assert packed_tensor.type.shape == (3, 15)
15091515
for tensor, packed_shape in zip([x, y, z], packed_shapes):
15101516
assert packed_shape.type.shape == (tensor.ndim - 1,)
@@ -1526,7 +1532,10 @@ def test_pack_basic(self):
15261532

15271533
z = pt.tensor("z", shape=(3, 1, 7, 2))
15281534
packed_tensor, packed_shapes = pack(x, y, z, axes=[0, 3])
1529-
input_dict = {variable: np.zeros(variable.type.shape) for variable in [x, y, z]}
1535+
input_dict = {
1536+
variable: np.zeros(variable.type.shape, dtype=config.floatX)
1537+
for variable in [x, y, z]
1538+
}
15301539
assert packed_tensor.type.shape == (3, 13, 2)
15311540
for tensor, packed_shape in zip([x, y, z], packed_shapes):
15321541
assert packed_shape.type.shape == (tensor.ndim - 2,)
@@ -1546,7 +1555,9 @@ def test_pack_unpack_round_trip(self):
15461555

15471556
fn = pytensor.function([x, y, z], new_outputs, mode="FAST_COMPILE")
15481557

1549-
input_vals = [rng.normal(size=var.type.shape) for var in [x, y, z]]
1558+
input_vals = [
1559+
rng.normal(size=var.type.shape).astype(config.floatX) for var in [x, y, z]
1560+
]
15501561
output_vals = fn(*input_vals)
15511562

15521563
for input_val, output_val in zip(input_vals, output_vals, strict=True):

0 commit comments

Comments
 (0)