From 8630eca23646270fa91dfd98f1865e5b702036fe Mon Sep 17 00:00:00 2001 From: Jokeren Date: Mon, 22 Sep 2025 18:17:04 -0400 Subject: [PATCH 1/9] Update Update Update Update Update Update Fix Update Update Update Update Lint Update Update Update Update Update Update Update Update Update Update Update Update Update Lint Update Update Update Lint Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Fix Update Fix Fix Update Revert Update Update Temporary update Update Update Update Update Update Update Update Update Fix Update Update Update Try Update Update Remove i64 Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update --- python/triton_kernels/bench/distributed.py | 221 +++++++++++++++++++-- third_party/proton/proton/proton.py | 2 +- 2 files changed, 209 insertions(+), 14 deletions(-) diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index a7ae39f77ab3..fbb2fb6d3274 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -1,11 +1,13 @@ +from enum import Enum import os +import math import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp from copy import deepcopy from dataclasses import dataclass -from typing import Tuple +from typing import Tuple, Optional import triton import triton.language as tl @@ -31,12 +33,26 @@ from bench_utils import quantize_weight +class CommKernelType(Enum): + TORCH = "torch" + TRITON = "triton" + FUSE = "fuse" + + @dataclass class ReduceScatterMetadata: input_split_sizes: list[int] ep_indx: torch.Tensor EP: int = 1 TP: int = 1 + comm: CommKernelType = CommKernelType.TORCH + + +TRITON_DTYPE_MAP = { + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, + torch.float32: tl.float32, +} def _is_distributed_launch() -> bool: @@ -88,9 +104,155 @@ def all_gather(x: torch.Tensor, dim=0) -> torch.Tensor: return x +def _reduce_ep_torch(metadata: ReduceScatterMetadata, input_tensor: torch.Tensor, output_list: list[torch.Tensor], + world_size: int, dim: int, op: dist.ReduceOp.RedOpType, original_dtype: torch.dtype, + intermediate_dtype: torch.dtype) -> torch.Tensor: + n_tokens = metadata.ep_indx.size(dim) + other_dims = input_tensor.shape[1:] + output_tensor = input_tensor.new_zeros((n_tokens, ) + other_dims, dtype=intermediate_dtype) + for i in range(world_size): + ep_rank = i // metadata.TP + mask = torch.any(metadata.ep_indx == ep_rank, dim=1) + if op == dist.ReduceOp.SUM: + output_tensor[mask] += output_list[i].to(intermediate_dtype) + else: + raise NotImplementedError(f"Reduce operation {op} is not implemented.") + return output_tensor.to(original_dtype) + + +@triton.jit +def _prepare_ep_positions_kernel(ep_indx_ptr, positions_ptr, n_tokens, n_expts, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_E: tl.constexpr): + ep_idx = tl.program_id(0) + token_offsets = tl.arange(0, BLOCK_SIZE_M) + expert_offsets = tl.arange(0, BLOCK_SIZE_E) + num_tiles = tl.cdiv(n_tokens, BLOCK_SIZE_M) + base = 0 + + for tile in range(num_tiles): + offs_m = tile * BLOCK_SIZE_M + token_offsets + token_mask = offs_m < n_tokens + load_mask = token_mask[:, None] & (expert_offsets[None, :] < n_expts) + ep_values = tl.load( + ep_indx_ptr + offs_m[:, None] * n_expts + expert_offsets[None, :], + mask=load_mask, + other=-1, + ) + + row_has_ep = tl.reduce_or(ep_values == ep_idx, axis=1) & token_mask + row_has_ep_i32 = row_has_ep.to(tl.int32) + prefix = tl.cumsum(row_has_ep_i32, axis=0) - row_has_ep_i32 + positions = tl.where(row_has_ep, base + prefix, -1) + tl.store( + positions_ptr + ep_idx * n_tokens + offs_m, + positions, + mask=token_mask + ) + increment = tl.sum(row_has_ep_i32, axis=0) + base = base + increment + + +def _accumulate_ep_metadata(grid, kernel, args): + ret = {} + n_tokens, hidden_size = args["n_tokens"], args["hidden_size"] + TP, EP = args["TP"], args["EP"] + ret["name"] = f"{kernel.name} [n_tokens={n_tokens}, hidden_size={hidden_size}, TP={TP}, EP={EP}]" + ep_positions_ptr = args["ep_positions_ptr"] + output_tensor_ptr = args["output_tensor_ptr"] + input_ptrs = args["input_ptrs"] + ret["bytes"] = (ep_positions_ptr.element_size() * ep_positions_ptr.numel() + + output_tensor_ptr.element_size() * output_tensor_ptr.numel() + + sum([p.element_size() * p.numel() for p in input_ptrs])) + return ret + + +@triton.jit(launch_metadata=_accumulate_ep_metadata) +def _accumulate_ep_triton_kernel(ep_positions_ptr, output_tensor_ptr, input_ptrs, n_tokens, hidden_size, + TP: tl.constexpr, EP: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + original_dtype: tl.constexpr, intermediate_dtype: tl.constexpr): + offs_m = tl.program_id(0) + token_mask = offs_m < n_tokens + offs_n = tl.arange(0, BLOCK_SIZE_N) + feature_mask = offs_n < hidden_size + io_mask = token_mask[:] & feature_mask + output = tl.zeros((BLOCK_SIZE_N,), dtype=intermediate_dtype) + + for ep_idx in tl.static_range(EP): + position = tl.load( + ep_positions_ptr + ep_idx * n_tokens + offs_m, + mask=token_mask, + other=-1, + ) + if position != -1: + row_offsets = position * hidden_size + col_offsets = row_offsets + offs_n + + for tp_idx in tl.static_range(TP): + values = tl.load( + input_ptrs[tl.constexpr(ep_idx * TP + tp_idx)] + col_offsets, + mask=io_mask, + other=0, + ).to(intermediate_dtype) + output += values + + output = output.to(original_dtype) + tl.store( + output_tensor_ptr + offs_m * hidden_size + offs_n, + output, + mask=io_mask, + ) + + +def _reduce_ep_triton(metadata: ReduceScatterMetadata, input_tensor: torch.Tensor, output_list: list[torch.Tensor], + world_size: int, dim: int, op: dist.ReduceOp.RedOpType, original_dtype: torch.dtype, + intermediate_dtype: torch.dtype) -> torch.Tensor: + if op != dist.ReduceOp.SUM: + raise NotImplementedError(f"Reduce operation {op} is not implemented.") + ep_indx = metadata.ep_indx.contiguous() + n_tokens = ep_indx.size(dim) + n_expts = ep_indx.size(1 - dim) + other_dims = input_tensor.shape[1:] + hidden_size = math.prod(other_dims) + output_tensor = input_tensor.new_empty((n_tokens, ) + other_dims, dtype=original_dtype) + positions = torch.empty((metadata.EP, n_tokens), dtype=torch.int32, device=input_tensor.device) + triton_original_dtype = TRITON_DTYPE_MAP.get(original_dtype, tl.float32) + triton_intermediate_dtype = TRITON_DTYPE_MAP.get(intermediate_dtype, tl.float32) + BLOCK_SIZE_M = 2048 if n_tokens >= 2048 else triton.next_power_of_2(n_tokens if n_tokens > 0 else 1) + BLOCK_SIZE_N = triton.next_power_of_2(hidden_size if hidden_size > 0 else 1) + BLOCK_SIZE_E = triton.next_power_of_2(n_expts if n_expts > 0 else 1) + + # XXX(Keren): Do not over optimize this function for now, as the communication interface (all-to-all) is subject to change. + _prepare_ep_positions_kernel[(metadata.EP,)]( + ep_indx, + positions, + n_tokens, + n_expts, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_E=BLOCK_SIZE_E, + num_warps=16 if n_tokens >= 2048 else 8, + ) # type: ignore + + _accumulate_ep_triton_kernel[(n_tokens,)]( + positions, + output_tensor, + tuple(output_list), + n_tokens, + hidden_size, + TP=metadata.TP, + EP=metadata.EP, + BLOCK_SIZE_N=BLOCK_SIZE_N, + original_dtype=triton_original_dtype, + intermediate_dtype=triton_intermediate_dtype, + ) + + return output_tensor + + + def reduce_scatter( input_tensor: torch.Tensor, - metadata: ReduceScatterMetadata = None, + metadata: Optional[ReduceScatterMetadata] = None, dim: int = 0, op: dist.ReduceOp.RedOpType = dist.ReduceOp.SUM, ) -> torch.Tensor: @@ -109,18 +271,16 @@ def dtype_cast(dtype: torch.dtype) -> torch.dtype: if metadata and metadata.input_split_sizes: assert dim == 0, "metadata only works with dim=0" input_list = list(input_tensor.split(metadata.input_split_sizes, dim=0)) + # TODO(Keren): Implement a triton all-to-all kernel output_list = all_to_all(input_list, dim=0) - n_tokens = metadata.ep_indx.size(dim) - other_dims = input_tensor.shape[1:] - output_tensor = input_tensor.new_zeros((n_tokens, ) + other_dims, dtype=intermediate_dtype) - for i in range(world_size): - ep_rank = i // metadata.TP - mask = torch.any(metadata.ep_indx == ep_rank, dim=1) - if op == dist.ReduceOp.SUM: - output_tensor[mask] += output_list[i].to(intermediate_dtype) - else: - raise NotImplementedError(f"Reduce operation {op} is not implemented.") - return output_tensor.to(original_dtype) + if metadata.comm == CommKernelType.TORCH: + return _reduce_ep_torch(metadata, input_tensor, output_list, world_size, dim, op, original_dtype, + intermediate_dtype) + elif metadata.comm == CommKernelType.TRITON: + return _reduce_ep_triton(metadata, input_tensor, output_list, world_size, dim, op, original_dtype, + intermediate_dtype) + else: + raise NotImplementedError(f"CommKernelType {metadata.comm} is not implemented.") else: input_list = list(input_tensor.chunk(world_size, dim=dim)) shape = input_list[0].shape @@ -436,6 +596,41 @@ def test_reduce_scatter_distributed_with_metadata(monkeypatch): torch.testing.assert_close(result, torch.tensor([[1, 2], [1, 2]], dtype=torch.float32)) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for Triton kernel execution") +@pytest.mark.parametrize("TP, EP", [(4, 2), (1, 8), (2, 2)]) +@pytest.mark.parametrize("n_tokens", [128, 1024, 30720]) +@pytest.mark.parametrize("hidden_size", [1024, 5760]) +@pytest.mark.parametrize("n_expt_act", [4]) +def test_reduce_ep(TP, EP, n_tokens, hidden_size, n_expt_act): + device = torch.device("cuda") + world_size = TP * EP + + ep_indx = torch.randint(0, EP, (n_tokens, n_expt_act), device=device, dtype=torch.int32).contiguous() + original_dtype = torch.float16 + intermediate_dtype = torch.float16 + input_tensor = torch.zeros((n_tokens, hidden_size), device=device, dtype=original_dtype) + output_list = [] + for rank in range(world_size): + ep_rank = rank // TP + mask = torch.any(ep_indx == ep_rank, dim=1) + random_values = torch.randn(mask.sum().item(), hidden_size, device=device, dtype=intermediate_dtype) + output_list.append(random_values) + metadata = ReduceScatterMetadata( + input_split_sizes=None, # it doesn't matter in this test since we skipped all_to_all + ep_indx=ep_indx, + EP=EP, + TP=TP, + ) + + op = dist.ReduceOp.SUM + dim = 0 + ret = _reduce_ep_triton(metadata, input_tensor, output_list, world_size, + dim, op, original_dtype, intermediate_dtype) + ref = _reduce_ep_torch(metadata, input_tensor, output_list, world_size, + dim, op, original_dtype, intermediate_dtype) + torch.testing.assert_close(ret, ref) + + def test_routing_distributed_EP(monkeypatch): # Test distributed routing with EP=1 (token_mask should be None) monkeypatch.setenv("WORLD_SIZE", "2") diff --git a/third_party/proton/proton/proton.py b/third_party/proton/proton/proton.py index dae38fd66616..80d3a971e92c 100644 --- a/third_party/proton/proton/proton.py +++ b/third_party/proton/proton/proton.py @@ -19,7 +19,7 @@ def parse_arguments(): choices=["shadow", "python"]) parser.add_argument("-m", "--mode", type=str, help="Profiling mode", default=None) parser.add_argument("-d", "--data", type=str, help="Profiling data", default="tree", choices=["tree", "trace"]) - parser.add_argument("-k", "--hook", type=str, help="Profiling hook", default=None, choices=[None, "launch"]) + parser.add_argument("-k", "--hook", type=str, help="Profiling hook", default=None, choices=[None, "triton"]) parser.add_argument('target_args', nargs=argparse.REMAINDER, help='Subcommand and its arguments') args = parser.parse_args() return args, args.target_args From 6468eaed2695c465c883e8cd6b13eb884beb842d Mon Sep 17 00:00:00 2001 From: Jokeren Date: Fri, 26 Sep 2025 09:17:12 -0400 Subject: [PATCH 2/9] Lint --- python/triton_kernels/bench/distributed.py | 29 ++++++++-------------- 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index fbb2fb6d3274..00dcb1025557 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -121,8 +121,8 @@ def _reduce_ep_torch(metadata: ReduceScatterMetadata, input_tensor: torch.Tensor @triton.jit -def _prepare_ep_positions_kernel(ep_indx_ptr, positions_ptr, n_tokens, n_expts, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_E: tl.constexpr): +def _prepare_ep_positions_kernel(ep_indx_ptr, positions_ptr, n_tokens, n_expts, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_E: tl.constexpr): ep_idx = tl.program_id(0) token_offsets = tl.arange(0, BLOCK_SIZE_M) expert_offsets = tl.arange(0, BLOCK_SIZE_E) @@ -143,11 +143,7 @@ def _prepare_ep_positions_kernel(ep_indx_ptr, positions_ptr, n_tokens, n_expts, row_has_ep_i32 = row_has_ep.to(tl.int32) prefix = tl.cumsum(row_has_ep_i32, axis=0) - row_has_ep_i32 positions = tl.where(row_has_ep, base + prefix, -1) - tl.store( - positions_ptr + ep_idx * n_tokens + offs_m, - positions, - mask=token_mask - ) + tl.store(positions_ptr + ep_idx * n_tokens + offs_m, positions, mask=token_mask) increment = tl.sum(row_has_ep_i32, axis=0) base = base + increment @@ -168,15 +164,14 @@ def _accumulate_ep_metadata(grid, kernel, args): @triton.jit(launch_metadata=_accumulate_ep_metadata) def _accumulate_ep_triton_kernel(ep_positions_ptr, output_tensor_ptr, input_ptrs, n_tokens, hidden_size, - TP: tl.constexpr, EP: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, + TP: tl.constexpr, EP: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, original_dtype: tl.constexpr, intermediate_dtype: tl.constexpr): offs_m = tl.program_id(0) token_mask = offs_m < n_tokens offs_n = tl.arange(0, BLOCK_SIZE_N) feature_mask = offs_n < hidden_size io_mask = token_mask[:] & feature_mask - output = tl.zeros((BLOCK_SIZE_N,), dtype=intermediate_dtype) + output = tl.zeros((BLOCK_SIZE_N, ), dtype=intermediate_dtype) for ep_idx in tl.static_range(EP): position = tl.load( @@ -223,7 +218,7 @@ def _reduce_ep_triton(metadata: ReduceScatterMetadata, input_tensor: torch.Tenso BLOCK_SIZE_E = triton.next_power_of_2(n_expts if n_expts > 0 else 1) # XXX(Keren): Do not over optimize this function for now, as the communication interface (all-to-all) is subject to change. - _prepare_ep_positions_kernel[(metadata.EP,)]( + _prepare_ep_positions_kernel[(metadata.EP, )]( ep_indx, positions, n_tokens, @@ -233,7 +228,7 @@ def _reduce_ep_triton(metadata: ReduceScatterMetadata, input_tensor: torch.Tenso num_warps=16 if n_tokens >= 2048 else 8, ) # type: ignore - _accumulate_ep_triton_kernel[(n_tokens,)]( + _accumulate_ep_triton_kernel[(n_tokens, )]( positions, output_tensor, tuple(output_list), @@ -249,7 +244,6 @@ def _reduce_ep_triton(metadata: ReduceScatterMetadata, input_tensor: torch.Tenso return output_tensor - def reduce_scatter( input_tensor: torch.Tensor, metadata: Optional[ReduceScatterMetadata] = None, @@ -623,11 +617,10 @@ def test_reduce_ep(TP, EP, n_tokens, hidden_size, n_expt_act): ) op = dist.ReduceOp.SUM - dim = 0 - ret = _reduce_ep_triton(metadata, input_tensor, output_list, world_size, - dim, op, original_dtype, intermediate_dtype) - ref = _reduce_ep_torch(metadata, input_tensor, output_list, world_size, - dim, op, original_dtype, intermediate_dtype) + dim = 0 + ret = _reduce_ep_triton(metadata, input_tensor, output_list, world_size, dim, op, original_dtype, + intermediate_dtype) + ref = _reduce_ep_torch(metadata, input_tensor, output_list, world_size, dim, op, original_dtype, intermediate_dtype) torch.testing.assert_close(ret, ref) From af3af647bd63706747bcf91e61e42d5778c2776d Mon Sep 17 00:00:00 2001 From: Jokeren Date: Fri, 26 Sep 2025 09:21:33 -0400 Subject: [PATCH 3/9] Update --- python/triton_kernels/bench/distributed.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index 00dcb1025557..383eeeee78fe 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -105,8 +105,9 @@ def all_gather(x: torch.Tensor, dim=0) -> torch.Tensor: def _reduce_ep_torch(metadata: ReduceScatterMetadata, input_tensor: torch.Tensor, output_list: list[torch.Tensor], - world_size: int, dim: int, op: dist.ReduceOp.RedOpType, original_dtype: torch.dtype, + dim: int, op: dist.ReduceOp.RedOpType, original_dtype: torch.dtype, intermediate_dtype: torch.dtype) -> torch.Tensor: + world_size = metadata.TP * metadata.EP n_tokens = metadata.ep_indx.size(dim) other_dims = input_tensor.shape[1:] output_tensor = input_tensor.new_zeros((n_tokens, ) + other_dims, dtype=intermediate_dtype) @@ -200,7 +201,7 @@ def _accumulate_ep_triton_kernel(ep_positions_ptr, output_tensor_ptr, input_ptrs def _reduce_ep_triton(metadata: ReduceScatterMetadata, input_tensor: torch.Tensor, output_list: list[torch.Tensor], - world_size: int, dim: int, op: dist.ReduceOp.RedOpType, original_dtype: torch.dtype, + dim: int, op: dist.ReduceOp.RedOpType, original_dtype: torch.dtype, intermediate_dtype: torch.dtype) -> torch.Tensor: if op != dist.ReduceOp.SUM: raise NotImplementedError(f"Reduce operation {op} is not implemented.") @@ -268,10 +269,10 @@ def dtype_cast(dtype: torch.dtype) -> torch.dtype: # TODO(Keren): Implement a triton all-to-all kernel output_list = all_to_all(input_list, dim=0) if metadata.comm == CommKernelType.TORCH: - return _reduce_ep_torch(metadata, input_tensor, output_list, world_size, dim, op, original_dtype, + return _reduce_ep_torch(metadata, input_tensor, output_list, dim, op, original_dtype, intermediate_dtype) elif metadata.comm == CommKernelType.TRITON: - return _reduce_ep_triton(metadata, input_tensor, output_list, world_size, dim, op, original_dtype, + return _reduce_ep_triton(metadata, input_tensor, output_list, dim, op, original_dtype, intermediate_dtype) else: raise NotImplementedError(f"CommKernelType {metadata.comm} is not implemented.") @@ -493,7 +494,7 @@ def routing_triton(x, logits, n_expts_act, sm_first=False, expt_indx=None, n_row routing_data, gather_indx, scatter_indx, - ReduceScatterMetadata(input_split_sizes=output_split_sizes, ep_indx=ep_indx, EP=EP, TP=TP), + ReduceScatterMetadata(input_split_sizes=output_split_sizes, ep_indx=ep_indx, EP=EP, TP=TP, comm=CommKernelType.TRITON), ) @@ -620,7 +621,8 @@ def test_reduce_ep(TP, EP, n_tokens, hidden_size, n_expt_act): dim = 0 ret = _reduce_ep_triton(metadata, input_tensor, output_list, world_size, dim, op, original_dtype, intermediate_dtype) - ref = _reduce_ep_torch(metadata, input_tensor, output_list, world_size, dim, op, original_dtype, intermediate_dtype) + ref = _reduce_ep_torch(metadata, input_tensor, output_list, world_size, dim, op, original_dtype, + intermediate_dtype) torch.testing.assert_close(ret, ref) From e27fa0fc9a8b06f107e361fdb5e9eac24d45f8e9 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Fri, 26 Sep 2025 09:21:55 -0400 Subject: [PATCH 4/9] Update --- python/triton_kernels/bench/distributed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index 383eeeee78fe..634df6bec05f 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -494,7 +494,8 @@ def routing_triton(x, logits, n_expts_act, sm_first=False, expt_indx=None, n_row routing_data, gather_indx, scatter_indx, - ReduceScatterMetadata(input_split_sizes=output_split_sizes, ep_indx=ep_indx, EP=EP, TP=TP, comm=CommKernelType.TRITON), + ReduceScatterMetadata(input_split_sizes=output_split_sizes, ep_indx=ep_indx, EP=EP, TP=TP, + comm=CommKernelType.TRITON), ) @@ -621,8 +622,7 @@ def test_reduce_ep(TP, EP, n_tokens, hidden_size, n_expt_act): dim = 0 ret = _reduce_ep_triton(metadata, input_tensor, output_list, world_size, dim, op, original_dtype, intermediate_dtype) - ref = _reduce_ep_torch(metadata, input_tensor, output_list, world_size, dim, op, original_dtype, - intermediate_dtype) + ref = _reduce_ep_torch(metadata, input_tensor, output_list, world_size, dim, op, original_dtype, intermediate_dtype) torch.testing.assert_close(ret, ref) From 6f4386ac9a1be12d2cf4d788aca99b36db097e02 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Fri, 26 Sep 2025 09:25:44 -0400 Subject: [PATCH 5/9] Update --- python/triton_kernels/bench/distributed.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index 634df6bec05f..30a95b808c6b 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -620,9 +620,10 @@ def test_reduce_ep(TP, EP, n_tokens, hidden_size, n_expt_act): op = dist.ReduceOp.SUM dim = 0 - ret = _reduce_ep_triton(metadata, input_tensor, output_list, world_size, dim, op, original_dtype, + ret = _reduce_ep_triton(metadata, input_tensor, output_list, dim, op, original_dtype, intermediate_dtype) - ref = _reduce_ep_torch(metadata, input_tensor, output_list, world_size, dim, op, original_dtype, intermediate_dtype) + ref = _reduce_ep_torch(metadata, input_tensor, output_list, dim, op, original_dtype, + intermediate_dtype) torch.testing.assert_close(ret, ref) From 1ac60faee363ed4c96d34733c2977064dc49625e Mon Sep 17 00:00:00 2001 From: Jokeren Date: Fri, 26 Sep 2025 09:30:33 -0400 Subject: [PATCH 6/9] Lint --- python/triton_kernels/bench/distributed.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index 30a95b808c6b..e77ea1883438 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -620,10 +620,8 @@ def test_reduce_ep(TP, EP, n_tokens, hidden_size, n_expt_act): op = dist.ReduceOp.SUM dim = 0 - ret = _reduce_ep_triton(metadata, input_tensor, output_list, dim, op, original_dtype, - intermediate_dtype) - ref = _reduce_ep_torch(metadata, input_tensor, output_list, dim, op, original_dtype, - intermediate_dtype) + ret = _reduce_ep_triton(metadata, input_tensor, output_list, dim, op, original_dtype, intermediate_dtype) + ref = _reduce_ep_torch(metadata, input_tensor, output_list, dim, op, original_dtype, intermediate_dtype) torch.testing.assert_close(ret, ref) From badc20744e75eb67f56cfc108da84b8fb0d6ebae Mon Sep 17 00:00:00 2001 From: Jokeren Date: Fri, 26 Sep 2025 15:13:06 -0400 Subject: [PATCH 7/9] Update --- python/triton_kernels/bench/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index e77ea1883438..6836da7c7b83 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -188,7 +188,7 @@ def _accumulate_ep_triton_kernel(ep_positions_ptr, output_tensor_ptr, input_ptrs values = tl.load( input_ptrs[tl.constexpr(ep_idx * TP + tp_idx)] + col_offsets, mask=io_mask, - other=0, + other=0.0, ).to(intermediate_dtype) output += values From aa560de8e7a3b489ed49c30ad75ae711f16a4f53 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Fri, 26 Sep 2025 21:04:00 -0400 Subject: [PATCH 8/9] Update --- python/triton_kernels/bench/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index 6836da7c7b83..a823f697fa9e 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -171,7 +171,7 @@ def _accumulate_ep_triton_kernel(ep_positions_ptr, output_tensor_ptr, input_ptrs token_mask = offs_m < n_tokens offs_n = tl.arange(0, BLOCK_SIZE_N) feature_mask = offs_n < hidden_size - io_mask = token_mask[:] & feature_mask + io_mask = token_mask & feature_mask output = tl.zeros((BLOCK_SIZE_N, ), dtype=intermediate_dtype) for ep_idx in tl.static_range(EP): From 1f1b34a6c2642929b4a2315053d283022c384d8a Mon Sep 17 00:00:00 2001 From: Jokeren Date: Fri, 26 Sep 2025 21:06:05 -0400 Subject: [PATCH 9/9] Fix --- python/triton_kernels/bench/distributed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index a823f697fa9e..a62eabb27c5c 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -214,9 +214,9 @@ def _reduce_ep_triton(metadata: ReduceScatterMetadata, input_tensor: torch.Tenso positions = torch.empty((metadata.EP, n_tokens), dtype=torch.int32, device=input_tensor.device) triton_original_dtype = TRITON_DTYPE_MAP.get(original_dtype, tl.float32) triton_intermediate_dtype = TRITON_DTYPE_MAP.get(intermediate_dtype, tl.float32) - BLOCK_SIZE_M = 2048 if n_tokens >= 2048 else triton.next_power_of_2(n_tokens if n_tokens > 0 else 1) - BLOCK_SIZE_N = triton.next_power_of_2(hidden_size if hidden_size > 0 else 1) - BLOCK_SIZE_E = triton.next_power_of_2(n_expts if n_expts > 0 else 1) + BLOCK_SIZE_M = 2048 if n_tokens >= 2048 else triton.next_power_of_2(n_tokens) + BLOCK_SIZE_N = triton.next_power_of_2(hidden_size) + BLOCK_SIZE_E = triton.next_power_of_2(n_expts) # XXX(Keren): Do not over optimize this function for now, as the communication interface (all-to-all) is subject to change. _prepare_ep_positions_kernel[(metadata.EP, )](