diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index d0fb0a197..5625ad93b 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -660,7 +660,10 @@ void invokeMaskedSoftmax(MaskedSoftmaxParam& param, cudaStream_t stream bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0; dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32); - if (block.x > 2048 && block.x <= 4096) { + if (block.x > 4096 && block.x <= 8192) { + LAUNCH_MAKSED_SOFTMAX(8); + } + else if (block.x > 2048 && block.x <= 4096) { LAUNCH_MAKSED_SOFTMAX(4) } else if (block.x > 1024) { @@ -697,8 +700,11 @@ void invokeMaskedSoftmax(MaskedSoftmaxParam<__nv_bfloat16, float>& param, cudaSt bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0; dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32); - - if (block.x > 2048 && block.x <= 4096) { + + if (block.x > 4096 && block.x <= 8192) { + LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 8); + } + else if (block.x > 2048 && block.x <= 4096) { LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 4); } else if (block.x > 1024) { @@ -730,7 +736,10 @@ void invokeMaskedSoftmax(MaskedSoftmaxParam<__nv_bfloat16, __nv_bfloat16>& param bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0; dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32); - if (block.x > 2048 && block.x <= 4096) { + if (block.x > 4096 && block.x <= 8192) { + LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 8); + } + else if (block.x > 2048 && block.x <= 4096) { LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 4); } else if (block.x > 1024) {