Skip to content

Commit 73331a6

Browse files
authored
Fix JAX Layout tiling argument for JAX 0.7.1 (#133)
The `jax.experimental.layout.Layout` argument is renamed from `_tiling` to `tiling`.
1 parent d93b28b commit 73331a6

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,14 @@ def _create_sparsecore_distribution(
221221
if jax.__version_info__ >= (0, 6, 3)
222222
else jax_layout.DeviceLocalLayout # type: ignore
223223
)
224+
layout = (
225+
LayoutClass(major_to_minor=(0, 1), tiling=((8,),)) # type: ignore
226+
if jax.__version_info__ >= (0, 7, 1)
227+
else LayoutClass(major_to_minor=(0, 1), _tiling=((8,),)) # type: ignore
228+
)
224229
# pylint: disable-next=protected-access
225230
sparsecore_layout._backend_layout = jax_layout.Format(
226-
LayoutClass(major_to_minor=(0, 1), _tiling=((8,),)), # type: ignore
231+
layout, # type: ignore
227232
jax.sharding.NamedSharding(
228233
device_mesh.backend_mesh,
229234
jax.sharding.PartitionSpec(

keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,14 @@ def _create_sparsecore_layout(
4646
if jax.__version_info__ >= (0, 6, 3)
4747
else jax_layout.DeviceLocalLayout # type: ignore
4848
)
49+
layout = (
50+
LayoutClass(major_to_minor=(0, 1), tiling=((8,),)) # type: ignore
51+
if jax.__version_info__ >= (0, 7, 1)
52+
else LayoutClass(major_to_minor=(0, 1), _tiling=((8,),)) # type: ignore
53+
)
4954
# pylint: disable-next=protected-access
5055
sparsecore_layout._backend_layout = jax_layout.Format(
51-
LayoutClass(major_to_minor=(0, 1), _tiling=((8,),)), # type: ignore
56+
layout, # type: ignore
5257
jax.sharding.NamedSharding(
5358
device_mesh.backend_mesh, jax.sharding.PartitionSpec(axes)
5459
),

0 commit comments

Comments
 (0)