Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 59 additions & 45 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
TVM_FFI_ICHECK_EQ(routing_logits->ndim, 2) << "routing_logits must be 2D.";
TVM_FFI_ICHECK_EQ(routing_logits->shape[1], num_experts) << "routing_logits has incorrect shape.";
if (routing_bias.has_value()) {
TVM_FFI_ICHECK_EQ(routing_bias.value()->dtype, dl_bfloat16) << "routing_bias must be bfloat16.";
TVM_FFI_ICHECK(routing_bias.value()->dtype == dl_bfloat16 ||
routing_bias.value()->dtype == dl_float32)
<< "routing_bias must be bfloat16 or float.";
TVM_FFI_ICHECK_EQ(routing_bias.value()->ndim, 1) << "routing_bias must be 1D.";
TVM_FFI_ICHECK_EQ(routing_bias.value()->shape[0], num_experts)
<< "routing_bias has incorrect shape.";
Expand Down Expand Up @@ -109,6 +111,10 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
args.routing_logits = routing_logits->data;
auto const routing_bias_dtype =
routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16;
auto btg_routing_bias_dtype = btg::Dtype::Fp32;
if (routing_bias_dtype == dl_bfloat16) {
btg_routing_bias_dtype = btg::Dtype::Bfloat16;
}
args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr;
args.hidden_states = hidden_states->data;
args.gemm1_weights = gemm1_weights->data;
Expand Down Expand Up @@ -140,7 +146,7 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
Tensor permuted_idx_to_token_idx =
alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits->device);
Tensor expert_weights =
alloc_tensor({args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits->device);
alloc_tensor({args.num_tokens, args.top_k}, dl_bfloat16, routing_logits->device);
Tensor expert_indexes =
alloc_tensor({args.num_tokens, args.top_k}, dl_int32, routing_logits->device);
Tensor expert_count_histogram = alloc_tensor(
Expand Down Expand Up @@ -184,8 +190,9 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
static_cast<int*>(num_tokens_per_expert->data),
static_cast<int*>(cta_idx_xy_to_batch_idx->data),
static_cast<int*>(cta_idx_xy_to_mn_limit->data),
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, use_routing_scales_on_input,
false /* use_deep_seek_fp8 */, static_cast<RoutingMethodType>(routing_method_type), stream);
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, btg_routing_bias_dtype,
use_routing_scales_on_input, false /* use_deep_seek_fp8 */,
static_cast<RoutingMethodType>(routing_method_type), stream);

// MoE kernel except routing
TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_float8_e4m3fn) << "hidden_states must be fp8.";
Expand Down Expand Up @@ -366,7 +373,9 @@ void trtllm_fp8_block_scale_moe_launcher(

auto const routing_bias_dtype =
routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16;
args.mDtypeExpW = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
auto btg_routing_bias_dtype =
routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;

args.routing_logits = static_cast<float*>(routing_logits->data);
args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr;
args.hidden_states = hidden_states->data;
Expand Down Expand Up @@ -398,8 +407,10 @@ void trtllm_fp8_block_scale_moe_launcher(
alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits->device);
Tensor permuted_idx_to_token_idx =
alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits->device);

Tensor expert_weights =
alloc_tensor({args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits->device);
alloc_tensor({args.num_tokens, args.top_k}, dl_bfloat16, routing_logits->device);
// NOTE: the output type of routing kernel is currently always bfloat16
Tensor expert_indexes =
alloc_tensor({args.num_tokens, args.top_k}, dl_int32, routing_logits->device);
int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2));
Expand Down Expand Up @@ -432,20 +443,21 @@ void trtllm_fp8_block_scale_moe_launcher(

tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
cudaStream_t stream = get_stream(routing_logits->device);
routing_runner.run(static_cast<float*>(routing_logits->data), args.routing_bias, args.num_tokens,
args.num_experts, args.top_k, args.n_group, args.topk_group,
args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor,
static_cast<int*>(expert_indexes->data),
static_cast<int*>(expert_count_histogram->data),
static_cast<int*>(total_num_padded_tokens->data),
static_cast<int*>(expanded_idx_to_permuted_idx->data),
nullptr /*static_cast<int*>(permuted_idx_to_expanded_idx->data)*/,
static_cast<int*>(permuted_idx_to_token_idx->data), expert_weights->data,
static_cast<int*>(num_tokens_per_expert->data),
static_cast<int*>(cta_idx_xy_to_batch_idx->data),
static_cast<int*>(cta_idx_xy_to_mn_limit->data),
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, false, true,
static_cast<RoutingMethodType>(routing_method_type), stream);
routing_runner.run(
static_cast<float*>(routing_logits->data), args.routing_bias, args.num_tokens,
args.num_experts, args.top_k, args.n_group, args.topk_group, args.local_expert_offset,
args.local_num_experts, args.routed_scaling_factor, static_cast<int*>(expert_indexes->data),
static_cast<int*>(expert_count_histogram->data),
static_cast<int*>(total_num_padded_tokens->data),
static_cast<int*>(expanded_idx_to_permuted_idx->data),
nullptr /*static_cast<int*>(permuted_idx_to_expanded_idx->data)*/,
static_cast<int*>(permuted_idx_to_token_idx->data), expert_weights->data,
static_cast<int*>(num_tokens_per_expert->data),
static_cast<int*>(cta_idx_xy_to_batch_idx->data),
static_cast<int*>(cta_idx_xy_to_mn_limit->data),
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, btg_routing_bias_dtype,
false /* use_routing_scales_on_input */, true /* use_deep_seek_fp8 */,
static_cast<RoutingMethodType>(routing_method_type), stream);

// MoE kernel except routing
TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_float8_e4m3fn) << "hidden_states must be fp8.";
Expand Down Expand Up @@ -671,7 +683,10 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
<< "routing_logits has incorrect shape.";
}
if (routing_bias.has_value()) {
TVM_FFI_ICHECK_EQ(routing_bias.value()->dtype, dl_bfloat16) << "routing_bias must be bfloat16.";
TVM_FFI_ICHECK(routing_bias.value()->dtype == dl_bfloat16 ||
routing_bias.value()->dtype == dl_float32)
<< "routing_bias must be bfloat16 or float.";

TVM_FFI_ICHECK_EQ(routing_bias.value()->ndim, 1) << "routing_bias must be 1D.";
TVM_FFI_ICHECK_EQ(routing_bias.value()->shape[0], num_experts)
<< "routing_bias has incorrect shape.";
Expand Down Expand Up @@ -714,15 +729,14 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace;

// setup args
// note: the assumption is that output data type is always Bfloat16 (the default)
auto routing_bias_dtype = dl_bfloat16;
if (routing_bias.has_value()) {
routing_bias_dtype = routing_bias.value()->dtype;
} else if (routing_logits.has_value()) {
routing_bias_dtype = routing_logits.value()->dtype;
}
args.mDtypeElt = dtype_act;
args.mDtypeExpW = routing_bias_dtype == dl_float32 ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16;
// note: the assumption is that output data type is always Bfloat16 (the default)
auto routing_bias_dtype = routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16;
auto btg_routing_bias_dtype =
routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
// We shouln't use args.mDtypeExpW since it indicates the output data type of routing kernel,
// which is currently always bfloat16 for routing kernel while the data type of routing bias now
// can be fp32
args.routing_logits = routing_logits.has_value() ? routing_logits.value()->data : nullptr;
args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr;
args.hidden_states = hidden_states->data;
Expand Down Expand Up @@ -771,7 +785,7 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
Tensor permuted_idx_to_token_idx =
alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states->device);
// Tensor expert_weights = alloc_tensor(
// {args.num_tokens, args.top_k}, routing_bias_dtype, hidden_states->device);
// {args.num_tokens, args.top_k}, dl_bfloat16, hidden_states->device);
// Tensor expert_indexes = alloc_tensor(
// {args.num_tokens, args.top_k}, dl_int32, hidden_states->device);
int constexpr MAX_NUM_EXPERTS = 384;
Expand Down Expand Up @@ -815,21 +829,21 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(

tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
cudaStream_t stream = get_stream(hidden_states->device);
routing_runner.run(args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts,
args.top_k, args.n_group, args.topk_group, args.local_expert_offset,
args.local_num_experts, args.routed_scaling_factor,
static_cast<int*>(expert_indices->data),
static_cast<int*>(expert_count_histogram->data),
static_cast<int*>(total_num_padded_tokens->data),
static_cast<int*>(expanded_idx_to_permuted_idx->data),
nullptr, /*static_cast<int*>(permuted_idx_to_expanded_idx->data),*/
static_cast<int*>(permuted_idx_to_token_idx->data), expert_weights->data,
static_cast<int*>(num_tokens_per_expert->data),
static_cast<int*>(cta_idx_xy_to_batch_idx->data),
static_cast<int*>(cta_idx_xy_to_mn_limit->data),
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt,
false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */,
static_cast<RoutingMethodType>(routing_method_type), stream);
routing_runner.run(
args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k,
args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts,
args.routed_scaling_factor, static_cast<int*>(expert_indices->data),
static_cast<int*>(expert_count_histogram->data),
static_cast<int*>(total_num_padded_tokens->data),
static_cast<int*>(expanded_idx_to_permuted_idx->data),
nullptr, /*static_cast<int*>(permuted_idx_to_expanded_idx->data),*/
static_cast<int*>(permuted_idx_to_token_idx->data), expert_weights->data,
static_cast<int*>(num_tokens_per_expert->data),
static_cast<int*>(cta_idx_xy_to_batch_idx->data),
static_cast<int*>(cta_idx_xy_to_mn_limit->data),
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, btg_routing_bias_dtype,
false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */,
static_cast<RoutingMethodType>(routing_method_type), stream);

//
// FC13 (gemm1) + FC2 (gemm2)
Expand Down
52 changes: 26 additions & 26 deletions csrc/trtllm_fused_moe_routing_deepseek.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ __global__ void routingMainKernel(KernelParams params) {
expertSelected = laneIdx < params.mNumExpertsPerGroup;
}
auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert;
auto biasVal = expertSelected ? params.mPtrRoutingBias[threadExpert] : invalidScore;

auto biasVal =
expertSelected ? static_cast<float>(params.mPtrRoutingBias[threadExpert]) : invalidScoreFloat;
// initialize the mPtrExpertCounts
if (params.mPtrExpertCounts) {
int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x;
Expand Down Expand Up @@ -496,24 +496,24 @@ void runImpl(Data& data, void* stream) {

// Maximum number of tokens supported by the kernel using a cooperative launch.
int const maxTokensCoop = (numBlocksCoop * NumThreads * 64) / data.mTopK;
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
/*coopLaunch=*/false, routingMainKernel, numBlocks, NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
LAUNCH_ROUTING_DEEPSEEK(data,
/*coopLaunch=*/false, routingMainKernel, numBlocks, NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1);

if (data.mPtrPermutedIdxSize != nullptr) {
if (useSingleCluster) {
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
/*coopLaunch=*/false, routingIndicesClusterKernel,
NumBlocksPerCluster, NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
LAUNCH_ROUTING_DEEPSEEK(data,
/*coopLaunch=*/false, routingIndicesClusterKernel,
NumBlocksPerCluster, NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1);
} else if (data.mNumTokens <= maxTokensCoop) {
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
/*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop,
NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
LAUNCH_ROUTING_DEEPSEEK(data,
/*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop,
NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1);
} else {
const int32_t expandedIdxSize = data.mNumTokens * data.mTopK;

Expand All @@ -528,16 +528,16 @@ void runImpl(Data& data, void* stream) {
int const numBlocksOffsets =
std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks);

LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
/*coopLaunch=*/false, routingIndicesHistogramKernel,
numBlocksHistogram, NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
/*coopLaunch=*/false, routingIndicesOffsetsKernel,
numBlocksOffsets, NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
LAUNCH_ROUTING_DEEPSEEK(data,
/*coopLaunch=*/false, routingIndicesHistogramKernel,
numBlocksHistogram, NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1);
LAUNCH_ROUTING_DEEPSEEK(data,
/*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets,
NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1);
}
}
}
Expand Down
11 changes: 8 additions & 3 deletions csrc/trtllm_fused_moe_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,18 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx,
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert,
int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit,
int32_t* numNonExitingCtas, btg::Dtype dtypeElt, bool useRoutingScalesOnInput,
bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream) {
int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias,
bool useRoutingScalesOnInput, bool useDeepSeekFp8,
RoutingMethodType routingMethodType, cudaStream_t stream) {
if (routingMethodType == RoutingMethodType::DeepSeekV3) {
FLASHINFER_CHECK(topK <= 8, "For DeepSeek routing method, must have topK <= 8");
FLASHINFER_CHECK(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4");
moe::dev::routing::routingDeepSeek::Data routingData;
routingData.mDtypeExpW = btg::Dtype::Bfloat16;
routingData.mDtypeExpW =
btg::Dtype::Bfloat16; // for DeepSeek, the expW is currently always bfloat16
routingData.mDtypeBias = dtypeBias; // for DeepSeek, the bias can be bfloat16 or fp32

routingData.mDtypeScore = btg::Dtype::Fp32; // for DeepSeek, the score is currently always fp32
routingData.mUsePdl = true;

// output:
Expand Down
Loading
Loading