Skip to content

Commit 9d6db54

Browse files
committed
Fix Blockwise vmap dispatch for no batch dimensions
Updates comments in funcify_Blockwise to avoid confusion about behaviour. Adds tests to verify correct behavior for these cases.
1 parent f83c05b commit 9d6db54

File tree

2 files changed

+74
-4
lines changed

2 files changed

+74
-4
lines changed

pytensor/link/mlx/dispatch/blockwise.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,58 @@
66

77
@mlx_funcify.register(Blockwise)
88
def funcify_Blockwise(op: Blockwise, node, **kwargs):
9-
# 2) Otherwise, get the core python function for this Blockwise
9+
"""Convert a Blockwise operation to an MLX function.
10+
11+
This handles vectorized operations by using mx.vmap when there are batch
12+
dimensions, or returning the core function directly when there are no
13+
batch dimensions to vectorize over.
14+
15+
Parameters
16+
----------
17+
op : Blockwise
18+
The Blockwise operation to convert.
19+
node : Apply
20+
The node containing the operation and its inputs.
21+
**kwargs
22+
Additional keyword arguments.
23+
24+
Returns
25+
-------
26+
callable
27+
An MLX function that implements the Blockwise operation, either
28+
the core function directly or wrapped with mx.vmap for vectorization.
29+
"""
30+
# Get the core python function for this Blockwise operation
1031
core_node = op._create_dummy_core_node(node.inputs)
1132
core_f = mlx_funcify(op.core_op, core_node)
1233

13-
# 3) Determine how many inputs correspond to batch dimensions
34+
# Determine how many batch dimensions are present in the output
1435
n_batch = op.batch_ndim(node)
1536

16-
# 4) Handle case where no vectorization is needed
37+
# If there are no batch dimensions, just return the core function
1738
if n_batch == 0:
1839
return core_f
1940

20-
# 5) Vectorize using mx.vmap over any batched inputs
41+
# Build in_axes specification for mx.vmap
42+
# Each input can be vectorized (axis=0) or static (axis=None)
2143
in_axes: list[int | None] = []
2244
for inp, sig in zip(node.inputs, op.inputs_sig):
2345
batch_ndim = inp.type.ndim - len(sig)
2446
if batch_ndim == 0:
47+
# Input has no batch dimensions - treat as static
2548
in_axes.append(None)
2649
continue
2750

2851
batch_bcast = inp.type.broadcastable[:batch_ndim]
2952
# If all batch dims are broadcastable (size 1), treat input as static
53+
# Otherwise, vectorize over the first dimension (axis=0)
3054
in_axes.append(0 if not all(batch_bcast) else None)
3155

56+
# If all inputs are static (no actual vectorization needed), return core function
57+
# This prevents calling mx.vmap with all-None in_axes, which would raise:
58+
# "ValueError: At least one of in_axes must be non-None"
3259
if not any(axis == 0 for axis in in_axes):
3360
return core_f
3461

62+
# Apply mx.vmap to vectorize the core function over batch dimensions
3563
return mx.vmap(core_f, in_axes=tuple(in_axes))

tests/link/mlx/test_blockwise.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,45 @@ def test_blockwise_conv1d():
2525

2626
# assert isinstance(out.owner.op, Blockwise)
2727
compare_mlx_and_py([a, b], [out], test_values, must_be_device_array=True)
28+
29+
30+
def test_blockwise_no_batch_dimensions():
31+
"""Test that Blockwise returns the core function when there are no batch dimensions.
32+
33+
This verifies the fix for the vmap dispatcher issue where mx.vmap should not
34+
be called when there are no batch dimensions to vectorize over.
35+
"""
36+
rng = np.random.default_rng(42)
37+
38+
# Create a blockwise matmul with no batch dimensions (core operation only)
39+
x = pt.matrix("x")
40+
y = pt.matrix("y")
41+
42+
blockwise_matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)")
43+
z = blockwise_matmul(x, y)
44+
45+
x_test = rng.normal(size=(2, 3))
46+
y_test = rng.normal(size=(3, 4))
47+
48+
compare_mlx_and_py([x, y], [z], [x_test, y_test], must_be_device_array=True)
49+
50+
51+
def test_blockwise_all_broadcastable_batch_dims():
52+
"""Test that Blockwise returns the core function when all batch dims are broadcastable.
53+
54+
When all batch dimensions are size-1 (broadcastable), vmap should not be called
55+
since there's no actual vectorization needed.
56+
"""
57+
rng = np.random.default_rng(43)
58+
59+
# Create inputs with size-1 batch dimensions
60+
x = tensor("x", shape=(1, 2, 3))
61+
y = tensor("y", shape=(1, 3, 4))
62+
63+
blockwise_matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)")
64+
z = blockwise_matmul(x, y)
65+
66+
x_test = rng.normal(size=(1, 2, 3))
67+
y_test = rng.normal(size=(1, 3, 4))
68+
69+
compare_mlx_and_py([x, y], [z], [x_test, y_test], must_be_device_array=True)

0 commit comments

Comments
 (0)