@@ -1474,7 +1474,10 @@ def test_pack_basic(self):
14741474 y = pt .tensor ("y" , shape = (5 ,))
14751475 z = pt .tensor ("z" , shape = (3 , 3 ))
14761476
1477- input_dict = {variable : np .zeros (variable .type .shape ) for variable in [x , y , z ]}
1477+ input_dict = {
1478+ variable : np .zeros (variable .type .shape , dtype = config .floatX )
1479+ for variable in [x , y , z ]
1480+ }
14781481
14791482 # Simple case, reduce all axes, equivalent to einops '*'
14801483 packed_tensor , packed_shapes = pack (x , y , z , axes = None )
@@ -1504,7 +1507,10 @@ def test_pack_basic(self):
15041507 y = pt .tensor ("y" , shape = (3 , 5 ))
15051508 z = pt .tensor ("z" , shape = (3 , 3 , 3 ))
15061509 packed_tensor , packed_shapes = pack (x , y , z , axes = 0 )
1507- input_dict = {variable : np .zeros (variable .type .shape ) for variable in [x , y , z ]}
1510+ input_dict = {
1511+ variable : np .zeros (variable .type .shape , dtype = config .floatX )
1512+ for variable in [x , y , z ]
1513+ }
15081514 assert packed_tensor .type .shape == (3 , 15 )
15091515 for tensor , packed_shape in zip ([x , y , z ], packed_shapes ):
15101516 assert packed_shape .type .shape == (tensor .ndim - 1 ,)
@@ -1526,7 +1532,10 @@ def test_pack_basic(self):
15261532
15271533 z = pt .tensor ("z" , shape = (3 , 1 , 7 , 2 ))
15281534 packed_tensor , packed_shapes = pack (x , y , z , axes = [0 , 3 ])
1529- input_dict = {variable : np .zeros (variable .type .shape ) for variable in [x , y , z ]}
1535+ input_dict = {
1536+ variable : np .zeros (variable .type .shape , dtype = config .floatX )
1537+ for variable in [x , y , z ]
1538+ }
15301539 assert packed_tensor .type .shape == (3 , 13 , 2 )
15311540 for tensor , packed_shape in zip ([x , y , z ], packed_shapes ):
15321541 assert packed_shape .type .shape == (tensor .ndim - 2 ,)
@@ -1546,7 +1555,9 @@ def test_pack_unpack_round_trip(self):
15461555
15471556 fn = pytensor .function ([x , y , z ], new_outputs , mode = "FAST_COMPILE" )
15481557
1549- input_vals = [rng .normal (size = var .type .shape ) for var in [x , y , z ]]
1558+ input_vals = [
1559+ rng .normal (size = var .type .shape ).astype (config .floatX ) for var in [x , y , z ]
1560+ ]
15501561 output_vals = fn (* input_vals )
15511562
15521563 for input_val , output_val in zip (input_vals , output_vals , strict = True ):
0 commit comments