[release/2.10] [Upstream cherry-pick] Add partitioned scatter approach with optimizations #2894
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
From pytorch#168073
It has been observed that in the case of heavy contended atomics poor performance is being achieved.
To solve this problem while minimizing kernel overhead this PR proposes an fx pass which will replace the index_put operation with an alternative scatter approach.
Algorithm:
This will reduce atomic contention at the cost of memory usage. In order to combat this we have built heuristics around the total number of partitions for the expanded buffer, as well as setting a cap on how large these expanded tensors can be (currently 10% of GPU memory)
Note the heuristic cannot be perfect as we do not know the true indices data at compile time, in real world models the indices will have duplicates and not be uniformly distributed which increases atomic contention, currently this cannot be modelled and we have to estimate contention based on input and output buffer sizes.
Benchmark code: https://gist.github.com/jataylo/dd3a6353ad2859efd65fa87b28aa3ebd
This code executes 3 index_add ops to 3 seperate buffers.
N = 1000000
D = 100
n = 501
values = float32 [N,D]
indices = int64 [N]
output = float32 [n, D]
For each run we modify the range of randint to simulate various levels of atomic contention
Gathered two sets of results, one with partitioned_scatter_enabled=True, the other partitioned_scatter_enabled=False
MI300
H100
We can see this could potentially benefit H100 on worst-case examples but would degrade perf in the best case, the atomic add cost on MI300 is heavier meaning this is more beneficial.
On MI300 we can see a mixed bag of e2e model improvements
https://hud.pytorch.org/benchmark/v3/dashboard/compiler_inductor?renderGroupId=main&time.start=2025-11-05T00%3A00%3A00.000Z&time.end=2025-12-04T02%3A00%3A00.000Z&filters.repo=pytorch%2Fpytorch&filters.benchmarkName=compiler&filters.mode=training&filters.dtype=amp&filters.deviceName=rocm+%28mi300x%29&filters.device=rocm&filters.suite=all&filters.compiler=default&lcommit.commit=38c42c575d342a7ea6f4a555bf845071e03b5f35&lcommit.workflow_id=19635538449&lcommit.date=2025-11-24T14%3A00%3A00Z&lcommit.branch=refs%2Ftags%2Fciflow%2Finductor-perf-test-nightly-rocm-mi300%2F168073&rcommit.commit=fedb7f15d177a259bf25c94e888137e0a9a69a81&rcommit.workflow_id=19856622912&rcommit.date=2025-12-02T12%3A00%3A00Z&rcommit.branch=refs%2Ftags%2Fciflow%2Finductor-perf-test-nightly-rocm-mi300%2F168073&lbranch=refs%2Ftags%2Fciflow%2Finductor-perf-test-nightly-rocm-mi300%2F168073&rbranch=refs%2Ftags%2Fciflow%2Finductor-perf-test-nightly-rocm-mi300%2F168073&maxSampling=110
Due to mixed-bag of results we will initially enable this as non default feature but testing passed CI with this enabled here
https://hud.pytorch.org/pytorch/pytorch/pull/168073?sha=fedb7f15d177a259bf25c94e888137e0a9a69a81
Note there are improvements to make after this lands:
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @hongxiayang @naromero77amd @pragupta @jerrymannil @xinyazhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @dllehr-amd @chenyang78