- 
                Notifications
    You must be signed in to change notification settings 
- Fork 19.6k
Add keras.ops.array_split for Tensor Parallelism Support #21697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
            hertschuh
  merged 70 commits into
  keras-team:master
from
buildwithsuhana:Tensor_parallel_keras
  
      
      
   
  Oct 30, 2025 
      
    
  
     Merged
                    Changes from 57 commits
      Commits
    
    
            Show all changes
          
          
            70 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      a27367a
              
                Added tensor parallel for keras (Part 1/3)
              
              
                buildwithsuhana 488cd8f
              
                Removed unnecessary lines
              
              
                buildwithsuhana 71ddd1a
              
                Fixes suggested by Gemini
              
              
                buildwithsuhana bc4e4e2
              
                Fixes suggested by Gemini
              
              
                buildwithsuhana d4200b5
              
                Fixes suggested by Gemini
              
              
                buildwithsuhana 21f89a2
              
                Fixes suggested by Gemini
              
              
                buildwithsuhana 299bd45
              
                Fixes suggested by Gemini
              
              
                buildwithsuhana da625e1
              
                Fixes suggested by Gemini
              
              
                buildwithsuhana c233b8c
              
                Fixing the failing test
              
              
                buildwithsuhana 7b8d733
              
                Fixing the failing test
              
              
                buildwithsuhana f825cd3
              
                Fixing test
              
              
                buildwithsuhana 3725180
              
                Adding tests for distributed_backends
              
              
                buildwithsuhana a6c8a96
              
                Modifications for failing tests
              
              
                buildwithsuhana 3fabfde
              
                Modified for failing test
              
              
                buildwithsuhana b133752
              
                Modified for failing test
              
              
                buildwithsuhana 83c2e3f
              
                Modified for failing test
              
              
                buildwithsuhana 3f3be6b
              
                added debuggers
              
              
                buildwithsuhana be325ab
              
                removed debuggers
              
              
                buildwithsuhana e1282ac
              
                Merge branch 'keras-team:master' into Tensor_parallel_keras
              
              
                buildwithsuhana fc11aaa
              
                Removed the tensorflow, numpy and torch backends
              
              
                buildwithsuhana ef6e2a0
              
                Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
              
              
                buildwithsuhana bea6ffa
              
                Refactoring the code
              
              
                buildwithsuhana 4e00245
              
                Refactoring the code
              
              
                buildwithsuhana 2f973b0
              
                refactoring
              
              
                buildwithsuhana bdb2b84
              
                Adding necessary docstrings
              
              
                buildwithsuhana d77fa71
              
                Merge branch 'keras-team:master' into Tensor_parallel_keras
              
              
                buildwithsuhana b9990b0
              
                Removing redundancies
              
              
                buildwithsuhana 0aeee6f
              
                Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
              
              
                buildwithsuhana f784956
              
                Modifying tests
              
              
                buildwithsuhana 8895a78
              
                Reformatting
              
              
                buildwithsuhana fe97f3b
              
                Reformatting the code
              
              
                buildwithsuhana 77f01aa
              
                Fixing failing tests
              
              
                buildwithsuhana 7080328
              
                fixes
              
              
                buildwithsuhana af711fd
              
                Fixing tests
              
              
                buildwithsuhana 97dde17
              
                formatting
              
              
                buildwithsuhana f322a97
              
                fixing test
              
              
                buildwithsuhana 5269ac9
              
                fixing test
              
              
                buildwithsuhana b9f36e9
              
                Removing redundant lines
              
              
                buildwithsuhana 555e5c9
              
                Refactoring to remove communications.py and state_action_keras.py
              
              
                buildwithsuhana b80d264
              
                formatting the files
              
              
                buildwithsuhana 93b1738
              
                fixing skip issues
              
              
                buildwithsuhana b7b2b9b
              
                fixing test
              
              
                buildwithsuhana f6c1142
              
                fixing test
              
              
                buildwithsuhana 669c799
              
                refactoring to remove distributed backend wrapper
              
              
                buildwithsuhana cd20b9f
              
                fixing test
              
              
                buildwithsuhana cd0049f
              
                making distrubed backend more jax friendly
              
              
                buildwithsuhana d1e4c69
              
                Fixing comments
              
              
                buildwithsuhana 86e0557
              
                Fixing comments
              
              
                buildwithsuhana 6c3883f
              
                Fixing comments
              
              
                buildwithsuhana 3e31e1e
              
                fixes
              
              
                buildwithsuhana c99601e
              
                Refactor
              
              
                buildwithsuhana dbae56d
              
                refactoring to resolve comments
              
              
                buildwithsuhana 2fc0f0e
              
                fixes
              
              
                buildwithsuhana 174093c
              
                fixes
              
              
                buildwithsuhana 7d18b0a
              
                fix
              
              
                buildwithsuhana f570925
              
                fix
              
              
                buildwithsuhana 9e7f873
              
                removing get_best_devices
              
              
                buildwithsuhana 5136091
              
                fixing comments
              
              
                buildwithsuhana 8f40c53
              
                Merge branch 'master' into Tensor_parallel_keras
              
              
                buildwithsuhana 08b8abe
              
                fixing merge conflict
              
              
                buildwithsuhana 3a408da
              
                Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
              
              
                buildwithsuhana eb796ea
              
                modifying variable name
              
              
                buildwithsuhana 15e1709
              
                Fixes
              
              
                buildwithsuhana 911b96e
              
                fix
              
              
                buildwithsuhana bd2f19f
              
                fix
              
              
                buildwithsuhana 71d079f
              
                splitting into 3 PRs
              
              
                buildwithsuhana 7789084
              
                Modified array_split implementation in openvino, tensorflow and torch
              
              
                buildwithsuhana 162e6c3
              
                formatting the array split function
              
              
                buildwithsuhana d47e3e6
              
                adding test for uneven array split
              
              
                buildwithsuhana f4f723d
              
                fixing test
              
              
                buildwithsuhana File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| import collections | ||
|  | ||
| from keras.src import ops | ||
|  | ||
|  | ||
| class Split: | ||
|         
                  buildwithsuhana marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| """Splits a tensor into shards along a specified dimension. | ||
| This is an internal utility used by a higher-level distribution API. | ||
| It implements sharding by slicing a tensor along one of its axes. | ||
| It handles cases where the dimension size is not perfectly divisible by the | ||
| number of workers by distributing the remainder elements one by one to the | ||
| first few workers. | ||
|         
                  buildwithsuhana marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| """ | ||
|  | ||
| def __init__(self, device_count, dim, sharding_type="auto"): | ||
| """Initializes the Split action. | ||
| Args: | ||
| device_count: The total number of workers/shards. | ||
| dim: The dimension along which to split the tensor. If -1, the | ||
| last dimension is used. | ||
| sharding_type: If `dim` is -1, this can be 'row' (dim=0) or | ||
|         
                  buildwithsuhana marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| 'column' (dim=1) to infer the split axis for 2D tensors. | ||
|         
                  buildwithsuhana marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| Defaults to "auto". | ||
| """ | ||
| self.device_count = device_count | ||
| self.dim = dim | ||
| self.sharding_type = sharding_type | ||
|  | ||
| if dim == -1 and sharding_type != "auto": | ||
| if sharding_type == "row": | ||
| self.dim = 0 | ||
| elif sharding_type == "column": | ||
| self.dim = 1 | ||
|  | ||
| def __call__(self, tensor, rank): | ||
|         
                  buildwithsuhana marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| """Splits the tensor and returns the shard corresponding to the rank. | ||
| This method calculates the correct slice of the tensor for a given | ||
| worker rank, handling uneven distributions gracefully. | ||
| Args: | ||
| tensor: The full tensor to be sharded. | ||
|         
                  buildwithsuhana marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| rank: The rank of the worker for which to get the shard. | ||
|         
                  buildwithsuhana marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| Returns: | ||
| A tensor shard corresponding to the given rank. | ||
| """ | ||
| if self.dim == -1: | ||
| dim = ops.ndim(tensor) - 1 | ||
| else: | ||
| dim = self.dim | ||
|  | ||
| total_size = tensor.shape[dim] | ||
| split_size = total_size // self.device_count | ||
| remainder = total_size % self.device_count | ||
|  | ||
| start_idx = rank * split_size + min(rank, remainder) | ||
| end_idx = start_idx + split_size + (1 if rank < remainder else 0) | ||
|  | ||
| slices = [slice(None)] * ops.ndim(tensor) | ||
| slices[dim] = slice(start_idx, end_idx) | ||
| return tensor[tuple(slices)] | ||
|  | ||
|  | ||
| LayoutMap = collections.namedtuple("LayoutMap", ["state_rules", "output_rules"]) | ||
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.