diff --git a/src/fastertransformer/models/gptj/GptJ.cc b/src/fastertransformer/models/gptj/GptJ.cc index 0382d8863..34d79a725 100644 --- a/src/fastertransformer/models/gptj/GptJ.cc +++ b/src/fastertransformer/models/gptj/GptJ.cc @@ -781,7 +781,7 @@ void GptJ::forward(std::unordered_map* output_tens invokeMaskPaddingTokens(masked_tokens_, input_tensors->at("input_lengths").getPtr(), // not_tiled - tiled_prompt_lengths_buf_, + has_prefix_prompt_ ? tiled_prompt_lengths_buf_ : (const int*)nullptr,, max_cache_seq_len, max_input_length + max_prefix_prompt_length, 0, diff --git a/src/fastertransformer/models/gptneox/GptNeoX.cc b/src/fastertransformer/models/gptneox/GptNeoX.cc index 2ce2dae7b..2bbc1390f 100644 --- a/src/fastertransformer/models/gptneox/GptNeoX.cc +++ b/src/fastertransformer/models/gptneox/GptNeoX.cc @@ -757,7 +757,7 @@ void GptNeoX::forward(std::unordered_map* output_t invokeMaskPaddingTokens(masked_tokens_, input_tensors->at("input_lengths").getPtr(), // not_tiled - tiled_prompt_lengths_buf_, + has_prefix_prompt_ ? tiled_prompt_lengths_buf_ : (const int*)nullptr,, max_cache_seq_len, max_input_length + max_prefix_prompt_length, 0,