diff --git a/keras/api/_tf_keras/keras/distribution/__init__.py b/keras/api/_tf_keras/keras/distribution/__init__.py index 66fed24c761d..25ca527ebb32 100644 --- a/keras/api/_tf_keras/keras/distribution/__init__.py +++ b/keras/api/_tf_keras/keras/distribution/__init__.py @@ -15,6 +15,9 @@ distribute_tensor as distribute_tensor, ) from keras.src.distribution.distribution_lib import distribution as distribution +from keras.src.distribution.distribution_lib import ( + get_device_count as get_device_count, +) from keras.src.distribution.distribution_lib import initialize as initialize from keras.src.distribution.distribution_lib import list_devices as list_devices from keras.src.distribution.distribution_lib import ( diff --git a/keras/api/distribution/__init__.py b/keras/api/distribution/__init__.py index 66fed24c761d..25ca527ebb32 100644 --- a/keras/api/distribution/__init__.py +++ b/keras/api/distribution/__init__.py @@ -15,6 +15,9 @@ distribute_tensor as distribute_tensor, ) from keras.src.distribution.distribution_lib import distribution as distribution +from keras.src.distribution.distribution_lib import ( + get_device_count as get_device_count, +) from keras.src.distribution.distribution_lib import initialize as initialize from keras.src.distribution.distribution_lib import list_devices as list_devices from keras.src.distribution.distribution_lib import ( diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index 6b5bf37314c0..1407c008910e 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -27,6 +27,20 @@ def list_devices(device_type=None): return [f"{device.platform}:{device.id}" for device in jax_devices] +def get_device_count(device_type=None): + """Returns the number of available JAX devices. + Args: + device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu"). + If `None`, it defaults to counting "gpu" or "tpu" devices if + available, otherwise it counts "cpu" devices. It does not + return the sum of all device types. + Returns: + int: The total number of JAX devices for the specified type. + """ + device_type = device_type.lower() if device_type else None + return jax.device_count(device_type) + + def distribute_variable(value, layout): """Create a distributed variable for JAX. diff --git a/keras/src/backend/jax/distribution_lib_test.py b/keras/src/backend/jax/distribution_lib_test.py index 8938c14fc50a..3ee3a2bc91b7 100644 --- a/keras/src/backend/jax/distribution_lib_test.py +++ b/keras/src/backend/jax/distribution_lib_test.py @@ -29,8 +29,8 @@ @pytest.mark.skipif( - backend.backend() != "jax", - reason="Backend specific test", + backend.backend() != "jax" or len(jax.devices()) != 8, + reason="Backend specific test and requires 8 devices", ) class JaxDistributionLibTest(testing.TestCase): def _create_jax_layout(self, sharding): @@ -42,6 +42,10 @@ def _create_jax_layout(self, sharding): return sharding + def test_get_device_count(self): + self.assertEqual(backend_dlib.get_device_count(), 8) + self.assertEqual(backend_dlib.get_device_count("cpu"), 8) + def test_list_devices(self): self.assertEqual(len(distribution_lib.list_devices()), 8) self.assertEqual(len(distribution_lib.list_devices("cpu")), 8) diff --git a/keras/src/distribution/distribution_lib.py b/keras/src/distribution/distribution_lib.py index 2daef40a2ed8..abf2b79e2c62 100644 --- a/keras/src/distribution/distribution_lib.py +++ b/keras/src/distribution/distribution_lib.py @@ -39,6 +39,20 @@ def list_devices(device_type=None): return distribution_lib.list_devices(device_type) +@keras_export("keras.distribution.get_device_count") +def get_device_count(device_type=None): + """Returns the number of available JAX devices. + Args: + device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu"). + If `None`, it defaults to counting "gpu" or "tpu" devices if + available, otherwise it counts "cpu" devices. It does not + return the sum of all device types. + Returns: + int: The total number of JAX devices for the specified type. + """ + return distribution_lib.get_device_count(device_type=device_type) + + @keras_export("keras.distribution.initialize") def initialize(job_addresses=None, num_processes=None, process_id=None): """Initialize the distribution system for multi-host/process setting.