- 
                Notifications
    You must be signed in to change notification settings 
- Fork 19.6k
Add keras.ops.array_split for Tensor Parallelism Support #21697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add keras.ops.array_split for Tensor Parallelism Support #21697
Conversation
| Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request lays the fundamental groundwork for introducing backend-agnostic auto-sharding and Tensor Parallelism into Keras 3.0. It establishes a modular and extensible architecture by defining core data structures, abstracting distributed backend functionalities, and providing high-level communication primitives. This initial set of changes is crucial for enabling future capabilities that will allow users to train very large models across multiple devices with significantly simplified code. Highlights
 Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either  
 Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a  Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request lays a solid foundation for tensor parallelism in Keras by introducing backend-agnostic abstractions for distributed operations and core data structures for sharding. The overall design is well-structured, separating concerns between backend-specific implementations, communication primitives, and configuration. However, there are several areas that need attention, particularly regarding the correctness of some backend implementations (especially JAX), placeholder logic, API clarity, and code consistency. Addressing these points will strengthen the foundation and prevent issues in future development.
| Codecov Report❌ Patch coverage is  Additional details and impacted files@@            Coverage Diff             @@
##           master   #21697      +/-   ##
==========================================
- Coverage   82.63%   82.63%   -0.01%     
==========================================
  Files         577      577              
  Lines       59316    59407      +91     
  Branches     9300     9313      +13     
==========================================
+ Hits        49018    49091      +73     
- Misses       7910     7911       +1     
- Partials     2388     2405      +17     
 Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a few initial comments and questions during my first look.
To make the review more manageable, I propose we split this change up. At almost 1,800 lines, the current change is quite difficult to review properly. What do you think about limiting this PR to just the JAX backend, and introducing the others in subsequent, smaller PRs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for creating more work, but can you split this PR in 3 PRs.
- one for split_array.
- one for get_device_count.
- one for everything else for now.
Thanks! These are independent features that make sense on their own.
Additionally, I had a question in an earlier review:
On JAX, support for uneven shards is limited. Intermediary values can be sharded unevenly, but the outputs of a jitted function must be evenly sharded. Can you make sure this is not a blocker for this project?
[1] jax-ml/jax#26946 (comment)
[2] https://docs.jax.dev/en/latest/_autosummary/jax.lax.with_sharding_constraint.html
This PR introduces keras.ops.array_split, a new operation that serves as a fundamental building block for tensor parallelism in Keras.
While keras.ops.split already exists, it requires the tensor's dimension to be evenly divisible by the number of splits. array_split (mirroring np.array_split) removes this restriction, allowing for uneven splits.
This capability is crucial for tensor parallelism, where we often need to shard large weight tensors (e.g., Dense kernels or Embedding layers) across a number of devices, even when the tensor's dimension isn't perfectly divisible by the device count. For example, splitting a tensor of size 10 into 3 parts will result in sub-tensors of size [4, 3, 3].