Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 203 additions & 14 deletions python/triton_kernels/bench/distributed.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from enum import Enum
import os
import math
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple
from typing import Tuple, Optional

import triton
import triton.language as tl
Expand All @@ -31,12 +33,26 @@
from bench_utils import quantize_weight


class CommKernelType(Enum):
TORCH = "torch"
TRITON = "triton"
FUSE = "fuse"


@dataclass
class ReduceScatterMetadata:
input_split_sizes: list[int]
ep_indx: torch.Tensor
EP: int = 1
TP: int = 1
comm: CommKernelType = CommKernelType.TORCH


TRITON_DTYPE_MAP = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32,
}


def _is_distributed_launch() -> bool:
Expand Down Expand Up @@ -88,9 +104,150 @@ def all_gather(x: torch.Tensor, dim=0) -> torch.Tensor:
return x


def _reduce_ep_torch(metadata: ReduceScatterMetadata, input_tensor: torch.Tensor, output_list: list[torch.Tensor],
dim: int, op: dist.ReduceOp.RedOpType, original_dtype: torch.dtype,
intermediate_dtype: torch.dtype) -> torch.Tensor:
world_size = metadata.TP * metadata.EP
n_tokens = metadata.ep_indx.size(dim)
other_dims = input_tensor.shape[1:]
output_tensor = input_tensor.new_zeros((n_tokens, ) + other_dims, dtype=intermediate_dtype)
for i in range(world_size):
ep_rank = i // metadata.TP
mask = torch.any(metadata.ep_indx == ep_rank, dim=1)
if op == dist.ReduceOp.SUM:
output_tensor[mask] += output_list[i].to(intermediate_dtype)
else:
raise NotImplementedError(f"Reduce operation {op} is not implemented.")
return output_tensor.to(original_dtype)


@triton.jit
def _prepare_ep_positions_kernel(ep_indx_ptr, positions_ptr, n_tokens, n_expts, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_E: tl.constexpr):
ep_idx = tl.program_id(0)
token_offsets = tl.arange(0, BLOCK_SIZE_M)
expert_offsets = tl.arange(0, BLOCK_SIZE_E)
num_tiles = tl.cdiv(n_tokens, BLOCK_SIZE_M)
base = 0

for tile in range(num_tiles):
offs_m = tile * BLOCK_SIZE_M + token_offsets
token_mask = offs_m < n_tokens
load_mask = token_mask[:, None] & (expert_offsets[None, :] < n_expts)
ep_values = tl.load(
ep_indx_ptr + offs_m[:, None] * n_expts + expert_offsets[None, :],
mask=load_mask,
other=-1,
)

row_has_ep = tl.reduce_or(ep_values == ep_idx, axis=1) & token_mask
row_has_ep_i32 = row_has_ep.to(tl.int32)
prefix = tl.cumsum(row_has_ep_i32, axis=0) - row_has_ep_i32
positions = tl.where(row_has_ep, base + prefix, -1)
tl.store(positions_ptr + ep_idx * n_tokens + offs_m, positions, mask=token_mask)
increment = tl.sum(row_has_ep_i32, axis=0)
base = base + increment


def _accumulate_ep_metadata(grid, kernel, args):
ret = {}
n_tokens, hidden_size = args["n_tokens"], args["hidden_size"]
TP, EP = args["TP"], args["EP"]
ret["name"] = f"{kernel.name} [n_tokens={n_tokens}, hidden_size={hidden_size}, TP={TP}, EP={EP}]"
ep_positions_ptr = args["ep_positions_ptr"]
output_tensor_ptr = args["output_tensor_ptr"]
input_ptrs = args["input_ptrs"]
ret["bytes"] = (ep_positions_ptr.element_size() * ep_positions_ptr.numel() +
output_tensor_ptr.element_size() * output_tensor_ptr.numel() +
sum([p.element_size() * p.numel() for p in input_ptrs]))
return ret


@triton.jit(launch_metadata=_accumulate_ep_metadata)
def _accumulate_ep_triton_kernel(ep_positions_ptr, output_tensor_ptr, input_ptrs, n_tokens, hidden_size,
TP: tl.constexpr, EP: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
original_dtype: tl.constexpr, intermediate_dtype: tl.constexpr):
offs_m = tl.program_id(0)
token_mask = offs_m < n_tokens
offs_n = tl.arange(0, BLOCK_SIZE_N)
feature_mask = offs_n < hidden_size
io_mask = token_mask[:] & feature_mask
output = tl.zeros((BLOCK_SIZE_N, ), dtype=intermediate_dtype)
Comment on lines 170 to 175

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Scalar mask indexed as vector in Triton kernel

Inside _accumulate_ep_triton_kernel the mask is computed as io_mask = token_mask[:] & feature_mask. token_mask is a scalar (offs_m < n_tokens) so subscripting it with [:] is invalid in Triton and causes compilation to fail before the kernel can launch. This prevents the custom reduce‑scatter kernel from building. A scalar mask can be broadcast directly with token_mask & feature_mask or by constructing a vector mask of the appropriate shape.

Useful? React with 👍 / 👎.


for ep_idx in tl.static_range(EP):
position = tl.load(
ep_positions_ptr + ep_idx * n_tokens + offs_m,
mask=token_mask,
other=-1,
)
if position != -1:
row_offsets = position * hidden_size
col_offsets = row_offsets + offs_n

for tp_idx in tl.static_range(TP):
values = tl.load(
input_ptrs[tl.constexpr(ep_idx * TP + tp_idx)] + col_offsets,
mask=io_mask,
other=0.0,
).to(intermediate_dtype)
output += values

output = output.to(original_dtype)
tl.store(
output_tensor_ptr + offs_m * hidden_size + offs_n,
output,
mask=io_mask,
)


def _reduce_ep_triton(metadata: ReduceScatterMetadata, input_tensor: torch.Tensor, output_list: list[torch.Tensor],
dim: int, op: dist.ReduceOp.RedOpType, original_dtype: torch.dtype,
intermediate_dtype: torch.dtype) -> torch.Tensor:
if op != dist.ReduceOp.SUM:
raise NotImplementedError(f"Reduce operation {op} is not implemented.")
ep_indx = metadata.ep_indx.contiguous()
n_tokens = ep_indx.size(dim)
n_expts = ep_indx.size(1 - dim)
other_dims = input_tensor.shape[1:]
hidden_size = math.prod(other_dims)
output_tensor = input_tensor.new_empty((n_tokens, ) + other_dims, dtype=original_dtype)
positions = torch.empty((metadata.EP, n_tokens), dtype=torch.int32, device=input_tensor.device)
triton_original_dtype = TRITON_DTYPE_MAP.get(original_dtype, tl.float32)
triton_intermediate_dtype = TRITON_DTYPE_MAP.get(intermediate_dtype, tl.float32)
BLOCK_SIZE_M = 2048 if n_tokens >= 2048 else triton.next_power_of_2(n_tokens if n_tokens > 0 else 1)
BLOCK_SIZE_N = triton.next_power_of_2(hidden_size if hidden_size > 0 else 1)
BLOCK_SIZE_E = triton.next_power_of_2(n_expts if n_expts > 0 else 1)

# XXX(Keren): Do not over optimize this function for now, as the communication interface (all-to-all) is subject to change.
_prepare_ep_positions_kernel[(metadata.EP, )](
ep_indx,
positions,
n_tokens,
n_expts,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_E=BLOCK_SIZE_E,
num_warps=16 if n_tokens >= 2048 else 8,
) # type: ignore

_accumulate_ep_triton_kernel[(n_tokens, )](
positions,
output_tensor,
tuple(output_list),
Comment on lines +232 to +235

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Triton kernel called with unsupported tuple argument

The new _reduce_ep_triton path launches _accumulate_ep_triton_kernel and passes tuple(output_list) as a single kernel argument. Triton kernel parameters must be individual tensors or constexpr literals; a Python tuple is not a valid runtime value and the launch will raise a TypeError before any computation runs. As written, the TRITON communication path introduced in this commit cannot execute at all. Consider passing the tensors as separate arguments or materializing a device-side pointer array and indexing that inside the kernel.

Useful? React with 👍 / 👎.

n_tokens,
hidden_size,
TP=metadata.TP,
EP=metadata.EP,
BLOCK_SIZE_N=BLOCK_SIZE_N,
original_dtype=triton_original_dtype,
intermediate_dtype=triton_intermediate_dtype,
)

return output_tensor


def reduce_scatter(
input_tensor: torch.Tensor,
metadata: ReduceScatterMetadata = None,
metadata: Optional[ReduceScatterMetadata] = None,
dim: int = 0,
op: dist.ReduceOp.RedOpType = dist.ReduceOp.SUM,
) -> torch.Tensor:
Expand All @@ -109,18 +266,16 @@ def dtype_cast(dtype: torch.dtype) -> torch.dtype:
if metadata and metadata.input_split_sizes:
assert dim == 0, "metadata only works with dim=0"
input_list = list(input_tensor.split(metadata.input_split_sizes, dim=0))
# TODO(Keren): Implement a triton all-to-all kernel
output_list = all_to_all(input_list, dim=0)
n_tokens = metadata.ep_indx.size(dim)
other_dims = input_tensor.shape[1:]
output_tensor = input_tensor.new_zeros((n_tokens, ) + other_dims, dtype=intermediate_dtype)
for i in range(world_size):
ep_rank = i // metadata.TP
mask = torch.any(metadata.ep_indx == ep_rank, dim=1)
if op == dist.ReduceOp.SUM:
output_tensor[mask] += output_list[i].to(intermediate_dtype)
else:
raise NotImplementedError(f"Reduce operation {op} is not implemented.")
return output_tensor.to(original_dtype)
if metadata.comm == CommKernelType.TORCH:
return _reduce_ep_torch(metadata, input_tensor, output_list, dim, op, original_dtype,
intermediate_dtype)
elif metadata.comm == CommKernelType.TRITON:
return _reduce_ep_triton(metadata, input_tensor, output_list, dim, op, original_dtype,
intermediate_dtype)
else:
raise NotImplementedError(f"CommKernelType {metadata.comm} is not implemented.")
else:
input_list = list(input_tensor.chunk(world_size, dim=dim))
shape = input_list[0].shape
Expand Down Expand Up @@ -339,7 +494,8 @@ def routing_triton(x, logits, n_expts_act, sm_first=False, expt_indx=None, n_row
routing_data,
gather_indx,
scatter_indx,
ReduceScatterMetadata(input_split_sizes=output_split_sizes, ep_indx=ep_indx, EP=EP, TP=TP),
ReduceScatterMetadata(input_split_sizes=output_split_sizes, ep_indx=ep_indx, EP=EP, TP=TP,
comm=CommKernelType.TRITON),
)


Expand Down Expand Up @@ -436,6 +592,39 @@ def test_reduce_scatter_distributed_with_metadata(monkeypatch):
torch.testing.assert_close(result, torch.tensor([[1, 2], [1, 2]], dtype=torch.float32))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for Triton kernel execution")
@pytest.mark.parametrize("TP, EP", [(4, 2), (1, 8), (2, 2)])
@pytest.mark.parametrize("n_tokens", [128, 1024, 30720])
@pytest.mark.parametrize("hidden_size", [1024, 5760])
@pytest.mark.parametrize("n_expt_act", [4])
def test_reduce_ep(TP, EP, n_tokens, hidden_size, n_expt_act):
device = torch.device("cuda")
world_size = TP * EP

ep_indx = torch.randint(0, EP, (n_tokens, n_expt_act), device=device, dtype=torch.int32).contiguous()
original_dtype = torch.float16
intermediate_dtype = torch.float16
input_tensor = torch.zeros((n_tokens, hidden_size), device=device, dtype=original_dtype)
output_list = []
for rank in range(world_size):
ep_rank = rank // TP
mask = torch.any(ep_indx == ep_rank, dim=1)
random_values = torch.randn(mask.sum().item(), hidden_size, device=device, dtype=intermediate_dtype)
output_list.append(random_values)
metadata = ReduceScatterMetadata(
input_split_sizes=None, # it doesn't matter in this test since we skipped all_to_all
ep_indx=ep_indx,
EP=EP,
TP=TP,
)

op = dist.ReduceOp.SUM
dim = 0
ret = _reduce_ep_triton(metadata, input_tensor, output_list, dim, op, original_dtype, intermediate_dtype)
ref = _reduce_ep_torch(metadata, input_tensor, output_list, dim, op, original_dtype, intermediate_dtype)
torch.testing.assert_close(ret, ref)


def test_routing_distributed_EP(monkeypatch):
# Test distributed routing with EP=1 (token_mask should be None)
monkeypatch.setenv("WORLD_SIZE", "2")
Expand Down
2 changes: 1 addition & 1 deletion third_party/proton/proton/proton.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def parse_arguments():
choices=["shadow", "python"])
parser.add_argument("-m", "--mode", type=str, help="Profiling mode", default=None)
parser.add_argument("-d", "--data", type=str, help="Profiling data", default="tree", choices=["tree", "trace"])
parser.add_argument("-k", "--hook", type=str, help="Profiling hook", default=None, choices=[None, "launch"])
parser.add_argument("-k", "--hook", type=str, help="Profiling hook", default=None, choices=[None, "triton"])
parser.add_argument('target_args', nargs=argparse.REMAINDER, help='Subcommand and its arguments')
args = parser.parse_args()
return args, args.target_args
Expand Down
Loading