diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 5741611644..e7c64af188 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -67,7 +67,9 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher( TVM_FFI_ICHECK_EQ(routing_logits->ndim, 2) << "routing_logits must be 2D."; TVM_FFI_ICHECK_EQ(routing_logits->shape[1], num_experts) << "routing_logits has incorrect shape."; if (routing_bias.has_value()) { - TVM_FFI_ICHECK_EQ(routing_bias.value()->dtype, dl_bfloat16) << "routing_bias must be bfloat16."; + TVM_FFI_ICHECK(routing_bias.value()->dtype == dl_bfloat16 || + routing_bias.value()->dtype == dl_float32) + << "routing_bias must be bfloat16 or float."; TVM_FFI_ICHECK_EQ(routing_bias.value()->ndim, 1) << "routing_bias must be 1D."; TVM_FFI_ICHECK_EQ(routing_bias.value()->shape[0], num_experts) << "routing_bias has incorrect shape."; @@ -109,6 +111,10 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher( args.routing_logits = routing_logits->data; auto const routing_bias_dtype = routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16; + auto btg_routing_bias_dtype = btg::Dtype::Fp32; + if (routing_bias_dtype == dl_bfloat16) { + btg_routing_bias_dtype = btg::Dtype::Bfloat16; + } args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr; args.hidden_states = hidden_states->data; args.gemm1_weights = gemm1_weights->data; @@ -140,7 +146,7 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher( Tensor permuted_idx_to_token_idx = alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits->device); Tensor expert_weights = - alloc_tensor({args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits->device); + alloc_tensor({args.num_tokens, args.top_k}, dl_bfloat16, routing_logits->device); Tensor expert_indexes = alloc_tensor({args.num_tokens, args.top_k}, dl_int32, routing_logits->device); Tensor expert_count_histogram = alloc_tensor( @@ -184,8 +190,9 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher( static_cast(num_tokens_per_expert->data), static_cast(cta_idx_xy_to_batch_idx->data), static_cast(cta_idx_xy_to_mn_limit->data), - static_cast(num_non_exiting_ctas->data), args.mDtypeElt, use_routing_scales_on_input, - false /* use_deep_seek_fp8 */, static_cast(routing_method_type), stream); + static_cast(num_non_exiting_ctas->data), args.mDtypeElt, btg_routing_bias_dtype, + use_routing_scales_on_input, false /* use_deep_seek_fp8 */, + static_cast(routing_method_type), stream); // MoE kernel except routing TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_float8_e4m3fn) << "hidden_states must be fp8."; @@ -366,7 +373,9 @@ void trtllm_fp8_block_scale_moe_launcher( auto const routing_bias_dtype = routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16; - args.mDtypeExpW = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + auto btg_routing_bias_dtype = + routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + args.routing_logits = static_cast(routing_logits->data); args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr; args.hidden_states = hidden_states->data; @@ -398,8 +407,10 @@ void trtllm_fp8_block_scale_moe_launcher( alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits->device); Tensor permuted_idx_to_token_idx = alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits->device); + Tensor expert_weights = - alloc_tensor({args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits->device); + alloc_tensor({args.num_tokens, args.top_k}, dl_bfloat16, routing_logits->device); + // NOTE: the output type of routing kernel is currently always bfloat16 Tensor expert_indexes = alloc_tensor({args.num_tokens, args.top_k}, dl_int32, routing_logits->device); int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2)); @@ -432,20 +443,21 @@ void trtllm_fp8_block_scale_moe_launcher( tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); cudaStream_t stream = get_stream(routing_logits->device); - routing_runner.run(static_cast(routing_logits->data), args.routing_bias, args.num_tokens, - args.num_experts, args.top_k, args.n_group, args.topk_group, - args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor, - static_cast(expert_indexes->data), - static_cast(expert_count_histogram->data), - static_cast(total_num_padded_tokens->data), - static_cast(expanded_idx_to_permuted_idx->data), - nullptr /*static_cast(permuted_idx_to_expanded_idx->data)*/, - static_cast(permuted_idx_to_token_idx->data), expert_weights->data, - static_cast(num_tokens_per_expert->data), - static_cast(cta_idx_xy_to_batch_idx->data), - static_cast(cta_idx_xy_to_mn_limit->data), - static_cast(num_non_exiting_ctas->data), args.mDtypeElt, false, true, - static_cast(routing_method_type), stream); + routing_runner.run( + static_cast(routing_logits->data), args.routing_bias, args.num_tokens, + args.num_experts, args.top_k, args.n_group, args.topk_group, args.local_expert_offset, + args.local_num_experts, args.routed_scaling_factor, static_cast(expert_indexes->data), + static_cast(expert_count_histogram->data), + static_cast(total_num_padded_tokens->data), + static_cast(expanded_idx_to_permuted_idx->data), + nullptr /*static_cast(permuted_idx_to_expanded_idx->data)*/, + static_cast(permuted_idx_to_token_idx->data), expert_weights->data, + static_cast(num_tokens_per_expert->data), + static_cast(cta_idx_xy_to_batch_idx->data), + static_cast(cta_idx_xy_to_mn_limit->data), + static_cast(num_non_exiting_ctas->data), args.mDtypeElt, btg_routing_bias_dtype, + false /* use_routing_scales_on_input */, true /* use_deep_seek_fp8 */, + static_cast(routing_method_type), stream); // MoE kernel except routing TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_float8_e4m3fn) << "hidden_states must be fp8."; @@ -671,7 +683,10 @@ Array trtllm_fp4_block_scale_moe_launcher( << "routing_logits has incorrect shape."; } if (routing_bias.has_value()) { - TVM_FFI_ICHECK_EQ(routing_bias.value()->dtype, dl_bfloat16) << "routing_bias must be bfloat16."; + TVM_FFI_ICHECK(routing_bias.value()->dtype == dl_bfloat16 || + routing_bias.value()->dtype == dl_float32) + << "routing_bias must be bfloat16 or float."; + TVM_FFI_ICHECK_EQ(routing_bias.value()->ndim, 1) << "routing_bias must be 1D."; TVM_FFI_ICHECK_EQ(routing_bias.value()->shape[0], num_experts) << "routing_bias has incorrect shape."; @@ -714,15 +729,14 @@ Array trtllm_fp4_block_scale_moe_launcher( tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace; // setup args - // note: the assumption is that output data type is always Bfloat16 (the default) - auto routing_bias_dtype = dl_bfloat16; - if (routing_bias.has_value()) { - routing_bias_dtype = routing_bias.value()->dtype; - } else if (routing_logits.has_value()) { - routing_bias_dtype = routing_logits.value()->dtype; - } args.mDtypeElt = dtype_act; - args.mDtypeExpW = routing_bias_dtype == dl_float32 ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16; + // note: the assumption is that output data type is always Bfloat16 (the default) + auto routing_bias_dtype = routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16; + auto btg_routing_bias_dtype = + routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + // We shouln't use args.mDtypeExpW since it indicates the output data type of routing kernel, + // which is currently always bfloat16 for routing kernel while the data type of routing bias now + // can be fp32 args.routing_logits = routing_logits.has_value() ? routing_logits.value()->data : nullptr; args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr; args.hidden_states = hidden_states->data; @@ -771,7 +785,7 @@ Array trtllm_fp4_block_scale_moe_launcher( Tensor permuted_idx_to_token_idx = alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states->device); // Tensor expert_weights = alloc_tensor( - // {args.num_tokens, args.top_k}, routing_bias_dtype, hidden_states->device); + // {args.num_tokens, args.top_k}, dl_bfloat16, hidden_states->device); // Tensor expert_indexes = alloc_tensor( // {args.num_tokens, args.top_k}, dl_int32, hidden_states->device); int constexpr MAX_NUM_EXPERTS = 384; @@ -815,21 +829,21 @@ Array trtllm_fp4_block_scale_moe_launcher( tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); cudaStream_t stream = get_stream(hidden_states->device); - routing_runner.run(args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, - args.top_k, args.n_group, args.topk_group, args.local_expert_offset, - args.local_num_experts, args.routed_scaling_factor, - static_cast(expert_indices->data), - static_cast(expert_count_histogram->data), - static_cast(total_num_padded_tokens->data), - static_cast(expanded_idx_to_permuted_idx->data), - nullptr, /*static_cast(permuted_idx_to_expanded_idx->data),*/ - static_cast(permuted_idx_to_token_idx->data), expert_weights->data, - static_cast(num_tokens_per_expert->data), - static_cast(cta_idx_xy_to_batch_idx->data), - static_cast(cta_idx_xy_to_mn_limit->data), - static_cast(num_non_exiting_ctas->data), args.mDtypeElt, - false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */, - static_cast(routing_method_type), stream); + routing_runner.run( + args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k, + args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, + args.routed_scaling_factor, static_cast(expert_indices->data), + static_cast(expert_count_histogram->data), + static_cast(total_num_padded_tokens->data), + static_cast(expanded_idx_to_permuted_idx->data), + nullptr, /*static_cast(permuted_idx_to_expanded_idx->data),*/ + static_cast(permuted_idx_to_token_idx->data), expert_weights->data, + static_cast(num_tokens_per_expert->data), + static_cast(cta_idx_xy_to_batch_idx->data), + static_cast(cta_idx_xy_to_mn_limit->data), + static_cast(num_non_exiting_ctas->data), args.mDtypeElt, btg_routing_bias_dtype, + false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */, + static_cast(routing_method_type), stream); // // FC13 (gemm1) + FC2 (gemm2) diff --git a/csrc/trtllm_fused_moe_routing_deepseek.cu b/csrc/trtllm_fused_moe_routing_deepseek.cu index 4ea1ba178e..426b62408c 100644 --- a/csrc/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/trtllm_fused_moe_routing_deepseek.cu @@ -79,8 +79,8 @@ __global__ void routingMainKernel(KernelParams params) { expertSelected = laneIdx < params.mNumExpertsPerGroup; } auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert; - auto biasVal = expertSelected ? params.mPtrRoutingBias[threadExpert] : invalidScore; - + auto biasVal = + expertSelected ? static_cast(params.mPtrRoutingBias[threadExpert]) : invalidScoreFloat; // initialize the mPtrExpertCounts if (params.mPtrExpertCounts) { int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x; @@ -496,24 +496,24 @@ void runImpl(Data& data, void* stream) { // Maximum number of tokens supported by the kernel using a cooperative launch. int const maxTokensCoop = (numBlocksCoop * NumThreads * 64) / data.mTopK; - LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, - /*coopLaunch=*/false, routingMainKernel, numBlocks, NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingMainKernel, numBlocks, NumThreads, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1); if (data.mPtrPermutedIdxSize != nullptr) { if (useSingleCluster) { - LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, - /*coopLaunch=*/false, routingIndicesClusterKernel, - NumBlocksPerCluster, NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesClusterKernel, + NumBlocksPerCluster, NumThreads, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1); } else if (data.mNumTokens <= maxTokensCoop) { - LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, - /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, - NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, + NumThreads, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1); } else { const int32_t expandedIdxSize = data.mNumTokens * data.mTopK; @@ -528,16 +528,16 @@ void runImpl(Data& data, void* stream) { int const numBlocksOffsets = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, - /*coopLaunch=*/false, routingIndicesHistogramKernel, - numBlocksHistogram, NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); - LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, - /*coopLaunch=*/false, routingIndicesOffsetsKernel, - numBlocksOffsets, NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesHistogramKernel, + numBlocksHistogram, NumThreads, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1); + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, + NumThreads, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1); } } } diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index 931431fa2f..e1cc5c20c6 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -55,13 +55,18 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx, int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, - int32_t* numNonExitingCtas, btg::Dtype dtypeElt, bool useRoutingScalesOnInput, - bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream) { + int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias, + bool useRoutingScalesOnInput, bool useDeepSeekFp8, + RoutingMethodType routingMethodType, cudaStream_t stream) { if (routingMethodType == RoutingMethodType::DeepSeekV3) { FLASHINFER_CHECK(topK <= 8, "For DeepSeek routing method, must have topK <= 8"); FLASHINFER_CHECK(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4"); moe::dev::routing::routingDeepSeek::Data routingData; - routingData.mDtypeExpW = btg::Dtype::Bfloat16; + routingData.mDtypeExpW = + btg::Dtype::Bfloat16; // for DeepSeek, the expW is currently always bfloat16 + routingData.mDtypeBias = dtypeBias; // for DeepSeek, the bias can be bfloat16 or fp32 + + routingData.mDtypeScore = btg::Dtype::Fp32; // for DeepSeek, the score is currently always fp32 routingData.mUsePdl = true; // output: diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index 0eb426b0e0..ff701e4d72 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -123,6 +123,55 @@ namespace moe::dev { FLASHINFER_WARN("Unsupported dtypeExpW"); \ } +#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, extraFlag) \ + if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, float, extraFlag), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, __nv_bfloat16, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, float, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, float, float, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported dtypeExpW"); \ + } + +#define LAUNCH_ROUTING_DEEPSEEK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag) \ + if (extraFlag) { \ + LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, true); \ + } else { \ + LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, false); \ + } + #define LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ stream, extraFlag, forceFloatInput) \ if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) { \ diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index e182fc114d..cc000ed6d6 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -146,7 +146,8 @@ namespace routingDeepSeek { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Data : public DataBase { tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16}; - + tg::Dtype mDtypeBias{tg::Dtype::Bfloat16}; + tg::Dtype mDtypeScore{tg::Dtype::Fp32}; // // Grouped Gemm Launch Config Buffers // @@ -160,9 +161,10 @@ struct Data : public DataBase { bool mUseRoutingSoftmax; }; -template +template struct KernelParams : public KernelParamsBase { using InputT = InputT_; + using BiasT = BiasT_; using OutputT = OutputT_; static constexpr bool UseGroups = UseGroups_; @@ -173,7 +175,7 @@ struct KernelParams : public KernelParamsBase { // Note: this variable(mPtrExpertWeightsFull) might need to be added back for the low-latency // kernels for MoE in tllm-gen in the future - OutputT const* mPtrRoutingBias = nullptr; + BiasT const* mPtrRoutingBias = nullptr; int32_t mNumExpertGroups = 0; int32_t mNumExpertsPerGroup = 0; @@ -189,7 +191,7 @@ struct KernelParams : public KernelParamsBase { params.mPtrExpertIdx = (PackedScoreIdx*)data.mPtrExpertIdx; // params.mPtrExpertWeightsFull = static_cast(data.mPtrExpertWeightsFull); - params.mPtrRoutingBias = static_cast(data.mPtrRoutingBias); + params.mPtrRoutingBias = static_cast(data.mPtrRoutingBias); params.mNumExpertGroups = data.mNumExpertGroups; params.mNumExpertsPerGroup = data.mNumExperts / data.mNumExpertGroups; diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 5f066468e6..bff48fdb3c 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -113,8 +113,9 @@ class Runner { int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx, int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, - batchedGemm::trtllm::gen::Dtype dtypeElt, bool useRoutingScalesOnInput, - bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream); + batchedGemm::trtllm::gen::Dtype dtypeElt, batchedGemm::trtllm::gen::Dtype dtypeBias, + bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, + cudaStream_t stream); private: int32_t mTileTokensDim{8};