-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[MULTI-GPU] Optimize reduce_scatter (except all-to-all) using custom triton kernels #8300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
8630eca
6468eae
af3af64
e27fa0f
6f4386a
1ac60fa
badc207
aa560de
1f1b34a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
from enum import Enum | ||
import os | ||
import math | ||
import pytest | ||
import torch | ||
import torch.distributed as dist | ||
import torch.multiprocessing as mp | ||
from copy import deepcopy | ||
from dataclasses import dataclass | ||
from typing import Tuple | ||
from typing import Tuple, Optional | ||
|
||
import triton | ||
import triton.language as tl | ||
|
@@ -31,12 +33,26 @@ | |
from bench_utils import quantize_weight | ||
|
||
|
||
class CommKernelType(Enum): | ||
TORCH = "torch" | ||
TRITON = "triton" | ||
FUSE = "fuse" | ||
|
||
|
||
@dataclass | ||
class ReduceScatterMetadata: | ||
input_split_sizes: list[int] | ||
ep_indx: torch.Tensor | ||
EP: int = 1 | ||
TP: int = 1 | ||
comm: CommKernelType = CommKernelType.TORCH | ||
|
||
|
||
TRITON_DTYPE_MAP = { | ||
torch.float16: tl.float16, | ||
torch.bfloat16: tl.bfloat16, | ||
torch.float32: tl.float32, | ||
} | ||
|
||
|
||
def _is_distributed_launch() -> bool: | ||
|
@@ -88,9 +104,150 @@ def all_gather(x: torch.Tensor, dim=0) -> torch.Tensor: | |
return x | ||
|
||
|
||
def _reduce_ep_torch(metadata: ReduceScatterMetadata, input_tensor: torch.Tensor, output_list: list[torch.Tensor], | ||
dim: int, op: dist.ReduceOp.RedOpType, original_dtype: torch.dtype, | ||
intermediate_dtype: torch.dtype) -> torch.Tensor: | ||
world_size = metadata.TP * metadata.EP | ||
n_tokens = metadata.ep_indx.size(dim) | ||
other_dims = input_tensor.shape[1:] | ||
output_tensor = input_tensor.new_zeros((n_tokens, ) + other_dims, dtype=intermediate_dtype) | ||
for i in range(world_size): | ||
ep_rank = i // metadata.TP | ||
mask = torch.any(metadata.ep_indx == ep_rank, dim=1) | ||
if op == dist.ReduceOp.SUM: | ||
output_tensor[mask] += output_list[i].to(intermediate_dtype) | ||
else: | ||
raise NotImplementedError(f"Reduce operation {op} is not implemented.") | ||
return output_tensor.to(original_dtype) | ||
|
||
|
||
@triton.jit | ||
def _prepare_ep_positions_kernel(ep_indx_ptr, positions_ptr, n_tokens, n_expts, BLOCK_SIZE_M: tl.constexpr, | ||
BLOCK_SIZE_E: tl.constexpr): | ||
ep_idx = tl.program_id(0) | ||
token_offsets = tl.arange(0, BLOCK_SIZE_M) | ||
expert_offsets = tl.arange(0, BLOCK_SIZE_E) | ||
num_tiles = tl.cdiv(n_tokens, BLOCK_SIZE_M) | ||
base = 0 | ||
|
||
for tile in range(num_tiles): | ||
offs_m = tile * BLOCK_SIZE_M + token_offsets | ||
token_mask = offs_m < n_tokens | ||
load_mask = token_mask[:, None] & (expert_offsets[None, :] < n_expts) | ||
ep_values = tl.load( | ||
ep_indx_ptr + offs_m[:, None] * n_expts + expert_offsets[None, :], | ||
mask=load_mask, | ||
other=-1, | ||
) | ||
|
||
row_has_ep = tl.reduce_or(ep_values == ep_idx, axis=1) & token_mask | ||
row_has_ep_i32 = row_has_ep.to(tl.int32) | ||
prefix = tl.cumsum(row_has_ep_i32, axis=0) - row_has_ep_i32 | ||
positions = tl.where(row_has_ep, base + prefix, -1) | ||
tl.store(positions_ptr + ep_idx * n_tokens + offs_m, positions, mask=token_mask) | ||
increment = tl.sum(row_has_ep_i32, axis=0) | ||
base = base + increment | ||
|
||
|
||
def _accumulate_ep_metadata(grid, kernel, args): | ||
ret = {} | ||
n_tokens, hidden_size = args["n_tokens"], args["hidden_size"] | ||
TP, EP = args["TP"], args["EP"] | ||
ret["name"] = f"{kernel.name} [n_tokens={n_tokens}, hidden_size={hidden_size}, TP={TP}, EP={EP}]" | ||
ep_positions_ptr = args["ep_positions_ptr"] | ||
output_tensor_ptr = args["output_tensor_ptr"] | ||
input_ptrs = args["input_ptrs"] | ||
ret["bytes"] = (ep_positions_ptr.element_size() * ep_positions_ptr.numel() + | ||
output_tensor_ptr.element_size() * output_tensor_ptr.numel() + | ||
sum([p.element_size() * p.numel() for p in input_ptrs])) | ||
return ret | ||
|
||
|
||
@triton.jit(launch_metadata=_accumulate_ep_metadata) | ||
def _accumulate_ep_triton_kernel(ep_positions_ptr, output_tensor_ptr, input_ptrs, n_tokens, hidden_size, | ||
TP: tl.constexpr, EP: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, | ||
original_dtype: tl.constexpr, intermediate_dtype: tl.constexpr): | ||
offs_m = tl.program_id(0) | ||
token_mask = offs_m < n_tokens | ||
offs_n = tl.arange(0, BLOCK_SIZE_N) | ||
feature_mask = offs_n < hidden_size | ||
io_mask = token_mask[:] & feature_mask | ||
output = tl.zeros((BLOCK_SIZE_N, ), dtype=intermediate_dtype) | ||
|
||
for ep_idx in tl.static_range(EP): | ||
position = tl.load( | ||
ep_positions_ptr + ep_idx * n_tokens + offs_m, | ||
mask=token_mask, | ||
other=-1, | ||
) | ||
if position != -1: | ||
row_offsets = position * hidden_size | ||
col_offsets = row_offsets + offs_n | ||
|
||
for tp_idx in tl.static_range(TP): | ||
values = tl.load( | ||
input_ptrs[tl.constexpr(ep_idx * TP + tp_idx)] + col_offsets, | ||
mask=io_mask, | ||
other=0.0, | ||
).to(intermediate_dtype) | ||
output += values | ||
|
||
output = output.to(original_dtype) | ||
tl.store( | ||
output_tensor_ptr + offs_m * hidden_size + offs_n, | ||
output, | ||
mask=io_mask, | ||
) | ||
|
||
|
||
def _reduce_ep_triton(metadata: ReduceScatterMetadata, input_tensor: torch.Tensor, output_list: list[torch.Tensor], | ||
dim: int, op: dist.ReduceOp.RedOpType, original_dtype: torch.dtype, | ||
intermediate_dtype: torch.dtype) -> torch.Tensor: | ||
if op != dist.ReduceOp.SUM: | ||
raise NotImplementedError(f"Reduce operation {op} is not implemented.") | ||
ep_indx = metadata.ep_indx.contiguous() | ||
n_tokens = ep_indx.size(dim) | ||
n_expts = ep_indx.size(1 - dim) | ||
other_dims = input_tensor.shape[1:] | ||
hidden_size = math.prod(other_dims) | ||
output_tensor = input_tensor.new_empty((n_tokens, ) + other_dims, dtype=original_dtype) | ||
positions = torch.empty((metadata.EP, n_tokens), dtype=torch.int32, device=input_tensor.device) | ||
triton_original_dtype = TRITON_DTYPE_MAP.get(original_dtype, tl.float32) | ||
triton_intermediate_dtype = TRITON_DTYPE_MAP.get(intermediate_dtype, tl.float32) | ||
BLOCK_SIZE_M = 2048 if n_tokens >= 2048 else triton.next_power_of_2(n_tokens if n_tokens > 0 else 1) | ||
BLOCK_SIZE_N = triton.next_power_of_2(hidden_size if hidden_size > 0 else 1) | ||
BLOCK_SIZE_E = triton.next_power_of_2(n_expts if n_expts > 0 else 1) | ||
|
||
# XXX(Keren): Do not over optimize this function for now, as the communication interface (all-to-all) is subject to change. | ||
_prepare_ep_positions_kernel[(metadata.EP, )]( | ||
ep_indx, | ||
positions, | ||
n_tokens, | ||
n_expts, | ||
BLOCK_SIZE_M=BLOCK_SIZE_M, | ||
BLOCK_SIZE_E=BLOCK_SIZE_E, | ||
num_warps=16 if n_tokens >= 2048 else 8, | ||
) # type: ignore | ||
|
||
_accumulate_ep_triton_kernel[(n_tokens, )]( | ||
positions, | ||
output_tensor, | ||
tuple(output_list), | ||
Comment on lines
+232
to
+235
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [P0] Triton kernel called with unsupported tuple argument The new Useful? React with 👍 / 👎. |
||
n_tokens, | ||
hidden_size, | ||
TP=metadata.TP, | ||
EP=metadata.EP, | ||
BLOCK_SIZE_N=BLOCK_SIZE_N, | ||
original_dtype=triton_original_dtype, | ||
intermediate_dtype=triton_intermediate_dtype, | ||
) | ||
|
||
return output_tensor | ||
|
||
|
||
def reduce_scatter( | ||
input_tensor: torch.Tensor, | ||
metadata: ReduceScatterMetadata = None, | ||
metadata: Optional[ReduceScatterMetadata] = None, | ||
dim: int = 0, | ||
op: dist.ReduceOp.RedOpType = dist.ReduceOp.SUM, | ||
) -> torch.Tensor: | ||
|
@@ -109,18 +266,16 @@ def dtype_cast(dtype: torch.dtype) -> torch.dtype: | |
if metadata and metadata.input_split_sizes: | ||
assert dim == 0, "metadata only works with dim=0" | ||
input_list = list(input_tensor.split(metadata.input_split_sizes, dim=0)) | ||
# TODO(Keren): Implement a triton all-to-all kernel | ||
output_list = all_to_all(input_list, dim=0) | ||
n_tokens = metadata.ep_indx.size(dim) | ||
other_dims = input_tensor.shape[1:] | ||
output_tensor = input_tensor.new_zeros((n_tokens, ) + other_dims, dtype=intermediate_dtype) | ||
for i in range(world_size): | ||
ep_rank = i // metadata.TP | ||
mask = torch.any(metadata.ep_indx == ep_rank, dim=1) | ||
if op == dist.ReduceOp.SUM: | ||
output_tensor[mask] += output_list[i].to(intermediate_dtype) | ||
else: | ||
raise NotImplementedError(f"Reduce operation {op} is not implemented.") | ||
return output_tensor.to(original_dtype) | ||
if metadata.comm == CommKernelType.TORCH: | ||
return _reduce_ep_torch(metadata, input_tensor, output_list, dim, op, original_dtype, | ||
intermediate_dtype) | ||
elif metadata.comm == CommKernelType.TRITON: | ||
return _reduce_ep_triton(metadata, input_tensor, output_list, dim, op, original_dtype, | ||
intermediate_dtype) | ||
else: | ||
raise NotImplementedError(f"CommKernelType {metadata.comm} is not implemented.") | ||
else: | ||
input_list = list(input_tensor.chunk(world_size, dim=dim)) | ||
shape = input_list[0].shape | ||
|
@@ -339,7 +494,8 @@ def routing_triton(x, logits, n_expts_act, sm_first=False, expt_indx=None, n_row | |
routing_data, | ||
gather_indx, | ||
scatter_indx, | ||
ReduceScatterMetadata(input_split_sizes=output_split_sizes, ep_indx=ep_indx, EP=EP, TP=TP), | ||
ReduceScatterMetadata(input_split_sizes=output_split_sizes, ep_indx=ep_indx, EP=EP, TP=TP, | ||
comm=CommKernelType.TRITON), | ||
) | ||
|
||
|
||
|
@@ -436,6 +592,39 @@ def test_reduce_scatter_distributed_with_metadata(monkeypatch): | |
torch.testing.assert_close(result, torch.tensor([[1, 2], [1, 2]], dtype=torch.float32)) | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for Triton kernel execution") | ||
@pytest.mark.parametrize("TP, EP", [(4, 2), (1, 8), (2, 2)]) | ||
@pytest.mark.parametrize("n_tokens", [128, 1024, 30720]) | ||
@pytest.mark.parametrize("hidden_size", [1024, 5760]) | ||
@pytest.mark.parametrize("n_expt_act", [4]) | ||
def test_reduce_ep(TP, EP, n_tokens, hidden_size, n_expt_act): | ||
device = torch.device("cuda") | ||
world_size = TP * EP | ||
|
||
ep_indx = torch.randint(0, EP, (n_tokens, n_expt_act), device=device, dtype=torch.int32).contiguous() | ||
original_dtype = torch.float16 | ||
intermediate_dtype = torch.float16 | ||
input_tensor = torch.zeros((n_tokens, hidden_size), device=device, dtype=original_dtype) | ||
output_list = [] | ||
for rank in range(world_size): | ||
ep_rank = rank // TP | ||
mask = torch.any(ep_indx == ep_rank, dim=1) | ||
random_values = torch.randn(mask.sum().item(), hidden_size, device=device, dtype=intermediate_dtype) | ||
output_list.append(random_values) | ||
metadata = ReduceScatterMetadata( | ||
input_split_sizes=None, # it doesn't matter in this test since we skipped all_to_all | ||
ep_indx=ep_indx, | ||
EP=EP, | ||
TP=TP, | ||
) | ||
|
||
op = dist.ReduceOp.SUM | ||
dim = 0 | ||
ret = _reduce_ep_triton(metadata, input_tensor, output_list, dim, op, original_dtype, intermediate_dtype) | ||
ref = _reduce_ep_torch(metadata, input_tensor, output_list, dim, op, original_dtype, intermediate_dtype) | ||
torch.testing.assert_close(ret, ref) | ||
|
||
|
||
def test_routing_distributed_EP(monkeypatch): | ||
# Test distributed routing with EP=1 (token_mask should be None) | ||
monkeypatch.setenv("WORLD_SIZE", "2") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Scalar mask indexed as vector in Triton kernel
Inside
_accumulate_ep_triton_kernel
the mask is computed asio_mask = token_mask[:] & feature_mask
.token_mask
is a scalar (offs_m < n_tokens
) so subscripting it with[:]
is invalid in Triton and causes compilation to fail before the kernel can launch. This prevents the custom reduce‑scatter kernel from building. A scalar mask can be broadcast directly withtoken_mask & feature_mask
or by constructing a vector mask of the appropriate shape.Useful? React with 👍 / 👎.