@@ -29,41 +29,41 @@ def test_blockwise_conv1d():
2929
3030def test_blockwise_no_batch_dimensions ():
3131 """Test that Blockwise returns the core function when there are no batch dimensions.
32-
32+
3333 This verifies the fix for the vmap dispatcher issue where mx.vmap should not
3434 be called when there are no batch dimensions to vectorize over.
3535 """
3636 rng = np .random .default_rng (42 )
37-
37+
3838 # Create a blockwise matmul with no batch dimensions (core operation only)
3939 x = pt .matrix ("x" )
4040 y = pt .matrix ("y" )
41-
41+
4242 blockwise_matmul = Blockwise (Dot (), signature = "(i,j),(j,k)->(i,k)" )
4343 z = blockwise_matmul (x , y )
44-
44+
4545 x_test = rng .normal (size = (2 , 3 ))
4646 y_test = rng .normal (size = (3 , 4 ))
47-
47+
4848 compare_mlx_and_py ([x , y ], [z ], [x_test , y_test ], must_be_device_array = True )
4949
5050
5151def test_blockwise_all_broadcastable_batch_dims ():
5252 """Test that Blockwise returns the core function when all batch dims are broadcastable.
53-
53+
5454 When all batch dimensions are size-1 (broadcastable), vmap should not be called
5555 since there's no actual vectorization needed.
5656 """
5757 rng = np .random .default_rng (43 )
58-
58+
5959 # Create inputs with size-1 batch dimensions
6060 x = tensor ("x" , shape = (1 , 2 , 3 ))
6161 y = tensor ("y" , shape = (1 , 3 , 4 ))
62-
62+
6363 blockwise_matmul = Blockwise (Dot (), signature = "(i,j),(j,k)->(i,k)" )
6464 z = blockwise_matmul (x , y )
65-
65+
6666 x_test = rng .normal (size = (1 , 2 , 3 ))
6767 y_test = rng .normal (size = (1 , 3 , 4 ))
68-
68+
6969 compare_mlx_and_py ([x , y ], [z ], [x_test , y_test ], must_be_device_array = True )
0 commit comments