From 00b1eba645a11b47243fd8343476e8169ce27af0 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 30 Aug 2023 06:36:05 +0000 Subject: [PATCH 1/2] seq len supported up to 8K --- .../kernels/unfused_attention_kernels.cu | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index d0fb0a197..f24328054 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -660,7 +660,10 @@ void invokeMaskedSoftmax(MaskedSoftmaxParam& param, cudaStream_t stream bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0; dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32); - if (block.x > 2048 && block.x <= 4096) { + if (block.x > 4096 && block.x <= 8192) { + LAUNCH_MAKSED_SOFTMAX_(8); + } + else if (block.x > 2048 && block.x <= 4096) { LAUNCH_MAKSED_SOFTMAX(4) } else if (block.x > 1024) { @@ -697,8 +700,11 @@ void invokeMaskedSoftmax(MaskedSoftmaxParam<__nv_bfloat16, float>& param, cudaSt bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0; dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32); - - if (block.x > 2048 && block.x <= 4096) { + + if (block.x > 4096 && block.x <= 8192) { + LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 8); + } + else if (block.x > 2048 && block.x <= 4096) { LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 4); } else if (block.x > 1024) { @@ -730,7 +736,10 @@ void invokeMaskedSoftmax(MaskedSoftmaxParam<__nv_bfloat16, __nv_bfloat16>& param bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0; dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32); - if (block.x > 2048 && block.x <= 4096) { + if (block.x > 4096 && block.x <= 8192) { + LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 8); + } + else if (block.x > 2048 && block.x <= 4096) { LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 4); } else if (block.x > 1024) { From 4a2c57aa93054cdb294082251fe31144e750af7a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 30 Aug 2023 06:39:15 +0000 Subject: [PATCH 2/2] fix typo --- src/fastertransformer/kernels/unfused_attention_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index f24328054..5625ad93b 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -661,7 +661,7 @@ void invokeMaskedSoftmax(MaskedSoftmaxParam& param, cudaStream_t stream dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32); if (block.x > 4096 && block.x <= 8192) { - LAUNCH_MAKSED_SOFTMAX_(8); + LAUNCH_MAKSED_SOFTMAX(8); } else if (block.x > 2048 && block.x <= 4096) { LAUNCH_MAKSED_SOFTMAX(4)