Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
distribute_tensor as distribute_tensor,
)
from keras.src.distribution.distribution_lib import distribution as distribution
from keras.src.distribution.distribution_lib import (
get_device_count as get_device_count,
)
from keras.src.distribution.distribution_lib import initialize as initialize
from keras.src.distribution.distribution_lib import list_devices as list_devices
from keras.src.distribution.distribution_lib import (
Expand Down
3 changes: 3 additions & 0 deletions keras/api/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
distribute_tensor as distribute_tensor,
)
from keras.src.distribution.distribution_lib import distribution as distribution
from keras.src.distribution.distribution_lib import (
get_device_count as get_device_count,
)
from keras.src.distribution.distribution_lib import initialize as initialize
from keras.src.distribution.distribution_lib import list_devices as list_devices
from keras.src.distribution.distribution_lib import (
Expand Down
14 changes: 14 additions & 0 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ def list_devices(device_type=None):
return [f"{device.platform}:{device.id}" for device in jax_devices]


def get_device_count(device_type=None):
"""Returns the number of available JAX devices.
Args:
device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu").
If `None`, it defaults to counting "gpu" or "tpu" devices if
available, otherwise it counts "cpu" devices. It does not
return the sum of all device types.
Returns:
int: The total number of JAX devices for the specified type.
"""
device_type = device_type.lower() if device_type else None
return jax.device_count(device_type)


def distribute_variable(value, layout):
"""Create a distributed variable for JAX.

Expand Down
4 changes: 2 additions & 2 deletions keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@


@pytest.mark.skipif(
backend.backend() != "jax",
reason="Backend specific test",
backend.backend() != "jax" or len(jax.devices()) != 8,
reason="Backend specific test and requires 8 devices",
)
class JaxDistributionLibTest(testing.TestCase):
def _create_jax_layout(self, sharding):
Expand Down
15 changes: 15 additions & 0 deletions keras/src/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ def list_devices(device_type=None):
return distribution_lib.list_devices(device_type)


@keras_export("keras.distribution.get_device_count")
def get_device_count(device_type=None):
"""Returns the total number of available devices.
Args:
device_type: Optional device type to count (e.g., "cpu",
"gpu", "tpu"). If `None`, it counts all available
devices.
Returns:
int: The total number of available devices.
"""
return distribution_lib.get_device_count(device_type=device_type)


@keras_export("keras.distribution.initialize")
def initialize(job_addresses=None, num_processes=None, process_id=None):
"""Initialize the distribution system for multi-host/process setting.
Expand Down