Skip to content

Commit b484720

Browse files
committed
pre-commit
1 parent 9d6db54 commit b484720

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

pytensor/link/mlx/dispatch/blockwise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
@mlx_funcify.register(Blockwise)
88
def funcify_Blockwise(op: Blockwise, node, **kwargs):
99
"""Convert a Blockwise operation to an MLX function.
10-
10+
1111
This handles vectorized operations by using mx.vmap when there are batch
1212
dimensions, or returning the core function directly when there are no
1313
batch dimensions to vectorize over.
14-
14+
1515
Parameters
1616
----------
1717
op : Blockwise
@@ -20,7 +20,7 @@ def funcify_Blockwise(op: Blockwise, node, **kwargs):
2020
The node containing the operation and its inputs.
2121
**kwargs
2222
Additional keyword arguments.
23-
23+
2424
Returns
2525
-------
2626
callable

tests/link/mlx/test_blockwise.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,41 +29,41 @@ def test_blockwise_conv1d():
2929

3030
def test_blockwise_no_batch_dimensions():
3131
"""Test that Blockwise returns the core function when there are no batch dimensions.
32-
32+
3333
This verifies the fix for the vmap dispatcher issue where mx.vmap should not
3434
be called when there are no batch dimensions to vectorize over.
3535
"""
3636
rng = np.random.default_rng(42)
37-
37+
3838
# Create a blockwise matmul with no batch dimensions (core operation only)
3939
x = pt.matrix("x")
4040
y = pt.matrix("y")
41-
41+
4242
blockwise_matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)")
4343
z = blockwise_matmul(x, y)
44-
44+
4545
x_test = rng.normal(size=(2, 3))
4646
y_test = rng.normal(size=(3, 4))
47-
47+
4848
compare_mlx_and_py([x, y], [z], [x_test, y_test], must_be_device_array=True)
4949

5050

5151
def test_blockwise_all_broadcastable_batch_dims():
5252
"""Test that Blockwise returns the core function when all batch dims are broadcastable.
53-
53+
5454
When all batch dimensions are size-1 (broadcastable), vmap should not be called
5555
since there's no actual vectorization needed.
5656
"""
5757
rng = np.random.default_rng(43)
58-
58+
5959
# Create inputs with size-1 batch dimensions
6060
x = tensor("x", shape=(1, 2, 3))
6161
y = tensor("y", shape=(1, 3, 4))
62-
62+
6363
blockwise_matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)")
6464
z = blockwise_matmul(x, y)
65-
65+
6666
x_test = rng.normal(size=(1, 2, 3))
6767
y_test = rng.normal(size=(1, 3, 4))
68-
68+
6969
compare_mlx_and_py([x, y], [z], [x_test, y_test], must_be_device_array=True)

0 commit comments

Comments
 (0)