Skip to content

Conversation

@buildwithsuhana
Copy link
Contributor

@buildwithsuhana buildwithsuhana commented Sep 26, 2025

This PR introduces keras.ops.array_split, a new operation that serves as a fundamental building block for tensor parallelism in Keras.

While keras.ops.split already exists, it requires the tensor's dimension to be evenly divisible by the number of splits. array_split (mirroring np.array_split) removes this restriction, allowing for uneven splits.

This capability is crucial for tensor parallelism, where we often need to shard large weight tensors (e.g., Dense kernels or Embedding layers) across a number of devices, even when the tensor's dimension isn't perfectly divisible by the device count. For example, splitting a tensor of size 10 into 3 parts will result in sub-tensors of size [4, 3, 3].

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request lays the fundamental groundwork for introducing backend-agnostic auto-sharding and Tensor Parallelism into Keras 3.0. It establishes a modular and extensible architecture by defining core data structures, abstracting distributed backend functionalities, and providing high-level communication primitives. This initial set of changes is crucial for enabling future capabilities that will allow users to train very large models across multiple devices with significantly simplified code.

Highlights

  • Core Distributed Backend Abstraction: Introduced BaseDistributedBackend as an abstract interface for distributed operations and a get_distributed_backend factory function to provide a unified, backend-agnostic way to interact with JAX, TensorFlow, PyTorch, and NumPy distributed environments.
  • High-Level Communication Primitives: Defined AllReduceKeras, AllGatherKeras, BroadcastKeras, and ScatterKeras classes, which serve as high-level wrappers for essential collective communication operations required for tensor parallelism.
  • Tensor Sharding Actions: Implemented StateActionKeras as an abstract base class for defining how tensors are transformed for distribution. Concrete implementations like SplitKeras handle tensor sharding, while GatherKeras and SumKeras define how to reconstruct original tensors from their distributed parts.
  • Sharding Plan Configuration: Introduced the ConfigKeras dataclass to store and manage model-wide sharding rules and output configurations, including a mechanism to dynamically create collective operations based on these rules.
  • Tensor Parallel Communicator: Added TensorParallelCommunicator to orchestrate complex communication patterns for tensor parallelism, including specific methods for handling forward and backward passes in column-parallel and row-parallel operations, along with gradient slicing logic.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request lays a solid foundation for tensor parallelism in Keras by introducing backend-agnostic abstractions for distributed operations and core data structures for sharding. The overall design is well-structured, separating concerns between backend-specific implementations, communication primitives, and configuration. However, there are several areas that need attention, particularly regarding the correctness of some backend implementations (especially JAX), placeholder logic, API clarity, and code consistency. Addressing these points will strengthen the foundation and prevent issues in future development.

@codecov-commenter
Copy link

codecov-commenter commented Sep 26, 2025

Codecov Report

❌ Patch coverage is 66.25000% with 27 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.63%. Comparing base (c2bc6cf) to head (f4f723d).
⚠️ Report is 8 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/openvino/numpy.py 5.88% 16 Missing ⚠️
keras/src/ops/numpy.py 78.04% 5 Missing and 4 partials ⚠️
keras/api/_tf_keras/keras/ops/__init__.py 0.00% 1 Missing ⚠️
keras/api/_tf_keras/keras/ops/numpy/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21697      +/-   ##
==========================================
- Coverage   82.63%   82.63%   -0.01%     
==========================================
  Files         577      577              
  Lines       59316    59407      +91     
  Branches     9300     9313      +13     
==========================================
+ Hits        49018    49091      +73     
- Misses       7910     7911       +1     
- Partials     2388     2405      +17     
Flag Coverage Δ
keras 82.45% <66.25%> (-0.01%) ⬇️
keras-jax 63.32% <50.00%> (+<0.01%) ⬆️
keras-numpy 57.57% <50.00%> (+0.01%) ⬆️
keras-openvino 34.30% <43.75%> (+<0.01%) ⬆️
keras-tensorflow 64.12% <56.25%> (+0.01%) ⬆️
keras-torch 63.63% <51.25%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@JyotinderSingh JyotinderSingh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a few initial comments and questions during my first look.

To make the review more manageable, I propose we split this change up. At almost 1,800 lines, the current change is quite difficult to review properly. What do you think about limiting this PR to just the JAX backend, and introducing the others in subsequent, smaller PRs?

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for creating more work, but can you split this PR in 3 PRs.

  • one for split_array.
  • one for get_device_count.
  • one for everything else for now.

Thanks! These are independent features that make sense on their own.

Additionally, I had a question in an earlier review:

On JAX, support for uneven shards is limited. Intermediary values can be sharded unevenly, but the outputs of a jitted function must be evenly sharded. Can you make sure this is not a blocker for this project?
[1] jax-ml/jax#26946 (comment)
[2] https://docs.jax.dev/en/latest/_autosummary/jax.lax.with_sharding_constraint.html

@buildwithsuhana buildwithsuhana changed the title Core Data Structures & Communication Primitives for Tensor Parallel for Keras Add keras.ops.array_split for Tensor Parallelism Support Oct 28, 2025
@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 30, 2025
@hertschuh hertschuh merged commit 1519bcc into keras-team:master Oct 30, 2025
11 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase labels Oct 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants