Skip to content

Commit 1eb414f

Browse files
committed
Remove custom mlx_linalg_mode from test files
Eliminated the definition and usage of the custom mlx_linalg_mode in test_nlinalg.py and test_slinalg.py, replacing it with the default mlx_mode. This simplifies the test setup and removes unnecessary configuration.
1 parent 364261b commit 1eb414f

File tree

2 files changed

+7
-15
lines changed

2 files changed

+7
-15
lines changed

tests/link/mlx/test_nlinalg.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode
99

1010

11-
mlx_linalg_mode = mlx_mode.including("blockwise")
12-
13-
1411
@pytest.mark.parametrize("compute_uv", [True, False])
1512
def test_mlx_svd(compute_uv):
1613
rng = np.random.default_rng(15)
@@ -25,7 +22,7 @@ def test_mlx_svd(compute_uv):
2522
[A],
2623
out,
2724
[A_val],
28-
mlx_mode=mlx_linalg_mode,
25+
mlx_mode=mlx_mode,
2926
assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True),
3027
)
3128

@@ -42,7 +39,7 @@ def test_mlx_kron():
4239
[A, B],
4340
[out],
4441
[A_val, B_val],
45-
mlx_mode=mlx_linalg_mode,
42+
mlx_mode=mlx_mode,
4643
assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True),
4744
)
4845

@@ -62,7 +59,7 @@ def test_mlx_inv(op):
6259
[A],
6360
[out],
6461
[A_val],
65-
mlx_mode=mlx_linalg_mode,
62+
mlx_mode=mlx_mode,
6663
assert_fn=partial(
6764
np.testing.assert_allclose, atol=1e-6, rtol=1e-6, strict=True
6865
),

tests/link/mlx/test_slinalg.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,6 @@
99
from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode
1010

1111

12-
# mlx complains about useless vmap (when there are no batch dims), so we need to include
13-
# local_remove_useless_blockwise rewrite for these tests
14-
mlx_linalg_mode = mlx_mode.including("blockwise")
15-
16-
1712
@pytest.mark.parametrize("lower", [True, False])
1813
def test_mlx_cholesky(lower):
1914
rng = np.random.default_rng(15)
@@ -29,7 +24,7 @@ def test_mlx_cholesky(lower):
2924
[A],
3025
[out],
3126
[A_val],
32-
mlx_mode=mlx_linalg_mode,
27+
mlx_mode=mlx_mode,
3328
assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True),
3429
)
3530

@@ -62,7 +57,7 @@ def test_mlx_solve(assume_a):
6257
[A, b],
6358
[out],
6459
[A_val, b_val],
65-
mlx_mode=mlx_linalg_mode,
60+
mlx_mode=mlx_mode,
6661
assert_fn=partial(
6762
np.testing.assert_allclose, atol=1e-6, rtol=1e-6, strict=True
6863
),
@@ -90,7 +85,7 @@ def test_mlx_SolveTriangular(lower, trans):
9085
[A, b],
9186
[out],
9287
[A_val, b_val],
93-
mlx_mode=mlx_linalg_mode,
88+
mlx_mode=mlx_mode,
9489
assert_fn=partial(
9590
np.testing.assert_allclose, atol=1e-6, rtol=1e-6, strict=True
9691
),
@@ -109,6 +104,6 @@ def test_mlx_LU():
109104
[A],
110105
out,
111106
[A_val],
112-
mlx_mode=mlx_linalg_mode,
107+
mlx_mode=mlx_mode,
113108
assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True),
114109
)

0 commit comments

Comments
 (0)