diff --git a/src/fastertransformer/utils/cuda_fp8_utils.h b/src/fastertransformer/utils/cuda_fp8_utils.h index 5b681171c..6b3b737c2 100644 --- a/src/fastertransformer/utils/cuda_fp8_utils.h +++ b/src/fastertransformer/utils/cuda_fp8_utils.h @@ -20,6 +20,7 @@ #include #include #include +#include // #define FP8_MHA #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 diff --git a/src/fastertransformer/utils/cuda_utils.h b/src/fastertransformer/utils/cuda_utils.h index 331b6f297..b8a02ab5f 100644 --- a/src/fastertransformer/utils/cuda_utils.h +++ b/src/fastertransformer/utils/cuda_utils.h @@ -27,6 +27,7 @@ #include #include #include +#include #ifdef SPARSITY_ENABLED #include #endif