Skip to content

Commit fcb9b85

Browse files
authored
Merge branch 'master' into fix-mask-parameter-issue-21154
2 parents 99b8b64 + 6d06085 commit fcb9b85

File tree

30 files changed

+447
-83
lines changed

30 files changed

+447
-83
lines changed

README.md

-4.99 KB
Loading

keras/api/_tf_keras/keras/distribution/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
distribute_tensor as distribute_tensor,
1616
)
1717
from keras.src.distribution.distribution_lib import distribution as distribution
18+
from keras.src.distribution.distribution_lib import (
19+
get_device_count as get_device_count,
20+
)
1821
from keras.src.distribution.distribution_lib import initialize as initialize
1922
from keras.src.distribution.distribution_lib import list_devices as list_devices
2023
from keras.src.distribution.distribution_lib import (

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
from keras.src.ops.numpy import argpartition as argpartition
141141
from keras.src.ops.numpy import argsort as argsort
142142
from keras.src.ops.numpy import array as array
143+
from keras.src.ops.numpy import array_split as array_split
143144
from keras.src.ops.numpy import average as average
144145
from keras.src.ops.numpy import bartlett as bartlett
145146
from keras.src.ops.numpy import bincount as bincount

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from keras.src.ops.numpy import argpartition as argpartition
2727
from keras.src.ops.numpy import argsort as argsort
2828
from keras.src.ops.numpy import array as array
29+
from keras.src.ops.numpy import array_split as array_split
2930
from keras.src.ops.numpy import average as average
3031
from keras.src.ops.numpy import bartlett as bartlett
3132
from keras.src.ops.numpy import bincount as bincount

keras/api/distribution/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
distribute_tensor as distribute_tensor,
1616
)
1717
from keras.src.distribution.distribution_lib import distribution as distribution
18+
from keras.src.distribution.distribution_lib import (
19+
get_device_count as get_device_count,
20+
)
1821
from keras.src.distribution.distribution_lib import initialize as initialize
1922
from keras.src.distribution.distribution_lib import list_devices as list_devices
2023
from keras.src.distribution.distribution_lib import (

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
from keras.src.ops.numpy import argpartition as argpartition
141141
from keras.src.ops.numpy import argsort as argsort
142142
from keras.src.ops.numpy import array as array
143+
from keras.src.ops.numpy import array_split as array_split
143144
from keras.src.ops.numpy import average as average
144145
from keras.src.ops.numpy import bartlett as bartlett
145146
from keras.src.ops.numpy import bincount as bincount

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from keras.src.ops.numpy import argpartition as argpartition
2727
from keras.src.ops.numpy import argsort as argsort
2828
from keras.src.ops.numpy import array as array
29+
from keras.src.ops.numpy import array_split as array_split
2930
from keras.src.ops.numpy import average as average
3031
from keras.src.ops.numpy import bartlett as bartlett
3132
from keras.src.ops.numpy import bincount as bincount

keras/src/backend/common/variables.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def assign(self, value):
282282
"The shape of the target variable and "
283283
"the shape of the target value in "
284284
"`variable.assign(value)` must match. "
285-
f"variable.shape={self.value.shape}, "
285+
f"variable.shape={self.shape}, "
286286
f"Received: value.shape={value.shape}. "
287287
f"Target variable: {self}"
288288
)
@@ -399,7 +399,11 @@ def constraint(self, value):
399399
def __repr__(self):
400400
value = None
401401
if hasattr(self, "_value") and self._value is not None:
402-
value = backend.core.convert_to_numpy(self._value)
402+
try:
403+
value = backend.core.convert_to_numpy(self._value)
404+
except:
405+
# In some cases the conversion to numpy can fail.
406+
pass
403407
value_str = f", value={value}" if value is not None else ""
404408
return (
405409
f"<Variable path={self.path}, shape={self.shape}, "

keras/src/backend/jax/distribution_lib.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ def list_devices(device_type=None):
2727
return [f"{device.platform}:{device.id}" for device in jax_devices]
2828

2929

30+
def get_device_count(device_type=None):
31+
"""Returns the number of available JAX devices.
32+
Args:
33+
device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu").
34+
If `None`, it defaults to counting "gpu" or "tpu" devices if
35+
available, otherwise it counts "cpu" devices. It does not
36+
return the sum of all device types.
37+
Returns:
38+
int: The total number of JAX devices for the specified type.
39+
"""
40+
device_type = device_type.lower() if device_type else None
41+
return jax.device_count(device_type)
42+
43+
3044
def distribute_variable(value, layout):
3145
"""Create a distributed variable for JAX.
3246

keras/src/backend/jax/distribution_lib_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929

3030

3131
@pytest.mark.skipif(
32-
backend.backend() != "jax",
33-
reason="Backend specific test",
32+
backend.backend() != "jax" or len(jax.devices()) != 8,
33+
reason="Backend specific test and requires 8 devices",
3434
)
3535
class JaxDistributionLibTest(testing.TestCase):
3636
def _create_jax_layout(self, sharding):
@@ -42,6 +42,10 @@ def _create_jax_layout(self, sharding):
4242

4343
return sharding
4444

45+
def test_get_device_count(self):
46+
self.assertEqual(backend_dlib.get_device_count(), 8)
47+
self.assertEqual(backend_dlib.get_device_count("cpu"), 8)
48+
4549
def test_list_devices(self):
4650
self.assertEqual(len(distribution_lib.list_devices()), 8)
4751
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)

0 commit comments

Comments
 (0)