diff --git a/csrc/flashinfer_xqa_binding.cu b/csrc/flashinfer_xqa_binding.cu index c812739c76..003a23a5f6 100644 --- a/csrc/flashinfer_xqa_binding.cu +++ b/csrc/flashinfer_xqa_binding.cu @@ -16,12 +16,32 @@ #include "tvm_ffi_utils.h" -void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize, - double qScale, TensorView output, +#if MLA_WRAPPER +void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q, +#if PAGED_KV_CACHE_LAYOUT == 1 + TensorView kCacheVLLM, TensorView vCacheVLLM, +#else + TensorView pool, +#endif + TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, + int64_t batchSize, TensorView kvCacheScale, TensorView semaphores, + TensorView scratch); + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper_mla, xqa_wrapper_mla); + +#else + +void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads, + int64_t slidingWinSize, double qScale, TensorView output, #if LOW_PREC_OUTPUT TensorView rcpOutScale, #endif - TensorView q, TensorView attentionSinks, TensorView pool, + TensorView q, tvm::ffi::Optional attentionSinks, +#if PAGED_KV_CACHE_LAYOUT == 1 + TensorView kCacheVLLM, TensorView vCacheVLLM, +#else + TensorView pool, +#endif TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, TensorView kvCacheScale, #if SPEC_DEC @@ -30,3 +50,5 @@ void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingW TensorView semaphores, TensorView scratch); TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper, xqa_wrapper); + +#endif diff --git a/csrc/xqa/gmma.cuh b/csrc/xqa/gmma.cuh new file mode 100644 index 0000000000..d1b2547fcd --- /dev/null +++ b/csrc/xqa/gmma.cuh @@ -0,0 +1,145 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include "cuda_hint.cuh" +#include "mha_stdheaders.cuh" +#include "utils.cuh" +#ifndef __CUDACC__ +#include +#endif +#include +#include + +namespace gmma { + +enum class SwizzleMode : uint64_t { kNONE = 0, k128 = 1, k64 = 2, k32 = 3 }; + +struct MatDesc { + uint64_t addr : 16; + uint64_t dimKOffset : 16; + uint64_t dimMNOffset : 16; + uint64_t pad0 : 1; + uint64_t baseOffset : 3; + uint64_t pad1 : 10; + SwizzleMode swizzle : 2; + + enum class Raw : uint64_t {}; + + [[nodiscard]] __device__ inline MatDesc withAddr(void const* data) const { + MatDesc ret = *this; + ret.addr = encode(__cvta_generic_to_shared(data)); + return ret; + } + + static __device__ inline uint32_t encode(uint32_t val) { return (val & 0x3FFFFU) >> 4; } + + __device__ inline bool operator==(MatDesc const& other) const { return raw() == other.raw(); } + + __device__ inline Raw const& raw() const { + static_assert(sizeof(MatDesc) == 8); + return reinterpret_cast(*this); + } + + static __device__ inline MatDesc fromRaw(Raw const& raw) { + return reinterpret_cast(raw); + } +}; + +static_assert(sizeof(MatDesc) == 8); + +[[nodiscard]] __device__ inline MatDesc::Raw addAddr(MatDesc::Raw base, void const* data) { + assert((uint32_t(__cvta_generic_to_shared(data)) & ~0x3FFFFU) == 0); + MatDesc::Raw ret = base; + auto& u32x2 = reinterpret_cast(ret); + u32x2[0] += static_cast(__cvta_generic_to_shared(data)) >> 4; + return ret; +} + +__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset, + uint32_t dimMNByteOffset, void const* patternStartAddr, + SwizzleMode swizzleMode) { + uint32_t const patternAddr = __cvta_generic_to_shared(patternStartAddr); + uint32_t const baseAlign = [&]() -> uint32_t { + switch (swizzleMode) { + case SwizzleMode::kNONE: + return 1; + case SwizzleMode::k128: + return 1024; + case SwizzleMode::k64: + return 512; + case SwizzleMode::k32: + return 256; + } + asm volatile("trap;\n"); + return 0; + }(); + uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7)); + return MatDesc{ + /*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)), + /*dimKOffset=*/MatDesc::encode(dimKByteOffset), + /*dimMNOffset=*/MatDesc::encode(dimMNByteOffset), + /*pad0=*/0, + /*baseOffset=*/baseOffset, + /*pad1=*/0, + /*swizzle=*/swizzleMode, + }; +} + +__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset, + uint32_t dimMNByteOffset, SwizzleMode swizzleMode) { + return makeMatDesc(data, dimKByteOffset, dimMNByteOffset, data, swizzleMode); +} + +inline constexpr uint32_t instM = 64; + +template +inline constexpr uint32_t instK = 32 / sizeof(MathElem); + +inline constexpr uint32_t instNBase = 8; + +// for both a and b, outer-dim is gemm-K and inner-dim is gemm-M or gemm-N +// acc is used as both input and output. +template +__device__ void mma_async_shmA(float (&acc)[exactDiv(n, instNBase)][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal); +template +__device__ void mma_async_regA(float (&acc)[exactDiv(n, instNBase)][2][2], + uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal); + +__device__ inline void fence() { asm volatile("wgmma.fence.sync.aligned;\n"); } + +__device__ inline void commit_group() { asm volatile("wgmma.commit_group.sync.aligned;\n"); } + +template +__device__ inline void wait_group() { + asm volatile("wgmma.wait_group.sync.aligned %0\n; " ::"n"(targetNbInFlightGroups)); +} + +template +constexpr SwizzleMode getSwizzleMode(Array2D const&) { + constexpr auto rowBytes = Array2D::rowBytes; + if constexpr (!swizzle) { + return SwizzleMode::kNONE; + } + if constexpr (rowBytes % 128 == 0) { + return SwizzleMode::k128; + } else if constexpr (rowBytes == 64) { + return SwizzleMode::k64; + } else { + static_assert(rowBytes == 32); + return SwizzleMode::k32; + } +} +} // namespace gmma + +#include "gmma_impl.cuh" diff --git a/csrc/xqa/gmma_impl.cuh b/csrc/xqa/gmma_impl.cuh new file mode 100644 index 0000000000..b9515ddea9 --- /dev/null +++ b/csrc/xqa/gmma_impl.cuh @@ -0,0 +1,4971 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include "cuda_hint.cuh" +#include "mha_stdheaders.cuh" +#ifndef __CUDACC__ +#include +#endif +#include +#include + +namespace gmma { +// cog template. Do code generation with: pip install cogapp; cog -r $filename + +// clang-format off +/*[[[cog +import cog +reg_list = lambda beg,end: ", ".join([f"%{i}" for i in range(beg, end)]) +acc_placeholder = lambda n: "{%s}" % reg_list(0, n//2) +acc_registers = lambda n: "\n , ".join([f'"+f"(acc[{i}][0][0]), "+f"(acc[{i}][0][1]), "+f"(acc[{i}][1][0]), "+f"(acc[{i}][1][1])' for i in range(n//8)]) +ptx_eol = "\\n" +n_list = [8, 16, 24, 32, 64, 128, 256] +for n in n_list: + cog.outl(f''' +template<> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, {n}, false, false>(float(&acc)[{n//8}][2][2], MatDesc::Raw descA, +MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} + +template<> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, {n}, false, false>(float(&acc)[{n//8}][2][2], uint32_t +const(&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} +''') + +for n in n_list: + for transA in [0, 1]: + for transB in [0, 1]: + for t,s in [('half', 'f16'), ('__nv_bfloat16', 'bf16')]: + cog.outl(f''' +template<> +__device__ inline void mma_async_shmA<{t}, {n}, {transA}, {transB}>(float(&acc)[{n//8}][2][2], MatDesc::Raw descA, +MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1, {transA}, {transB};{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1, {transA}, {transB};{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} +''') + if transA == 0: + cog.outl(f''' +template<> +__device__ inline void mma_async_regA<{t}, {n}, {transA}, {transB}>(float(&acc)[{n//8}][2][2], uint32_t +const(&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1, {transB};{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1, {transB};{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} +''') +]]]*/ +// clang-format on + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 8, false, false>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 8, false, false>(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 16, false, false>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 16, false, false>(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 24, false, false>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 24, false, false>(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 32, false, false>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 32, false, false>(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 64, false, false>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 64, false, false>(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 128, false, false>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 128, false, false>( + float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 256, false, false>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 256, false, false>( + float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[1][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 0, 0>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 8, 0, 0>(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[1][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 0, 1>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 8, 0, 1>(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[1][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 1, 0>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[1][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 1, 1>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[2][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 0, 0>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 16, 0, 0>(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[2][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 0, 1>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 16, 0, 1>(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[2][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 1, 0>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[2][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 1, 1>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[3][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 0>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 0>(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[3][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 1>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 1>(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[3][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 0>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[3][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 1>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[4][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 0, 0>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 32, 0, 0>(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[4][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 0, 1>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 32, 0, 1>(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[4][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 1, 0>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[4][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 1, 1>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[8][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 0, 0>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 64, 0, 0>(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[8][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 0, 1>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 64, 0, 1>(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[8][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 1, 0>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[8][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 1, 1>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[16][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[16][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 0, 0>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 128, 0, 0>(float (&acc)[16][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[16][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[16][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 0, 1>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 128, 0, 1>(float (&acc)[16][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[16][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 1, 0>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[16][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 1, 1>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[32][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[32][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 0, 0>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 256, 0, 0>(float (&acc)[32][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[32][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[32][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 0, 1>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 256, 0, 1>(float (&acc)[32][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[32][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 1, 0>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[32][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 1, 1>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +//[[[end]]] +} // namespace gmma diff --git a/csrc/xqa/mha.cu b/csrc/xqa/mha.cu index c896017780..4e276bf0b4 100644 --- a/csrc/xqa/mha.cu +++ b/csrc/xqa/mha.cu @@ -92,7 +92,7 @@ __constant__ constexpr uint32_t cacheVTileSeqLen = 32; #if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200 constexpr uint32_t preferedKHeadPartBytes = 64; __constant__ constexpr uint32_t cacheVTileSeqLen = 32; -#elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 900 +#elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 constexpr uint32_t preferedKHeadPartBytes = 128; __constant__ constexpr uint32_t cacheVTileSeqLen = 64; #else @@ -476,7 +476,7 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy col + actualQSeqLen < nbValidCols ? true : packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart)); - acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : -INFINITY; + acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax; } } } @@ -2659,7 +2659,12 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 #if LOW_PREC_OUTPUT float const* rcpOutScale, #endif - InputHead const* q, float const* attentionSinks, GMemCacheHead* pool, + InputHead const* q, float const* attentionSinks, +#if PAGED_KV_CACHE_LAYOUT == 1 + GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, +#else + GMemCacheHead* pool, +#endif KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float const* __restrict__ kvCacheScale, @@ -2691,7 +2696,12 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); #if USE_PAGED_KV_CACHE uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); +#if PAGED_KV_CACHE_LAYOUT == 1 + KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, + maxNbPagesPerSeq}; +#else KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; +#endif cudaLaunchKernelEx(&launchCfg, kernel_mha, #if SPEC_DEC qSeqLen, nbKHeads, headGrpSize, qCuSeqLens, @@ -2709,11 +2719,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 #if SPEC_DEC mask, #endif - attentionSinks, cacheList, -#if BEAM_WIDTH > 1 - beamSearchParams, -#endif - batchSize, kvCacheScale, semaphores, scratch); + attentionSinks, cacheList, batchSize, kvCacheScale, semaphores, scratch); #else KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; #ifndef NDEBUG diff --git a/csrc/xqa/mha.h b/csrc/xqa/mha.h index 77d8a2fd2f..d50c081b6a 100644 --- a/csrc/xqa/mha.h +++ b/csrc/xqa/mha.h @@ -135,7 +135,12 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 #if LOW_PREC_OUTPUT float const* rcpOutScale, #endif - InputHead const* q, float const* attentionSinks, GMemCacheHead* pool, + InputHead const* q, float const* attentionSinks, +#if PAGED_KV_CACHE_LAYOUT == 1 + GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, +#else + GMemCacheHead* pool, +#endif KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float const* __restrict__ kvCacheScale, @@ -186,15 +191,39 @@ void launchHopperF8MHA( #endif uint32_t* semaphores, void* scratch, cudaStream_t stream); +void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, + uint32_t slidingWinSize, float qScale, OutputHead* output, +#if LOW_PREC_OUTPUT + float const* rcpOutScale, +#endif + InputHead const* q, float const* attentionSinks, +#if PAGED_KV_CACHE_LAYOUT == 1 + GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, +#else + GMemCacheHead* pool, +#endif + KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, + uint32_t const* seqLen, uint32_t batchSize, + float const* __restrict__ kvCacheScale, +#if SPEC_DEC + uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, +#endif + uint32_t* semaphores, void* scratch, cudaStream_t stream); + void launchMLA( cudaDeviceProp const& prop, uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed float qScale, OutputHead* output, InputHead const* q, #if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, +#else GMemCacheHead* pool, // global pool of pages +#endif KVCachePageIndex const* kvCachePageList, // device pointer. shape: - // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] + // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or + // [batchSize][maxNbPagesPerSeq] (Layout 1) #else GMemKVCacheHead* kvCacheData, #endif @@ -203,6 +232,24 @@ void launchMLA( // Used only for int8/fp8 KV cache. uint32_t* semaphores, void* scratch, cudaStream_t stream); +void launchMLAFlashInfer( + uint32_t multiProcessorCount, + uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed + float qScale, OutputHead* output, InputHead const* q, +#if PAGED_KV_CACHE_LAYOUT == 1 + GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, +#else + GMemCacheHead* pool, // global pool of pages +#endif + KVCachePageIndex const* + kvCachePageList, // device pointer. shape: + // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or + // [batchSize][maxNbPagesPerSeq] (Layout 1) + uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, + float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. + // Used only for int8/fp8 KV cache. + uint32_t* semaphores, void* scratch, cudaStream_t stream); + #if STATIC_NB_K_HEADS constexpr uint32_t nbKHeads = NB_K_HEADS; diff --git a/csrc/xqa/mha_sm90.cu b/csrc/xqa/mha_sm90.cu new file mode 100644 index 0000000000..acbda1fa88 --- /dev/null +++ b/csrc/xqa/mha_sm90.cu @@ -0,0 +1,3276 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "cuda_hint.cuh" +#include "defines.h" +#if !(IS_MLA) +#include "barriers.cuh" +#include "utils.cuh" +#include "utils.h" + +#if SPEC_DEC +#define Q_HEADS_PER_CTA 64 +#include "specDec.h" +#endif + +#ifndef GENERATE_CUBIN +#include + +#include "hostUtils.h" +#include "tensorMap.h" +#endif +#include "gmma.cuh" +#include "mha.h" +#include "mhaUtils.cuh" +#include "mha_stdheaders.cuh" +#include "tma.h" + +#define DBG_PRINT 0 + +#ifdef SPEC_Q_SEQ_LEN +static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN is only supported for SPEC_DEC"); +constexpr uint32_t specDecQLen = SPEC_Q_SEQ_LEN; +static_assert(specDecQLen * headGrpSize <= 32, "SPEC_Q_SEQ_LEN macro value is too large"); +#define SWAP_AB 1 +#else +#define SWAP_AB (!SPEC_DEC) +#endif + +#define IS_SUPPORTED_F16_CASE \ + (CACHE_ELEM_ENUM == 0 && !SPEC_DEC && SWAP_AB && !USE_INPUT_KV && !LOW_PREC_OUTPUT) + +inline constexpr bool swapAB = SWAP_AB; + +#pragma region Config + +static_assert((inputElemSize == cacheElemSize && mha::is_same_v) || + inputElemSize > cacheElemSize); +using MathElem = + mha::conditional_t<(inputElemSize > cacheElemSize && mha::is_same_v), + InputElem, CacheElem>; + +constexpr uint32_t gmmaWarpsPerGrp = 4; +constexpr uint32_t gmmaWarpGrpSize = warp_size * gmmaWarpsPerGrp; +constexpr uint32_t gemm0NbGmmaGrps = 1; +constexpr uint32_t gemm0NbThrds = gmmaWarpGrpSize * gemm0NbGmmaGrps; +constexpr uint32_t gemm0NbWarps = gmmaWarpsPerGrp * gemm0NbGmmaGrps; +#if SPEC_DEC && !SWAP_AB +inline constexpr uint32_t ctaNbQHeads = Q_HEADS_PER_CTA; +inline constexpr uint32_t inputTokensPerCta = ctaNbQHeads / headGrpSize; +constexpr uint32_t ctaNbValidQHeads = ctaNbQHeads; +#elif SPEC_DEC && SWAP_AB +inline constexpr uint32_t inputTokensPerCta = specDecQLen; +inline constexpr uint32_t ctaNbValidQHeads = headGrpSize * inputTokensPerCta; +inline constexpr uint32_t ctaNbQHeads = []() { + static_assert(ctaNbValidQHeads <= 32, "ctaNbValidQHeads cannot exceed 32"); + if constexpr (ctaNbValidQHeads <= 8) { + return 8; + } + if constexpr (ctaNbValidQHeads <= 16) { + return 16; + } + return 32; +}(); +#else +inline constexpr uint32_t ctaNbValidQHeads = headGrpSize * beamWidth; +inline constexpr uint32_t ctaNbQHeads = roundUp(ctaNbValidQHeads, swapAB ? 8U : 64U); +inline constexpr uint32_t inputTokensPerCta = 1; +#endif +constexpr uint32_t gemm0WarpGrpTileNbTokens = 64; +inline constexpr uint32_t gemm0CtaTileNbTokens = gemm0WarpGrpTileNbTokens * gemm0NbGmmaGrps; +constexpr uint32_t gemm1NbGmmaGrps = 1; +constexpr uint32_t gemm1NbThrds = gmmaWarpGrpSize * gemm1NbGmmaGrps; +constexpr uint32_t gemm1NbWarps = gmmaWarpsPerGrp * gemm1NbGmmaGrps; +constexpr uint32_t gemm1CtaTileNbTokens = gemm0CtaTileNbTokens; +constexpr uint32_t mathHeadBytes = sizeof(Vec); +constexpr uint32_t nbIOWarps = 4; +constexpr uint32_t nbIOThrds = warp_size * nbIOWarps; +constexpr uint32_t multiBlockMinNbTilesPerCta = 1; // 3; // @fixme: need tuning +constexpr uint32_t multiBlockMinNbTiles = multiBlockMinNbTilesPerCta * 2; +constexpr uint32_t nbWarps = gemm0NbWarps + gemm1NbWarps + nbIOWarps; + +constexpr uint32_t cacheHeadPartBytes = mha::min(paddedCacheHeadBytes, 128U); +constexpr uint32_t cacheHeadNbParts = + exactDiv(paddedCacheHeadBytes, cacheHeadPartBytes); // @fixme: support divUp in the future +constexpr uint32_t cacheHeadPartElems = exactDiv(headElems, cacheHeadNbParts); +constexpr uint32_t swizzleBytes = cacheHeadPartBytes; +static_assert(swizzleBytes == 128 || swizzleBytes == 64 || swizzleBytes == 32); + +constexpr bool needInputCvt = + inputElemSize > cacheElemSize&& mha::is_same_v; +constexpr bool needCacheCvt = inputElemSize > cacheElemSize&& mha::is_same_v; +static_assert(needInputCvt || needCacheCvt || mha::is_same_v); + +using ShmQWiseVec = Vec; + +constexpr uint32_t qPartBytes = mha::min(mathHeadBytes, 128U); +constexpr uint32_t nbQParts = exactDiv(mathHeadBytes, qPartBytes); +constexpr uint32_t grainsPerQPart = exactDiv(qPartBytes, grainBytes); + +constexpr uint32_t xPartBytes = mha::min(cacheElemSize * gemm0CtaTileNbTokens, 128U); +constexpr uint32_t nbXParts = exactDiv(cacheElemSize * gemm0CtaTileNbTokens, xPartBytes); +constexpr uint32_t grainsPerXPart = exactDiv(xPartBytes, grainBytes); +constexpr uint32_t cacheElemsPerGrain = exactDiv(grainBytes, cacheElemSize); + +constexpr uint32_t grainsPerIOHead = exactDiv(ioHeadBytes, grainBytes); +constexpr uint32_t grainsPerPaddedInputHead = exactDiv(paddedInputHeadBytes, grainBytes); + +#if USE_BEAM_SEARCH +constexpr uint32_t beamSearchGemm0CtaTileNbTokens = exactDiv(gemm0CtaTileNbTokens, beamWidth); +#endif + +using PaddedOutHead = PaddedInputHead; + +#pragma endregion Config + +struct alignas(128) SharedMem { + using KBuffer = Array2D; + static constexpr uint32_t nbKBuf = 2; + KBuffer k[nbKBuf]; // as is loaded from global mem. + using XBuffer = Vec, nbXParts>; + static constexpr uint32_t nbXBuf = + 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens + ? 1 + : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens)); + using VBuffer = + Vec, + cacheHeadNbParts>; +#if !SWAP_AB + using VTBuffer = + Array2D; +#endif + static constexpr uint32_t nbVBuf = 2; +#if CACHE_ELEM_ENUM == 0 + using OutSwizzleBuf = Array2D; +#elif CACHE_ELEM_ENUM == 2 + using OutSwizzleBuf = Array2D, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>; +#endif + static_assert(nbXBuf == nbVBuf); + + union ReusedXVOutSwizzleBuf { + struct XV { + XBuffer x; + VBuffer v; +#if !SWAP_AB + VTBuffer vt; +#endif + // @fixme: also put xColMax and xColSum here + } xv; + + OutSwizzleBuf outSwizzle; + } reusedXVOutSwizzleBuf[nbXBuf]; + + static_assert(sizeof(OutSwizzleBuf) <= sizeof(SharedMem::ReusedXVOutSwizzleBuf::XV), + "need to use split output to avoid excessive shared memory usage"); + + __device__ inline XBuffer& xBuf(uint32_t i) { return reusedXVOutSwizzleBuf[i].xv.x; } + + __device__ inline VBuffer& vBuf(uint32_t i) { return reusedXVOutSwizzleBuf[i].xv.v; } +#if !SWAP_AB + __device__ inline VTBuffer& vtBuf(uint32_t i) { return reusedXVOutSwizzleBuf[i].xv.vt; } +#endif + __device__ inline OutSwizzleBuf& outSwizzleBuf(uint32_t i) { + return reusedXVOutSwizzleBuf[i].outSwizzle; + } + + using QBuffer = Vec, nbQParts>; + QBuffer q; // For gmma math. Conversion done if needed. + + // @fixme: move these into reusedXVOutSwizzleBuf +#if SWAP_AB + ShmQWiseVec xColMax[nbXBuf]; + ShmQWiseVec xColSum[nbXBuf][gemm0NbWarps]; +#else + ShmQWiseVec xRowMax[nbXBuf]; + ShmQWiseVec xRowSum[nbXBuf]; +#endif + + ShmQWiseVec gemm0CurrentSeqMax; + // col sum and max for the current gemm1 acc. Use shared memory to save some registers. register + // storage will be 8x duplicate for swapAB and 4x duplicate for non-swapAB. + ShmQWiseVec gemm1AccColMax; + ShmQWiseVec gemm1AccColSum; + +#if USE_PAGED_KV_CACHE + static constexpr uint32_t nbPagesPerTile = + gemm0CtaTileNbTokens >= tokensPerPage ? exactDiv(gemm0CtaTileNbTokens, tokensPerPage) : 1; + Vec pages[2]; // one for K and one for V +#endif + + // mem barriers + + CtaBarrierPair qBar; + CtaBarrierPair kBar[nbKBuf]; + CtaBarrierPair vBar[nbVBuf]; +#if !SWAP_AB + CtaBarrierPair vtBar[nbVBuf]; +#endif + CtaBarrierPair xBar[nbXBuf]; + + // used internally in the gemm0 warp group + // @fixme: use separate arrive and wait for all usage + CtaBarrier gemm0WarpGrpBar; + + // used internally in the gemm1 warp group + // @fixme: use separate arrive and wait for all usage + CtaBarrier gemm1WarpGrpBar; + + bool isLastCta; +}; + +CUBIN_EXPORT __device__ constexpr uint32_t smemSize = sizeof(SharedMem); +#ifdef __CUDA_ARCH__ +static_assert(smemSize < kMAX_SMEM_SIZE); +#endif + +constexpr uint32_t nbQLdWarps = needInputCvt ? nbIOWarps - 2 : 1; +constexpr uint32_t nbQLdThrds = warp_size * nbQLdWarps; + +#if CACHE_ELEM_ENUM == 0 || CACHE_ELEM_ENUM == 2 +template +struct F16QToF8Converter { + static_assert(inputElemSize == 2); + using F16Vec = Vec; +#if CACHE_ELEM_ENUM == 0 + using ShmVec = F16Vec; +#elif CACHE_ELEM_ENUM == 2 + using F8Vec = Vec; + using ShmVec = F8Vec; +#endif + + static constexpr uint32_t grainsPerPaddedInputHead = exactDiv(paddedInputHeadBytes, grainBytes); + static constexpr uint32_t grainsPerPaddedInputQHeadGrp = grainsPerPaddedInputHead * headGrpSize; +#if !(SPEC_DEC) + static constexpr uint32_t totalGrains = grainsPerPaddedInputQHeadGrp * beamWidth; +#else + static_assert(beamWidth == 1); + static constexpr uint32_t totalGrains = grainsPerPaddedInputQHeadGrp * inputTokensPerCta; +#endif + static constexpr uint32_t nbIters = divUp(totalGrains, nbThrds); + + using RegData = Vec; + + static __device__ RegData load(uint32_t tid, TinyPtr const& src, + uint32_t const nbKHeads /*for beam search and spec dec*/, + uint32_t nbTokens); + static __device__ void store(uint32_t tid, SharedMem::QBuffer& dst, RegData const& data); +}; +#endif // CACHE_ELEM_ENUM + +struct KVTilePartLoader { + static constexpr uint32_t nbParts = cacheHeadNbParts; + static constexpr uint32_t partElems = exactDiv(headElems, nbParts); + +#if USE_PAGED_KV_CACHE + static_assert(gemm0CtaTileNbTokens % tokensPerPage == 0 || + tokensPerPage % gemm0CtaTileNbTokens == 0); + static constexpr uint32_t nbPagesPerTile = SharedMem::nbPagesPerTile; +#endif + + uint32_t const nbKHeads; + KVCacheList const& cacheList; + uint32_t const idxReq; + uint32_t const idxHeadGrp; + + CUtensorMap const& tensorMap; +#if USE_PAGED_KV_CACHE + uint32_t const nbPages; // for bound check + Vec& pages; + uint32_t idxTileRef; // idxTile used to load the pages +#endif + uint32_t const baseOffset; + + __device__ KVTilePartLoader(bool isK, uint32_t nbKHeads, + KVCacheList const& cacheList, uint32_t idxReq, + uint32_t idxHeadGrp, CUtensorMap const& tensorMap +#if USE_PAGED_KV_CACHE + , + uint32_t nbPages, Vec& pageBuf +#endif + ); + // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache + template + __device__ void loadData( + Array2D& dst, + uint32_t idxTile, uint32_t idxPart, CtaBarrier& bar); + + __device__ void loadPages(uint32_t idxTile); + __device__ GMemKVCacheHead& getHead(uint32_t pos); +}; + +using GmmaAccCoreMat = Array2D; +template +using GmmaAcc = + Array2D; + +inline constexpr uint32_t gemm0M = (swapAB ? gemm0CtaTileNbTokens : ctaNbQHeads); +inline constexpr uint32_t gemm0N = (swapAB ? ctaNbQHeads : gemm0CtaTileNbTokens); + +using Gemm0Acc = GmmaAcc; + +#if SWAP_AB +using RegColWiseVec = Vec, Gemm0Acc::cols>; +using UniformNeedRescaleMask = Vec; +using RegSeqWiseVec = RegColWiseVec; +#else +using RegRowWiseVec = Vec, Gemm0Acc::rows>; +using UniformNeedRescaleMask = + Vec; +using RegSeqWiseVec = RegRowWiseVec; +#endif + +#if SPEC_DEC + +__device__ inline uint32_t getInputSeqLen(SpecDecParams const& params, uint32_t idxReq) { + return (params.qCuSeqLens == nullptr) ? params.qSeqLen + : params.qCuSeqLens[idxReq + 1] - params.qCuSeqLens[idxReq]; +} + +__device__ inline uint32_t getInputTokOffset(SpecDecParams const& params, uint32_t idxReq) { + return (params.qCuSeqLens == nullptr) ? params.qSeqLen * idxReq : params.qCuSeqLens[idxReq]; +} + +struct SpecDec { + static inline constexpr uint32_t tileSize = gemm0CtaTileNbTokens; + static inline constexpr uint32_t ctaMaxQSeqLen = (ctaNbQHeads / headGrpSize); + using TileMaskRow = Vec; + + __device__ inline SpecDec(SpecDecParams const& params, uint32_t idxReq, uint32_t idxInputSubSeq, + uint32_t seqLen) + : params(params), idxInputSubSeq(idxInputSubSeq), seqLen(seqLen) { + inputSeqLen = getInputSeqLen(params, idxReq); + baseOffset = divUp(params.qSeqLen, 32U) * + (getInputTokOffset(params, idxReq) + ctaMaxQSeqLen * idxInputSubSeq); + } + + __device__ inline uint32_t unmaskedSeqLen() const { return seqLen - inputSeqLen; } + + __device__ inline bool needMask(uint32_t idxTile, uint32_t idxQTokInCta) const { + return tileSize * (idxTile + 1) > unmaskedSeqLen() && + ctaMaxQSeqLen * idxInputSubSeq + idxQTokInCta < inputSeqLen && params.mask != nullptr; + } + + __device__ inline int32_t maskColBeg(uint32_t idxTile) const { + int32_t const convergedSeqLen = int32_t(unmaskedSeqLen()); + return static_cast(exactDiv(tileSize, 32) * idxTile) - + static_cast(divUp(convergedSeqLen, 32)); + } + + __device__ inline TileMaskRow loadTileMaskRow(uint32_t idxTile, uint32_t idxQTokInCta) const { + assert(needMask(idxTile, idxQTokInCta)); + constexpr uint32_t nbOrigElems = TileMaskRow::size + 1; + Vec orig; + + int32_t const cols = divUp(params.qSeqLen, 32); + uint32_t const rowOffset = baseOffset + idxQTokInCta * cols; + int32_t const colBeg = maskColBeg(idxTile); +#pragma unroll + for (int32_t i = 0; i < int32_t(nbOrigElems); i++) { + int32_t const idx = colBeg + i; + orig[i] = inRange(idx, 0, cols) ? params.mask[rowOffset + idx] : (idx < 0 ? ~0U : 0U); + } + TileMaskRow mask; + uint32_t const shift = (32 - unmaskedSeqLen() % 32) % 32; +#pragma unroll + for (uint32_t i = 0; i < TileMaskRow::size; i++) { + asm("shf.r.clamp.b32 %0, %1, %2, %3;\n" + : "=r"(mask[i]) + : "r"(orig[i]), "r"(orig[i + 1]), "r"(shift)); + } + return mask; + } + + SpecDecParams const& params; + uint32_t const idxInputSubSeq; + uint32_t const seqLen; + uint32_t inputSeqLen; + uint32_t baseOffset; +}; + +__device__ void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + int32_t tok0WinBeg, +#endif + uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank); +#endif + +#if SWAP_AB +__device__ RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, + Gemm0Acc const& src); +__device__ void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, + uint32_t validRowEnd); +__device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax); +__device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src); +__device__ void storeGemm0AccToShm(uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, + CtaBarrier& barConsumed, Gemm0Acc const& acc); +__device__ RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec); +__device__ RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound); +#else +__device__ RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, ShmQWiseVec& smemColMax, + Gemm0Acc const& src); +__device__ void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd); +__device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegRowWiseVec const& colMax); +__device__ RegRowWiseVec computeWarpRowSum(Gemm0Acc& src); +__device__ void storeGemm0AccToShm(uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, + CtaBarrier& barConsumed, Gemm0Acc const& acc); +__device__ RegRowWiseVec loadShmRowWiseVecWithDup(uint32_t warpRank, ShmQWiseVec const& smemVec); +__device__ void storeShmRowWiseVec(uint32_t warpRank, ShmQWiseVec& smemVec, + RegRowWiseVec const& regVec); +#endif + +using RegMatAFrag = Array2D, 1, 2>; +constexpr uint32_t gemm1NbGmmaInstK = exactDiv(gemm1CtaTileNbTokens, gmma::instK); + +#if SWAP_AB +constexpr uint32_t gemm1NbGmmaInstM = exactDiv(headElems, gmma::instM); +__device__ Vec loadVTileTransposed(uint32_t warpRank, uint32_t lane, + SharedMem::VBuffer const& smemV, + uint32_t idxGmmaInstK); +using Gemm1Acc = GmmaAcc; +__device__ void rescaleGemm1AccForNewColMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXColMax, + ShmQWiseVec const (&shmXColSum)[gemm0NbWarps], + ShmQWiseVec& shmAccColMax, Gemm1Acc& acc, + ShmQWiseVec& shmAccColSum, + CtaBarrier& gemm1WarpGrpBar); +template +__device__ void finalizeAndWriteOut_sync( + uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, + Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, ShmQWiseVec const& accColSum, + ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, + uint32_t nbKHeads = 0 /* only for final result in spec dec. */); +#else +__device__ void transposeVTile(uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, + SharedMem::VBuffer const& src); +using Gemm1Acc = GmmaAcc; +__device__ void rescaleGemm1AccForNewRowMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXRowMax, + ShmQWiseVec const(&shmXRowSum), + ShmQWiseVec& shmAccRowMax, Gemm1Acc& acc, + ShmQWiseVec& shmAccRowSum); +template +__device__ void finalizeAndWriteOut_sync( + uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, + float xvoScale, ShmQWiseVec const& accColSum, + uint32_t nbKHeads /* only for final result in spec dec. set to 1 for workspace*/, + uint32_t ctaNbValidTokens); +#endif + +inline constexpr uint32_t ropeNbPairsPerThrdImpl(uint32_t nbThrds) { + auto const val = divUp(exactDiv(validElemsPerHead, 2), nbThrds); + assert(val <= 32); + return val <= 2 ? val : (val <= 4 ? 4 : (val <= 8 ? 8 : (val <= 16 ? 16 : 32))); +} + +template +inline constexpr uint32_t ropeNbPairsPerThrd = ropeNbPairsPerThrdImpl(nbThrds); + +template +__device__ Vec, ropeNbPairsPerThrd> loadHead( + Vec const& head, uint32_t tid); +template +__device__ mha::conditional_t, 2>, + Vec, nbPairsPerThrd>> +applyRoPE(Vec, nbPairsPerThrd> const& data, + Vec, nbPairsPerThrd> const& ropeCosSin); +template +__device__ void storeRotatedPairsForKV( + GMemCacheHead& dst, + mha::conditional_t>, 2>, + Vec, ropeNbPairsPerThrd>> const& src, + uint32_t tid); +template +__device__ void storeRotatedPairsForQ( + SharedMem::QBuffer& dst, + mha::conditional_t>, 2>, + Vec, ropeNbPairsPerThrd>> const& src, + uint32_t row, uint32_t tid); + +class ScratchMem { + public: + struct alignas(8) SumMax { + float sum; + float max; + }; + + using ColWiseVec = Vec; + + HOST_DEVICE_FUNC ScratchMem(void* scratch, uint32_t maxTotalNbSubSeq, uint32_t nbInputSeqSplit) + : mScratch{static_cast(scratch)} { + uint32_t const nbChunks = maxTotalNbSubSeq * nbInputSeqSplit; + Segmenter segmenter; + constexpr uint32_t alignment = sizeof(Vec); + mRowSumMax = segmenter.template newSeg(nbChunks, alignment); + mTokens = segmenter.template newSeg>(nbChunks, alignment); + } + + HOST_DEVICE_FUNC TinyPtr rowSumMax() const { return makePtr(mRowSumMax); } + + HOST_DEVICE_FUNC TinyPtr> tokens() const { + return makePtr>(mTokens); + } + + private: + template + HOST_DEVICE_FUNC TinyPtr makePtr(uint32_t offset) const { + return TinyPtr{mScratch, offset}.template cast(); + } + + private: + mha::byte* mScratch; + // offsets + uint32_t mRowSumMax; + uint32_t mTokens; +}; + +struct MultiBlockSMem { + using ColWiseVec = ScratchMem::ColWiseVec; + static constexpr uint32_t nbBuf = useSpecDec ? 2 : 4; + static constexpr uint32_t nbIOWarps = nbBuf; + using Elem = InputElem; + using Head = Vec; + Vec, nbBuf> tokens; + Vec rowSumMax; + Vec barriers; +}; + +#ifndef NDEBUG +namespace dbg { +template +__device__ void printAcc(CtaBarrier& warpGrpBar, uint32_t warpRank, + Array2D const& acc) { + for (int m = 0; m < nbGmmaInstM; m++) { + for (int w = 0; w < 4; w++) { + if (warpRank == w) { + for (int a = 0; a < 2; a++) { + for (int b = 0; b < 8; b++) { + for (int n = 0; n < nbGmmaInstNBase; n++) { + for (uint32_t i = 0; i < 4; i++) { + if (laneId() == b * 4 + i) { + printf("%f, %f, ", acc(m, n)(a, 0), acc(m, n)(a, 1)); + } + __syncwarp(); + } + } + if (laneId() == 0) { + printf("\n"); + } + __syncwarp(); + } + if (laneId() == 0) { + printf("\n"); + } + __syncwarp(); + } + } + warpGrpBar.arrive_and_wait(); + } + } +} + +__device__ void printShmColWiseVec(ShmQWiseVec const& vec) { + for (uint32_t i = 0; i < vec.size; i++) { + printf("%f, ", vec[i]); + } + printf("\n"); +} + +template +__device__ void printArray2D(Array2D const& src) { + for (uint32_t i = 0; i < rows; i++) { + for (uint32_t j = 0; j < cols; j++) { + T const val = src.template at(i, j); + for (uint32_t k = 0; k < exactDiv(sizeof(T), sizeof(Elem)); k++) { + printf("%f, ", float(reinterpret_cast(&val)[k])); + } + } + printf("\n"); + } +} +} // namespace dbg +#endif + +CUBIN_EXPORT __device__ constexpr XQAKernelType kernelType = + XQAKernelType::kHOPPER_WARP_SPECIALIZED; + +CUBIN_EXPORT __global__ +#ifdef NDEBUG +#if !OPTIMIZE_FOR_LATENCY +__launch_bounds__(128 * 3, headElems* ctaNbQHeads <= 128 * 16 ? 3 : 2) +#else +__launch_bounds__(128 * 3) +#endif +#else + __launch_bounds__(128 * 3, 1) +#endif + void kernel_mha( + uint32_t const nbKHeads, +#if SLIDING_WINDOW + uint32_t const slidingWinSize, +#endif + float const qScale, + OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads] +#if LOW_PREC_OUTPUT + float const* const rcpOutScale, +#endif +#if USE_INPUT_KV + IOHead const* __restrict__ const qkv, // [nbReq][beamWidth][nbQHeads+nbKHeads+nbVHeads], +#if ROPE_STYLE != 0 + Vec const* __restrict__ const ropeCosSin, // [maxNbPosEmb] +#endif +#else + IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads], +#endif + float const* attentionSinks, // [headGrpSize] + KVCacheList const cacheList, +#if USE_BEAM_SEARCH + BeamSearchParams const beamSearchParams, +#endif + uint32_t const batchSize, + float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and + // V cache. Used only for int8/fp8 KV cache. +#if PAGED_KV_CACHE_LAYOUT == 1 + __grid_constant__ CUtensorMap const tensorMapVLLMK, + __grid_constant__ CUtensorMap const tensorMapVLLMV, +#else + __grid_constant__ CUtensorMap const tensorMap, +#endif +#if SPEC_DEC + SpecDecParams const specDecParams, +#endif + uint32_t* __restrict__ const semaphores = + nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)] + void* __restrict__ const scratch = nullptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) && \ + (IS_SUPPORTED_F16_CASE || CACHE_ELEM_ENUM == 2) && BEAM_WIDTH == 1 + uint32_t const idxReq = blockIdx.z / nbKHeads; +#if SPEC_DEC + uint32_t const reqInputTokBeg = getInputTokOffset(specDecParams, idxReq); + uint32_t const reqInputTokEnd = getInputTokOffset(specDecParams, idxReq + 1); + uint32_t const nbInputSeqSplit = gridDim.x; + assert(nbInputSeqSplit == divUp(specDecParams.qSeqLen, inputTokensPerCta)); +#else + uint32_t const reqInputTokBeg = idxReq; + uint32_t const reqInputTokEnd = idxReq + 1; + constexpr uint32_t nbInputSeqSplit = 1; + assert(gridDim.x == nbInputSeqSplit); +#endif + uint32_t const idxHeadGrp = blockIdx.z % nbKHeads; // inside one request + assert(gridDim.z == nbKHeads * batchSize); + uint32_t const cacheSeqLen = getCacheSeqLen(cacheList, idxReq); + static_assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); + constexpr uint32_t tileSize = gemm0CtaTileNbTokens; +#if SPEC_DEC + uint32_t const idxInputSubSeq = blockIdx.x; + uint32_t const inputSeqLen = reqInputTokEnd - reqInputTokBeg; + uint32_t const ctaTokOffset = inputTokensPerCta * idxInputSubSeq; + uint32_t const ctaNbValidTokens = + mha::min(uint32_t{inputTokensPerCta}, inputSeqLen - ctaTokOffset); + + if (ctaTokOffset >= inputSeqLen) { + return; + } +#else + uint32_t const idxInputSubSeq = 0; + uint32_t const inputSeqLen = 1; + uint32_t const ctaTokOffset = 0; + uint32_t const ctaNbValidTokens = 1; +#endif +#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE + // get the actual start position depending on ctaTokOffset, which is the draft token position per + // CTA + uint32_t const tok0SeqLen = cacheSeqLen - inputSeqLen + 1 + ctaTokOffset; + int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize); + uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg); +#elif SLIDING_WINDOW + bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize); + // if SPEC_DEC && SLIDING_WINDOW && IS_SPEC_DEC_TREE, it should not do sliding + assert(!SPEC_DEC || !rtIsReallySliding); + uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0; +#else + constexpr bool rtIsReallySliding = false; + constexpr uint32_t nbTotalSkipTokens = 0; +#endif + uint32_t const nbSkipLeadingTiles = nbTotalSkipTokens / tileSize; + uint32_t const tile0NbSkipTokens = nbTotalSkipTokens % tileSize; + +#if USE_BEAM_SEARCH + uint32_t const ctxCacheSeqLen = getCtxCacheSeqLen(beamSearchParams, idxReq); + uint32_t const nbCtxKTiles = useKVCache ? ctxCacheSeqLen / gemm0CtaTileNbTokens : 0; + uint32_t const nbDivergentKTiles = + useKVCache + ? divUp(cacheSeqLen - gemm0CtaTileNbTokens * nbCtxKTiles, beamSearchGemm0CtaTileNbTokens) + : 0; + uint32_t const nbKTiles = nbCtxKTiles + nbDivergentKTiles; + uint32_t const nbVTiles = nbKTiles; +#else + uint32_t const nbTiles = useKVCache ? divUp(cacheSeqLen, tileSize) : 0; + // uint32_t const nbKTiles = nbTiles; + // uint32_t const nbVTiles = nbTiles; + uint32_t const nbTilesInUse = nbTiles - nbSkipLeadingTiles; +#endif + uint32_t const maxNbSubSeq = gridDim.y; + uint32_t const idxSubSeq = blockIdx.y; + bool const isMultiBlockMode = (maxNbSubSeq > 1 && nbTilesInUse >= multiBlockMinNbTiles); + uint32_t const idxKTileInit = nbSkipLeadingTiles + idxSubSeq; + uint32_t const idxVTileInit = idxKTileInit; + uint32_t const nbSubSeq = + isMultiBlockMode ? mha::min(nbTilesInUse / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1; + static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2); + assert(isMultiBlockMode == (nbSubSeq > 1)); + if (idxSubSeq >= nbSubSeq) { + return; + } + uint32_t const ctaInputTokBeg = reqInputTokBeg + ctaTokOffset; + auto const warpIdx = getWarpIdx(uint3{128, 1, 3}); + auto const wid = warpIdx.z * 4 + warpIdx.x; +#if PAGED_KV_CACHE_LAYOUT == 1 + if (wid == 0 && warpElectSync()) { + tma::prefetchTensorMap(tensorMapVLLMK); + tma::prefetchTensorMap(tensorMapVLLMV); + } +#else + if (wid == 0 && warpElectSync()) { + tma::prefetchTensorMap(tensorMap); + } +#endif + extern __shared__ char smemByteBuf[]; + assert(dynamicSmemSize() >= sizeof(SharedMem)); + SharedMem& smem = *reinterpret_cast(&smemByteBuf[0]); + + constexpr uint32_t nbBuffers = 2; + static_assert(nbBuffers == SharedMem::nbKBuf && nbBuffers == SharedMem::nbVBuf && + nbBuffers == SharedMem::nbXBuf); + if (wid < nbBuffers) { + if (warpElectSync()) { + smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size); + smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size); +#if !SWAP_AB + smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2); +#endif + smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds); + } + } else if (wid == nbBuffers) { + if (warpElectSync()) { + smem.qBar.initialize(gemm0NbThrds + nbQLdThrds, gemm0NbThrds + nbQLdThrds); + init(&smem.gemm0WarpGrpBar, gemm0NbThrds); + init(&smem.gemm1WarpGrpBar, gemm1NbThrds); + } + } + __syncthreads(); + +#if USE_PAGED_KV_CACHE + uint32_t const nbPages = divUp(cacheSeqLen, tokensPerPage); +#endif + + constexpr bool isKVCacheQuantized = (cacheElemSize < 2); + assert(idxKTileInit < nbTiles); + uint32_t const nbIters = divUp(nbTiles - idxKTileInit, nbSubSeq); + assert(nbIters >= 1); + + constexpr uint32_t gmmaInstK = gmma::instK; + constexpr uint32_t grainsPerInstK = exactDiv(sizeof(MathElem) * gmmaInstK, grainBytes); + + if (warpIdx.z == 0) { +#if SPEC_DEC + SpecDec const specDec{specDecParams, idxReq, idxInputSubSeq, cacheSeqLen}; +#endif + + // QK gemm + constexpr uint32_t nbGmmaInstM = exactDiv(gemm0CtaTileNbTokens, gmma::instM); + using Acc = GmmaAcc; + + unused(smem.qBar.consumed.arrive()); + for (auto& b : smem.kBar) { + unused(b.consumed.arrive()); + } + + float const qkScale = + qScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * + rsqrtf(validElemsPerHead); // qkScale is applied onto Q*K.T before softmax. + uint32_t const warpRank = warpIdx.x; + + // init once per sequence. It also works as global colMax across iterations. + if (threadIdx.x < ctaNbQHeads) { + smem.gemm0CurrentSeqMax[threadIdx.x] = safeInitRowMax; + } + smem.gemm0WarpGrpBar.arrive_and_wait(); + + smem.qBar.produced.arrive_and_wait(); +#if DBG_PRINT + if (threadIdx.x == 0) { + printf("q:\n"); + dbg::printArray2D<__nv_fp8_e4m3, true>(smem.q[0]); + } +#endif + + auto const matDescQBase = + gmma::makeMatDesc(nullptr, 0, SharedMem::QBuffer::Elem::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::QBuffer::Elem{})) + .raw(); + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { + uint32_t const idxKTile = idxKTileInit + idxIter * nbSubSeq; + assert(idxKTile < nbTiles); + Acc acc; // no need to initialize. GMMA allows us to ignore acc initial values. + gmma::fence(); + static_assert(cacheHeadNbParts == nbQParts); +#pragma unroll + for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) { + auto const idxKBuf = (idxIter * cacheHeadNbParts + idxPart) % SharedMem::nbKBuf; + auto& kBuf = smem.k[idxKBuf]; + auto& kBar = smem.kBar[idxKBuf]; + static_assert(SharedMem::KBuffer::rows % 8 == 0); + auto const matDescKBase = + gmma::makeMatDesc(nullptr, 0, SharedMem::KBuffer::rowBytes * 8, &smem.k[0], + gmma::getSwizzleMode(SharedMem::KBuffer{})) + .raw(); + assert(matDescKBase == gmma::makeMatDesc(nullptr, 0, SharedMem::KBuffer::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::KBuffer{})) + .raw()); + arrive_tx_and_wait(kBar.produced, exactDiv(sizeof(SharedMem::KBuffer), gemm0NbThrds)); + // if (threadIdx.x == 0) { + // printf("************* part %u *******\n", idxPart); + // printf("q:\n"); + // dbg::printArray2D<__nv_fp8_e4m3, true>(smem.q[idxPart]); + // printf("k:\n"); + // dbg::printArray2D<__nv_fp8_e4m3, true>(kBuf); + // } + constexpr uint32_t nbGmmaInstK = exactDiv(cacheHeadPartElems, gmmaInstK); +#pragma unroll + for (uint32_t k = 0; k < nbGmmaInstK; k++) { + bool const accHasVal = (idxPart != 0 || k != 0); + auto const matDescQ = addAddr(matDescQBase, &smem.q[idxPart](0, grainsPerInstK * k)); +#pragma unroll + for (uint32_t m = 0; m < nbGmmaInstM; m++) { + auto const matDescK = addAddr(matDescKBase, &kBuf(64 * m, grainsPerInstK * k)); +#if SWAP_AB + gmma::mma_async_shmA( + reinterpret_cast(acc(m, 0)), + matDescK, matDescQ, accHasVal); +#else + gmma::mma_async_shmA( + reinterpret_cast(acc(m, 0)), + matDescQ, matDescK, accHasVal); +#endif + } + } + gmma::commit_group(); + //@fixme: use two sets of acc and let gmma_async overlap with softmax. But this will let + // tile0_softmax + // wait for + // k loading of tile1 and may harm perf for short-seq cases. + gmma::wait_group<0>(); + unused(kBar.consumed.arrive()); + } +#if !defined(NDEBUG) && DBG_PRINT + dbg::printAcc(smem.gemm0WarpGrpBar, warpRank, acc); +#endif + // apply qkScale + acc = acc * qkScale; + // apply mask +#if SPEC_DEC + warpGrpApplyMask(acc, specDec, +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + tok0WinBeg, +#endif + cacheSeqLen, idxKTile, warpRank); +#else + bool const isFirstTile = (idxKTile == nbSkipLeadingTiles); + bool const needMaskLeading = (rtIsReallySliding && isFirstTile && tile0NbSkipTokens > 0); + bool const isLastTile = (idxKTile + 1 == nbTiles); + bool const needMaskTrailing = isLastTile && cacheSeqLen % tileSize != 0; + if (needMaskLeading || needMaskTrailing) { + uint32_t const validTokenBeg = needMaskLeading ? tile0NbSkipTokens : 0; + uint32_t const validTokenEnd = (needMaskTrailing ? cacheSeqLen % tileSize : tileSize); + if (validTokenBeg > 0 || validTokenEnd < tileSize) { +#if SWAP_AB + warpGrpApplyMask(warpRank, acc, validTokenBeg, validTokenEnd); +#else + warpGrpApplyMask(acc, validTokenBeg, validTokenEnd); +#endif + } + } +#endif + // update colMax in shared mem and get a register copy +#if SWAP_AB + RegColWiseVec const colMax = + computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc); + warpGrpOnlineSoftmax(acc, colMax); +#else + RegRowWiseVec const rowMax = + computeWarpGrpRowMax_sync(warpRank, smem.gemm0CurrentSeqMax, acc); + warpGrpOnlineSoftmax(acc, rowMax); +#endif + + // @fixme: may need fp32->fp8->fp32 before doing sum. +#if SWAP_AB + RegColWiseVec const warpColSum = computeWarpColSum(acc); +#else + RegRowWiseVec const rowSum = computeWarpRowSum(acc); +#endif + + // map 1 to fp8_max before conversion to fp8 + acc = acc * kE4M3_MAX; + + uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf; + auto& xBar = smem.xBar[idxXBuf]; + // @fixme: for fp16/bf16, try not to transpose acc here, and leave it to the next GEMM. +#if SWAP_AB + storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc); + // store colMax and warpColSum + auto const lane = laneId(); + if (lane < 4) { + auto& xColMax = smem.xColMax[idxXBuf]; + auto& xColSum = smem.xColSum[idxXBuf][warpRank]; +#pragma unroll + for (uint32_t n = 0; n < colMax.size; n++) { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { + if (warpRank == 0) { + xColMax[8 * n + 2 * lane + j] = colMax[n][j]; + } + xColSum[8 * n + 2 * lane + j] = warpColSum[n][j]; + } + } + } +#else + storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc); + storeShmRowWiseVec(warpRank, smem.xRowMax[idxXBuf], rowMax); + storeShmRowWiseVec(warpRank, smem.xRowSum[idxXBuf], rowSum); +#endif + + __syncwarp(); + // the release semantics of arrive does not work for async consumers like gmma. additional + // fence is needed. + asm volatile("fence.proxy.async.shared::cta;\n"); + unused(xBar.produced.arrive()); + } + unused(smem.qBar.consumed.arrive()); + } else if (warpIdx.z == 1) { + // XV GEMM + for (auto& b : smem.vBar) { + unused(b.consumed.arrive()); + } +#if !SWAP_AB + for (auto& b : smem.vtBar) { + unused(b.consumed.arrive()); + } +#endif + for (auto& b : smem.xBar) { + unused(b.consumed.arrive()); + } + + if (threadIdx.x < smem.gemm1AccColMax.size) { + auto const idx = threadIdx.x; + smem.gemm1AccColMax[idx] = safeInitRowMax; + smem.gemm1AccColSum[idx] = 0; + } + smem.gemm1WarpGrpBar.arrive_and_wait(); + + uint32_t const warpRank = warpIdx.x; + + constexpr float xScale = 1.f / kE4M3_MAX; +#if LOW_PREC_OUTPUT + float const oScale = rcpOutScale[0]; +#else + constexpr float oScale = 1.F; +#endif + float const xvoScale = xScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * oScale; + + Gemm1Acc acc{}; // init to zeros to avoid runtime checking for first gmma instruction. + gmma::fence(); + + static_assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens, "not implemented"); + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { + uint32_t idxVTile = idxVTileInit + idxIter * nbSubSeq; + auto const idxVBuf = idxIter % SharedMem::nbVBuf; + auto const idxXBuf = idxVBuf; + auto& vBar = smem.vBar[idxVBuf]; + arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds)); + auto const& vBuf = smem.vBuf(idxVBuf); +#if !SWAP_AB + CtaBarrierPair& vtBar = smem.vtBar[idxVBuf]; + auto& vtBuf = smem.vtBuf(idxVBuf); + vtBar.consumed.arrive_and_wait(); + transposeVTile(warpRank, laneId(), vtBuf, vBuf); + vBar.consumed.arrive(); + vtBar.produced.arrive(); +#endif + auto& xBar = smem.xBar[idxXBuf]; + xBar.produced.arrive_and_wait(); +#if !defined(NDEBUG) && DBG_PRINT +#if SWAP_AB + if (threadIdx.x == 0) { + printf("colMax:\n"); + for (int i = 0; i < ctaNbQHeads; i++) { + printf("%f, ", smem.xColMax[idxXBuf][i]); + } + printf("\n"); + printf("colSum:\n"); + for (int n = 0; n < 4; n++) { + for (int i = 0; i < ctaNbQHeads; i++) { + printf("%f, ", smem.xColSum[idxXBuf][n][i]); + } + printf("\n"); + } + printf("\n"); + printf("X:\n"); + for (int i = 0; i < ctaNbQHeads; i++) { + for (int j = 0; j < gemm0CtaTileNbTokens; j++) { + auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart); + auto const e = reinterpret_cast&>( + smem.xBuf(idxXBuf)[j / elemsPerXPart].template at( + i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain]; + printf("%.2f, ", float(e)); + if (j % 16 == 15) { + printf("| "); + } + } + printf("\n\n"); + } + } + smem.gemm1WarpGrpBar.arrive_and_wait(); +#else + if (blockIdx.y == 1 && threadIdx.x == 0) { + printf("rowMax:\n"); + for (int i = 0; i < ctaNbQHeads; i++) { + printf("%f, ", smem.xRowMax[idxXBuf][i]); + } + printf("\n"); + printf("rowSum:\n"); + for (int i = 0; i < ctaNbQHeads; i++) { + printf("%f, ", smem.xRowSum[idxXBuf][i]); + } + printf("\n"); + } + smem.gemm1WarpGrpBar.arrive_and_wait(); +#endif +#endif + +#if SWAP_AB + // @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc + // instead. + rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf], + smem.gemm1AccColMax, acc, smem.gemm1AccColSum, + smem.gemm1WarpGrpBar); +#else + rescaleGemm1AccForNewRowMax_sync(warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf], + smem.gemm1AccColMax, acc, smem.gemm1AccColSum); +#endif + auto& xBuf = smem.xBuf(idxXBuf); + + auto const descXBase = + gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::XBuffer::Elem{})) + .raw(); +#if CACHE_ELEM_ENUM == 0 + auto const descVBase = + gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::VBuffer::Elem{})) + .raw(); +#endif +#if SWAP_AB +//@fixme: to reduce code size, we can disable unroll and use double-buffer for LDSM in +// loadVTileTransposed. +#pragma unroll + for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++) { +#if CACHE_ELEM_ENUM == 2 + Vec const fragA = + loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK); +#if !defined(NDEBUG) && DBG_PRINT + if (threadIdx.x == 0) { + printf("fragA:\nidxInstK == %u\n", idxInstK); + } + smem.gemm1WarpGrpBar.arrive_and_wait(); + for (int m = 0; m < 2; m++) { + for (int w = 0; w < 4; w++) { + if (warpRank == w) { + if (laneId() == 0) { + printf(" warpRank = %u\n", warpRank); + } + __syncwarp(); + for (int a = 0; a < 2; a++) { + for (int b = 0; b < 8; b++) { + for (int c = 0; c < 2; c++) { + for (int d = 0; d < 4; d++) { + if (laneId() == b * 4 + d) { + for (int e = 0; e < 4; e++) { + auto const& elem4 = + reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>(fragA[m](0, c)(a, 0)); + printf("%.2f, ", float(elem4[e])); + } + } + __syncwarp(); + } + } + if (laneId() == 0) { + printf("\n"); + } + __syncwarp(); + } + if (laneId() == 0 && a == 0) { + printf("----------------------\n"); + } + __syncwarp(); + } + } + smem.gemm1WarpGrpBar.arrive_and_wait(); + } + } +#endif +#endif + BoundedVal const kOffsetInGrains{grainsPerInstK * + idxInstK}; + auto const descX = + addAddr(descXBase, + &xBuf[kOffsetInGrains.template divBy().get()]( + 0, kOffsetInGrains.template mod().get())); +#if CACHE_ELEM_ENUM == 2 + gmma::fence(); +#endif +#pragma unroll + for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++) { +#if CACHE_ELEM_ENUM == 0 + auto const descV = + addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0)); + gmma::mma_async_shmA( + reinterpret_cast( + acc(idxInstM, 0)), + descV, descX, true); +#elif CACHE_ELEM_ENUM == 2 + gmma::mma_async_regA( + reinterpret_cast( + acc(idxInstM, 0)), + reinterpret_cast(fragA[idxInstM]), descX, true); +#endif + } + gmma::commit_group(); + //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until + // finish of + // gmma. + gmma::wait_group<0>(); + } +#else + auto const descVTBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::VTBuffer{})) + .raw(); + vtBar.produced.arrive_and_wait(); +// if (idxIter == 1 && threadIdx.x == 0) { +// printf("vtBuf:\n"); +// dbg::printArray2D<__nv_fp8_e4m3, true>(vtBuf); +// } +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++) { + BoundedVal const kOffsetInGrains{grainsPerInstK * k}; + auto const descX = + addAddr(descXBase, + &xBuf[kOffsetInGrains.template divBy().get()]( + gmma::instM * m, + kOffsetInGrains.template mod().get())); + auto const descVT = + addAddr(descVTBase, + &vtBuf(0, kOffsetInGrains.template mod().get())); + gmma::mma_async_shmA( + reinterpret_cast(acc(m, 0)), + descX, descVT, true); + } + } + gmma::commit_group(); + //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until + // finish of gmma. + gmma::wait_group<0>(); +#endif + if (idxIter == nbIters - 1) { + // gmma::wait_group should have already synchronized threads, so this may be unnecessary. + smem.gemm1WarpGrpBar.arrive_and_wait(); + assert(idxXBuf == idxVBuf); + if (isMultiBlockMode) { + ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; + uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; + uint32_t const idxAllSubSeq = maxNbSubSeq * idxSeq + idxSubSeq; + uint32_t const idxChunk = idxAllSubSeq * nbInputSeqSplit + idxInputSubSeq; + // save row max/sum + static_assert(ctaNbValidQHeads <= gmmaWarpsPerGrp * warp_size); + if (threadIdx.x < ctaNbValidQHeads) { + float const colMax = smem.gemm1AccColMax[threadIdx.x]; + float const colSum = smem.gemm1AccColSum[threadIdx.x]; + ScratchMem::SumMax sumMax; + sumMax.sum = colSum; + sumMax.max = colMax; + (scratchMem.rowSumMax() + idxChunk).template cast()[threadIdx.x] = + sumMax; + } + // compute scratch ptr for output writing + IOHead* const dst = (scratchMem.tokens() + idxChunk).template cast(); +#if SWAP_AB + finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, + xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, + smem.gemm1AccColMax, nullptr); +#else + finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, + smem.gemm1AccColSum, 1, ctaNbValidTokens); +#endif + } else { + uint32_t const outOffset = + headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); + OutputHead* const dst = &output[outOffset]; + ShmQWiseVec const* attentionSinksVec = nullptr; + if (attentionSinks != nullptr) { + attentionSinksVec = + reinterpret_cast(attentionSinks + headGrpSize * idxHeadGrp); + } +#if SWAP_AB + finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, + smem.outSwizzleBuf(idxXBuf), acc, xvoScale, + smem.gemm1WarpGrpBar, smem.gemm1AccColSum, + smem.gemm1AccColMax, attentionSinksVec, nbKHeads); +#else + finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, + smem.gemm1AccColSum, nbKHeads, ctaNbValidTokens); +#endif + } + } + unused(xBar.consumed.arrive()); +#if SWAP_AB + unused(vBar.consumed.arrive()); +#else + unused(vtBar.consumed.arrive()); +#endif + } + } else { + // IO warps + static_assert(beamWidth == 1); +#if ENABLE_PDL + preExit(); +#endif +#if ENABLE_PDL == 1 + acqBulk(); +#endif + assert(warpIdx.z == 2); + uint32_t const newTokenPos = cacheSeqLen - 1; + if (warpIdx.x < nbQLdWarps) { + // load Q. Use register to load fp16 data and store fp8 to shared mem. + // @fixme: If register pressure is high and shared mem pressure is low, switch to TMA instead. + using QCvt = F16QToF8Converter; + static_assert(beamWidth == 1); +#if USE_INPUT_KV + TinyPtr const qData{ + qkv, headGrpSize * idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq}; + constexpr bool isNeox = (ROPE_STYLE == 1); + constexpr uint32_t thrdsPerHead = mha::min(warp_size, divUp(headElems, 4U)); + uint32_t const lane = laneId(); + uint32_t const idxThrd = warpIdx.x * warp_size + lane; + uint32_t const idxThrdGrp = + (thrdsPerHead % 32 == 0 ? makeWarpUniform(this_warp(), idxThrd / thrdsPerHead) + : idxThrd / thrdsPerHead); + constexpr uint32_t nbThrdGrps = exactDiv(warp_size * nbQLdWarps, thrdsPerHead); + uint32_t const tid = idxThrd % thrdsPerHead; + smem.qBar.consumed.arrive_and_wait(); +#if ROPE_STYLE != 0 + auto const& ropeCosSinHead = + reinterpret_cast const&>(ropeCosSin[cacheSeqLen - 1]); + auto const cosSinPairs = loadHead(ropeCosSinHead, tid); +#endif +#if ENABLE_PDL == 2 + acqBulk(); +#endif +#pragma unroll + for (uint32_t iter = 0; iter < divUp(headGrpSize, nbThrdGrps); iter++) { + uint32_t const idxHead = nbThrdGrps * iter + idxThrdGrp; + if (idxHead >= headGrpSize) { + break; + } +#if ROPE_STYLE == 0 + auto const rotatedPairs = + loadHead(qData[idxHead], tid); +#else + auto const pairs = loadHead(qData[idxHead], tid); + auto const rotatedPairs = applyRoPE(pairs, cosSinPairs); +#endif + storeRotatedPairsForQ(smem.q, rotatedPairs, idxHead, tid); + } +#else + TinyPtr const qData{ + q, headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp)}; +#if ENABLE_PDL == 2 + acqBulk(); +#endif + auto const f16QData = QCvt::load(threadIdx.x, qData, nbKHeads, ctaNbValidTokens); + + smem.qBar.consumed.arrive_and_wait(); + QCvt::store(threadIdx.x, smem.q, f16QData); +#endif + // the release semantics of arrive does not work for async consumers like gmma. additional + // fence is needed. + asm volatile("fence.proxy.async.shared::cta;\n"); + unused(smem.qBar.produced.arrive()); + } else if (warpIdx.x == nbQLdWarps) { // load k + KVTilePartLoader kTilePartLoader{true, nbKHeads, cacheList, idxReq, idxHeadGrp, +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + tensorMapVLLMK, +#else + tensorMap, +#endif + nbPages, smem.pages[0] +#else + tensorMap +#endif + }; + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { + uint32_t const idxKTile = idxKTileInit + idxIter * nbSubSeq; + kTilePartLoader.loadPages(idxKTile); +#if USE_INPUT_KV || ENABLE_PDL == 2 +#if SPEC_DEC + bool const anyNewTokens = + (gemm0CtaTileNbTokens * (idxKTile + 1) > cacheSeqLen - inputSeqLen); +#else + bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxKTile + 1) >= cacheSeqLen); +#endif + if (anyNewTokens) { +#if ENABLE_PDL == 2 + acqBulk(); +#endif +#if USE_INPUT_KV + static_assert(beamWidth == 1); + uint32_t const inputKHeadOffset = + headGrpSize * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; + IOHead const& inKHead = qkv[inputKHeadOffset]; + uint32_t const lane = laneId(); + float const rcpKScale = 1.F / kvCacheScale[0]; +#if ROPE_STYLE == 0 + constexpr bool isNeox = false; + auto const pairs = + loadHead(inKHead, lane) * rcpKScale; + Vec, decltype(pairs)::size> convertedPairs; + constexpr uint32_t nbElems = decltype(pairs)::Elem::size * decltype(pairs)::size; + reinterpret_cast&>(convertedPairs) = + convert(reinterpret_cast const&>(pairs)); + storeRotatedPairsForKV(kTilePartLoader.getHead(newTokenPos), + convertedPairs, lane); +#else + constexpr bool isNeox = (ROPE_STYLE == 1); + auto const pairs = loadHead(inKHead, lane) * rcpKScale; + auto const& ropeCosSinHead = + reinterpret_cast const&>(ropeCosSin[cacheSeqLen - 1]); + auto const cosSinPairs = loadHead(ropeCosSinHead, lane); + auto const rotatedPairs = applyRoPE(pairs, cosSinPairs); + storeRotatedPairsForKV(kTilePartLoader.getHead(newTokenPos), + rotatedPairs, lane); +#endif + static_assert(inputSeqLen == 1); + __syncwarp(); +#endif + } +#endif + for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) { + auto const idxKBuf = (idxIter * cacheHeadNbParts + idxPart) % SharedMem::nbKBuf; + auto& kBar = smem.kBar[idxKBuf]; + kBar.consumed.arrive_and_wait(); + if (warpElectSync()) { + kTilePartLoader.loadData(smem.k[idxKBuf], idxKTile, idxPart, kBar.produced); + } + __syncwarp(); + } + } + } else if (warpIdx.x == nbQLdWarps + 1) { // load v + KVTilePartLoader vTileLoader{false, nbKHeads, cacheList, idxReq, idxHeadGrp, +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + tensorMapVLLMV, +#else + tensorMap, +#endif + nbPages, smem.pages[1] +#else + tensorMap +#endif + }; + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { + uint32_t const idxVTile = idxVTileInit + idxIter * nbSubSeq; + vTileLoader.loadPages(idxVTile); +#if USE_INPUT_KV || ENABLE_PDL == 2 +#if SPEC_DEC + bool const anyNewTokens = + (gemm0CtaTileNbTokens * (idxVTile + 1) > cacheSeqLen - inputSeqLen); +#else + bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxVTile + 1) >= cacheSeqLen); +#endif + if (anyNewTokens) { +#if ENABLE_PDL == 2 + acqBulk(); +#endif +#if USE_INPUT_KV + static_assert(beamWidth == 1); + uint32_t const inputVHeadOffset = + (headGrpSize + 1) * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; + IOHead const& inVHead = qkv[inputVHeadOffset]; + uint32_t const lane = laneId(); + float const rcpVScale = 1.F / kvCacheScale[0]; + constexpr bool isNeox = false; + auto const pairs = + loadHead(inVHead, lane) * rcpVScale; + Vec, decltype(pairs)::size> convertedPairs; + constexpr uint32_t nbElems = decltype(pairs)::Elem::size * decltype(pairs)::size; + reinterpret_cast&>(convertedPairs) = + convert(reinterpret_cast const&>(pairs)); + static_assert(SPEC_DEC == 0); + storeRotatedPairsForKV(vTileLoader.getHead(newTokenPos), + convertedPairs, lane); + __syncwarp(); +#endif + } +#endif + + uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf; + auto& vBar = smem.vBar[idxVBuf]; + vBar.consumed.arrive_and_wait(); + if (warpElectSync()) { +#pragma unroll + for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) { + vTileLoader.loadData(smem.vBuf(idxVBuf)[idxPart], idxVTile, idxPart, vBar.produced); + } + } + __syncwarp(); + } + } + } + __syncthreads(); + uint32_t const nbBarriers = &smem.gemm1WarpGrpBar - &smem.qBar.produced + 1; + uint32_t const tid = + threadIdx.x + blockDim.x * threadIdx.y + blockDim.x * blockDim.y * threadIdx.z; + assert(nbBarriers <= blockDim.x * blockDim.y * blockDim.z); + if (tid < nbBarriers) { + (&smem.qBar.produced)[tid].~CtaBarrier(); + } + if (!isMultiBlockMode) { + return; + } + bool& smemIsLastCta = smem.isLastCta; + if (threadIdx.x == gemm1NbThrds - 1U && threadIdx.z == 0) { + uint32_t const lastOld = nbSubSeq - 1; + ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; + uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; + uint32_t old; + uint32_t const idxSemaphore = idxSeq * nbInputSeqSplit + idxInputSubSeq; + auto const pSemaphore = &semaphores[idxSemaphore]; + asm volatile("atom.acq_rel.gpu.global.inc.u32 %0, [%1], %2;\n" + : "=r"(old) + : "l"(pSemaphore), "r"(lastOld)); + smemIsLastCta = (old == lastOld); + } + { + assert(dynamicSmemSize() >= sizeof(MultiBlockSMem)); +#ifndef __CUDACC_RTC__ + assert(sizeof(MultiBlockSMem) < offsetof(SharedMem, isLastCta)); +#endif + auto& smem = *reinterpret_cast(&smemByteBuf[0]); + assert(blockDim.x >= MultiBlockSMem::nbBuf); + constexpr uint32_t nbMathWarps = gemm0NbWarps + gemm1NbWarps; + + static_assert(nbWarps >= MultiBlockSMem::nbBuf); + if (wid < MultiBlockSMem::nbBuf) { + if (warpElectSync()) { + smem.barriers[wid].initialize(isHeadPadded ? warp_size : 1U, nbMathWarps * warp_size); + smem.barriers[wid].consumed.arrive(nbMathWarps * warp_size); + } + } + __syncthreads(); + + if (!smemIsLastCta) { + return; + } + if (wid < nbMathWarps) { + constexpr uint32_t headsPerWarp = divUp(ctaNbValidQHeads, nbMathWarps); + using Acc = Vec; + + struct HeadState { + Acc acc; + float sum; + float max; + }; + + Vec states{}; + for (auto& s : states.data) { + s.max = safeInitRowMax; + } + uint32_t const lane = laneId(); + for (uint32_t idxBlock = 0; idxBlock < nbSubSeq; idxBlock++) { + uint32_t const idxBuf = idxBlock % MultiBlockSMem::nbBuf; + auto& bar = smem.barriers[idxBuf]; + bar.produced.wait_parity(idxBlock / MultiBlockSMem::nbBuf % 2 != 0); + for (uint32_t i = 0; i < headsPerWarp; i++) { + uint32_t const idxHead = wid + nbMathWarps * i; + if ((ctaNbValidQHeads % nbMathWarps != 0) && (idxHead >= ctaNbValidQHeads)) { + break; + } + HeadState& state = states[i]; + auto const sumMax = smem.rowSumMax[idxBuf][idxHead]; + auto const data = convert(reinterpret_cast&>( + smem.tokens[idxBuf][idxHead][Acc::size * lane])); + if (sumMax.max > state.max) { + float const scale = expf(state.max - sumMax.max); + state.max = sumMax.max; + state.sum = state.sum * scale + sumMax.sum; + state.acc = state.acc * scale + data * sumMax.sum; + } else { + float const scale = expf(sumMax.max - state.max); + state.sum = state.sum + sumMax.sum * scale; + state.acc = state.acc + data * (sumMax.sum * scale); + } + } + unused(bar.consumed.arrive()); + } + // Add the attention sinks. + if (attentionSinks != nullptr) { + for (uint32_t i = 0; i < headsPerWarp; i++) { + uint32_t const idxHead = wid + nbMathWarps * i; + float sink = + expf(attentionSinks[mha::min(idxHead, headGrpSize - 1) + idxHeadGrp * headGrpSize] - + states[i].max); + states[i].sum += sink; + } + } + __syncthreads(); + uint32_t const outOffset = + headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); + auto const dst = &output[outOffset]; + for (uint32_t i = 0; i < headsPerWarp; i++) { + uint32_t const idxHead = wid + nbMathWarps * i; + if ((ctaNbValidQHeads % nbMathWarps != 0) && (idxHead >= ctaNbValidQHeads)) { + break; + } +#if SPEC_DEC + uint32_t const idxToken = idxHead / headGrpSize; + if (idxToken >= ctaNbValidTokens) { + break; + } + uint32_t const tokenPad = headGrpSize * (nbKHeads - 1); + uint32_t const idxDstHead = idxHead + idxToken * tokenPad; +#else + uint32_t const idxDstHead = idxHead; +#endif + auto const& s = states[i]; + auto const outData = convert(s.acc * (1.f / s.sum)); + if (Acc::size * lane < validElemsPerHead) { + reinterpret_cast&>(dst[idxDstHead][Acc::size * lane]) = + outData; + } + } + } else if (wid < nbMathWarps + MultiBlockSMem::nbIOWarps) { + static_assert(MultiBlockSMem::nbIOWarps <= MultiBlockSMem::nbBuf); + ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; + uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; + uint32_t const initIdxBlock = wid - nbMathWarps; + // each warp loads data for a block + for (uint32_t idxBlock = initIdxBlock; idxBlock < nbSubSeq; + idxBlock += MultiBlockSMem::nbIOWarps) { + uint32_t const idxAllSubSeq = maxNbSubSeq * idxSeq + idxBlock; + uint32_t const idxChunk = idxAllSubSeq * nbInputSeqSplit + idxInputSubSeq; + uint32_t const idxBuf = idxBlock % MultiBlockSMem::nbBuf; + auto& bar = smem.barriers[idxBuf]; + bar.consumed.wait_parity(idxBlock / MultiBlockSMem::nbBuf % 2 != 0); + auto const lane = laneId(); +#pragma unroll + for (uint32_t iter = 0; iter < divUp(ctaNbValidQHeads, warp_size); iter++) { + uint32_t const i = iter * warp_size + lane; + if (ctaNbValidQHeads % warp_size != 0 && i >= ctaNbValidQHeads) { + break; + } + ldgsts::copyAsync( + &smem.rowSumMax[idxBuf][i], &scratchMem.rowSumMax()[idxChunk][i]); + } + ldgsts::barArrive(bar.produced, false); + if constexpr (isHeadPadded) { + static_assert(grainsPerPaddedInputHead <= warp_size); + constexpr uint32_t headsPerIter = exactDiv(warp_size, grainsPerPaddedInputHead); + constexpr uint32_t nbIters = divUp(ctaNbValidQHeads, headsPerIter); + constexpr uint32_t nbWholeIters = ctaNbValidQHeads / headsPerIter; +#pragma unroll + for (uint32_t i = 0; i < nbIters; i++) { + uint32_t const idxHead = + headsPerIter * i + + BoundedVal{lane}.template divBy().get(); + uint32_t const idxGrain = + BoundedVal{lane}.template mod().get(); + if (i < nbWholeIters || idxHead < ctaNbValidQHeads) { + constexpr uint32_t nbElemsPerGrain = + exactDiv(grainBytes, sizeof(MultiBlockSMem::Elem)); + auto const dst = &smem.tokens[idxBuf][idxHead][nbElemsPerGrain * idxGrain]; + auto const src = + idxGrain < grainsPerIOHead + ? &scratchMem.tokens()[idxChunk][idxHead][nbElemsPerGrain * idxGrain] + : nullptr; + ldgsts::copyAsync(dst, src, idxGrain < grainsPerIOHead ? grainBytes : 0U); + } + } + ldgsts::barArrive(bar.produced, true); + } else { + if (warpElectSync()) { + tma::loadLinearAsync(&smem.tokens[idxBuf], &scratchMem.tokens()[idxChunk], + sizeof(smem.tokens[idxBuf]), bar.produced); + arrive_tx(bar.produced, sizeof(smem.tokens[idxBuf]), 1); + } + } + } + __syncthreads(); + uint32_t const idxBar = tid - warp_size * nbMathWarps; + if (idxBar < MultiBlockSMem::nbBuf * 2) { + reinterpret_cast(&smem.barriers[0])[idxBar].~CtaBarrier(); + } + } + } +#else +#if GENERATE_CUBIN + static_assert("This kernel is for Hopper only"); +#else + asm volatile("trap;\n"); +#endif +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && BEAM_WIDTH == 1 +} + +#if CACHE_ELEM_ENUM == 0 || CACHE_ELEM_ENUM == 2 +template +__device__ inline typename F16QToF8Converter::RegData +F16QToF8Converter::load(uint32_t tid, TinyPtr const& src, + uint32_t const nbKHeads /*for beam search only*/, + uint32_t nbTokens) { +#if !(SPEC_DEC) + assert(nbTokens == 1); + nbTokens = 1; +#endif + typename F16QToF8Converter::RegData dst; +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t const idxGrain = nbThrds * iter + tid; + if (idxGrain >= totalGrains) { + break; + } +#if SPEC_DEC + uint32_t const idxToken = idxGrain / grainsPerPaddedInputQHeadGrp; + uint32_t const tokenPad = grainsPerPaddedInputQHeadGrp * (nbKHeads - 1); + uint32_t offsetInGrains = idxGrain + tokenPad * idxToken; + static_assert(beamWidth == 1); +#else + uint32_t const idxBeam = beamWidth == 1 ? 0 : idxGrain / grainsPerPaddedInputQHeadGrp; + uint32_t const beamPad = grainsPerPaddedInputQHeadGrp * (nbKHeads - 1); + uint32_t offsetInGrains = idxGrain + beamPad * idxBeam; +#endif + bool isGrainInBound = true; + if constexpr (isHeadPadded) { + uint32_t const idxGrainInsideHead = offsetInGrains % grainsPerPaddedInputHead; + offsetInGrains = + offsetInGrains / grainsPerPaddedInputHead * grainsPerIOHead + idxGrainInsideHead; + isGrainInBound = (idxGrainInsideHead < grainsPerIOHead); + } +#if SPEC_DEC + isGrainInBound = isGrainInBound && (idxToken < nbTokens); +#endif + LdGrain const srcGrain = + isGrainInBound ? src.template cast()[offsetInGrains] : LdGrain{}; + static_assert(inputElemSize == 2); + auto const& fp16Data = + reinterpret_cast const&>(srcGrain); + dst[iter] = idxGrain % grainsPerPaddedInputHead < grainsPerIOHead + ? fp16Data + : mha::decay_t{}; + } + return dst; +} + +template +__device__ inline void F16QToF8Converter::store( + uint32_t tid, SharedMem::QBuffer& dst, + F16QToF8Converter::RegData const& data) { +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t const idxGrain = nbThrds * iter + tid; + if (idxGrain >= totalGrains) { + break; + } +#if CACHE_ELEM_ENUM == 0 + static_assert(inputElemSize == cacheElemSize); + ShmVec const& shmData = data[iter]; + uint32_t const r = idxGrain / grainsPerPaddedInputHead; + BoundedVal const c = {idxGrain % grainsPerPaddedInputHead}; + + dst[c.template divBy().get()].template at( + r, c.template mod().get()) = reinterpret_cast(shmData); +#else + auto const& fp16Data = data[iter]; + ShmVec shmData; +#pragma unroll + for (uint32_t i = 0; i < fp16Data.size; i++) { + shmData[i] = CacheElem{fp16Data[i]}; + } + uint32_t const dstIdxGrain = idxGrain / 2; + uint32_t const dstIdxHalfGrain = idxGrain % 2; + constexpr uint32_t grainsPerCacheHead = exactDiv(paddedCacheHeadBytes, grainBytes); + uint32_t const r = dstIdxGrain / grainsPerCacheHead; + BoundedVal const c = {dstIdxGrain % grainsPerCacheHead}; + reinterpret_cast&>( + dst[c.template divBy().get()].template at( + r, c.template mod().get()))[dstIdxHalfGrain] = shmData; +#endif + } +} +#endif + +__device__ inline KVTilePartLoader::KVTilePartLoader(bool isK, uint32_t nbKHeads, + KVCacheList const& cacheList, + uint32_t idxReq, uint32_t idxHeadGrp, + CUtensorMap const& tensorMap +#if USE_PAGED_KV_CACHE + , + uint32_t nbPages, + Vec& pageBuf +#endif + ) + : nbKHeads{nbKHeads}, + cacheList{cacheList}, + idxReq{idxReq}, + idxHeadGrp{idxHeadGrp}, + tensorMap{tensorMap} +#if USE_PAGED_KV_CACHE + , + nbPages{nbPages}, + pages{pageBuf} +#if PAGED_KV_CACHE_LAYOUT == 1 + , + baseOffset{idxReq * cacheList.maxNbPagesPerSeq} +#else + , + baseOffset{((idxReq * beamWidth) * 2 + (isK ? 0 : 1)) * cacheList.maxNbPagesPerSeq} +#endif +#else + , + baseOffset{(idxReq * beamWidth) * 2 + (isK ? 0 : 1)} +#endif +{ +} + +// tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache +template +__device__ inline void KVTilePartLoader::loadData( + Array2D& dst, + uint32_t idxTile, uint32_t idxPart, CtaBarrier& bar) { + static_assert(nbTokens == gemm0CtaTileNbTokens); +#if USE_PAGED_KV_CACHE + assert(idxTile == idxTileRef); + if constexpr (nbTokens < tokensPerPage) { + assert(nbPagesPerTile == 1); + uint32_t const offset = nbTokens * (idxTile % exactDiv(tokensPerPage, nbTokens)); +#if PAGED_KV_CACHE_LAYOUT == 1 + tma::loadAsync(&dst, tensorMap, + DimsLE<4>{partElems * idxPart, idxHeadGrp, offset, (uint32_t)pages[0]}, bar); +#else + tma::loadAsync(&dst, tensorMap, + DimsLE<4>{partElems * idxPart, offset, idxHeadGrp, (uint32_t)pages[0]}, bar); +#endif + } else { +#pragma unroll + for (uint32_t i = 0; i < nbPagesPerTile; i++) { +#if PAGED_KV_CACHE_LAYOUT == 1 + tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, + DimsLE<4>{partElems * idxPart, idxHeadGrp, 0, (uint32_t)pages[i]}, bar); +#else + tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, + DimsLE<4>{partElems * idxPart, 0, idxHeadGrp, (uint32_t)pages[i]}, bar); +#endif + } + } +#else + tma::loadAsync(&dst, tensorMap, + DimsLE<4>{partElems * idxPart, nbTokens * idxTile, idxHeadGrp, baseOffset}, bar); +#endif +} + +__device__ inline void KVTilePartLoader::loadPages(uint32_t idxTile) { +#if USE_PAGED_KV_CACHE + uint32_t const idxPageBeg = gemm0CtaTileNbTokens >= tokensPerPage + ? nbPagesPerTile * idxTile + : idxTile / exactDiv(tokensPerPage, gemm0CtaTileNbTokens); +#pragma unroll + for (uint32_t i = 0; i < nbPagesPerTile; i++) { + uint32_t const idxPage = idxPageBeg + i; + auto const page = + idxPage < nbPages ? cacheList.kvCachePageList[baseOffset + idxPage] : kBAD_PAGE_INDEX; + if (warpElectSync()) { + pages[i] = page; + } + } + idxTileRef = idxTile; + __syncwarp(); +#endif +} + +__device__ inline GMemKVCacheHead& KVTilePartLoader::getHead(uint32_t pos) { + constexpr uint32_t nbTokens = gemm0CtaTileNbTokens; +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + // Raise a runtime error indicating not implemented + assert(false && "KVTilePartLoader::getHead is not implemented for PAGED_KV_CACHE_LAYOUT == 1"); + __trap(); +#else + uint32_t const idxTile = pos / nbTokens; + assert(idxTile == idxTileRef); + uint32_t const offset = pos % tokensPerPage; + return cacheList + .pool[tokensPerPage * (nbKHeads * pages[pos % nbTokens / tokensPerPage] + idxHeadGrp) + + offset]; +#endif +#else + // shape: KVCacheHead[batchSize][beamWidth][2][nbKHeads][capacity] + return cacheList.data[cacheList.capacity * (baseOffset * nbKHeads + idxHeadGrp) + pos]; +#endif +} + +#if SWAP_AB +#if SPEC_DEC +__device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + int32_t tok0WinBeg, +#endif + uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank) { + constexpr uint32_t tileSize = gemm0CtaTileNbTokens; + static_assert(SPEC_Q_SEQ_LEN <= sizeof(MaskType) * 8, "not implemented"); + + assert(cacheSeqLen >= SPEC_Q_SEQ_LEN); + uint32_t const maskStartRow = cacheSeqLen - SPEC_Q_SEQ_LEN; + uint32_t const tileStartRow = tileSize * idxTile; + if (tileStartRow + tileSize < maskStartRow) { + return; + } + + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; + +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; + uint32_t const maskCol = col / headGrpSize; + MaskType const bit_mask = (1ULL << (maskCol + 1)) - 1; + +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint32_t const row = gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad; + uint32_t const globalRow = tileStartRow + row; + if (globalRow >= cacheSeqLen) { + acc(m, n)(i, j) = safeInitRowMax; + continue; + } + if (globalRow >= maskStartRow) { + uint32_t const maskRow = globalRow - maskStartRow; + if ((bit_mask >> maskRow) == 0) { + acc(m, n)(i, j) = safeInitRowMax; + } + } + } + } + } + } +} +#endif // SPEC_DEC + +// smemColMax is persistent across multiple iterations +__device__ inline RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, + ShmQWiseVec& smemColMax, + Gemm0Acc const& src) { + auto colMax = RegColWiseVec::filled(Vec::filled(safeInitRowMax)); +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { +#pragma unroll + for (uint32_t m = 0; m < src.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + colMax[n][j] = (m == 0 && i == 0) ? src(m, n)(i, j) : fmax(colMax[n][j], src(m, n)(i, j)); + } + } + } + } + +#pragma unroll + for (uint32_t xorMask = 16; xorMask > 2; xorMask /= 2) { +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { + auto& x = colMax[n][j]; + x = fmax(x, __shfl_xor_sync(~0U, x, xorMask)); + } + } + } + + uint32_t const lane = laneId(); + if (lane < 4) { +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { + atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]); + } + } + } + warpGrpBar.arrive_and_wait(); + uint32_t const idxInQuad = lane % 4; + +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + assert(colMax[n][j] <= smemColMax[8 * n + 2 * idxInQuad + j]); + colMax[n][j] = smemColMax[8 * n + 2 * idxInQuad + j]; + } + } + warpGrpBar.arrive_and_wait(); + return colMax; +} + +__device__ inline RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec) { + RegColWiseVec ret; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); + auto const idx = laneId() % nbThrdsPerInstNBase; +#pragma unroll + for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) { + static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == + exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); + ret[i] = reinterpret_cast, + exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( + smemVec)[i * nbThrdsPerInstNBase + idx]; + } + return ret; +} + +__device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, + uint32_t bound) { + RegColWiseVec ret; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); + auto const idx = laneId() % nbThrdsPerInstNBase; +#pragma unroll + for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) { + static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == + exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); + ret[i] = reinterpret_cast, + exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( + gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)]; + } + return ret; +} + +__device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, + uint32_t validRowEnd) { + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint32_t const row = 64 * m + 16 * warpRank + 8 * i + idxQuad; + if (row >= validRowBeg && row < validRowEnd) { + continue; + } +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + acc(m, n)(i, j) = safeInitRowMax; + } + } + } + } +} + +__device__ inline void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax) { +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + float const maxVal = colMax[n][j]; + float const bias = maxVal * log2e; +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + float& elem = acc(m, n)(i, j); + assert(maxVal >= elem); + elem = exp2f(elem * log2e - bias); + } + } + } + } +} + +__device__ inline RegColWiseVec computeWarpColSum(Gemm0Acc& src) { + auto colSum = RegColWiseVec::filled(Vec::filled(0)); +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { +#pragma unroll + for (uint32_t m = 0; m < src.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + colSum[n][j] = (m == 0 && i == 0) ? src(m, n)(i, j) : colSum[n][j] + src(m, n)(i, j); + } + } + } + } + +#pragma unroll + for (uint32_t xorMask = 16; xorMask > 2; xorMask /= 2) { +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + auto& x = colSum[n][j]; + x += __shfl_xor_sync(~0U, x, xorMask); + } + } + } + return colSum; +} + +__device__ inline void storeGemm0AccToShm(uint32_t warpRank, uint32_t lane, + SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, + Gemm0Acc const& acc) { +#if CACHE_ELEM_ENUM == 0 + using F16Acc = Array2D, Gemm0Acc::rows, Gemm0Acc::cols>; + F16Acc f16Acc; + reinterpret_cast&>(f16Acc) = + convert(reinterpret_cast const&>(acc)); + static_assert(Gemm0Acc::size == 1 || Gemm0Acc::size % 2 == 0); + uint32_t const idxHalf = lane / 16; + uint32_t const idxInHalf = lane % 16; + uint32_t const idxOctInsideHalf = idxInHalf / 8; + uint32_t const idxRowInsideOct = lane % 8; + uint32_t const warpBaseC = 16 * warpRank; + auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> std::pair { + uint32_t const accR = idxAccCoreMat / Gemm0Acc::cols; + uint32_t const accC = idxAccCoreMat % Gemm0Acc::cols; + return {accR, accC}; + }; + auto const getDstAddr = [&](uint32_t idxAccCoreMat) -> LdGrain* { + auto const [accR, accC] = toAccCoords(idxAccCoreMat); + static_assert(sizeof(MathElem) * gemm0CtaTileNbTokens == xPartBytes); + uint32_t const idxPart = 0; + uint32_t const dstR = accC * 8 + idxRowInsideOct; + uint32_t const dstC = + exactDiv(gmma::instM * accR + warpBaseC + 8 * idxOctInsideHalf, cacheElemsPerGrain); + assert(dstC / exactDiv(xPartBytes, grainBytes) == idxPart); + return &smemX[idxPart].template at(dstR, dstC); + }; + auto const getAccData = [&](uint32_t idxAccCoreMat) { + auto const [accR, accC] = toAccCoords(idxAccCoreMat); + return f16Acc(accR, accC); + }; + + barConsumed.arrive_and_wait(); +#pragma unroll + for (uint32_t iter = 0; iter < Gemm0Acc::size / 2; iter++) { + auto const dstAddr = getDstAddr(iter * 2 + idxHalf); + Vec const data[2] = {getAccData(iter * 2), getAccData(iter * 2 + 1)}; + stmatrix(dstAddr, reinterpret_cast(data)); + } + if constexpr (Gemm0Acc::size % 2 != 0) { + auto const dstAddr = lane < 16 ? getDstAddr(Gemm0Acc::size - 1) : nullptr; + stmatrix(dstAddr, getAccData(Gemm0Acc::size - 1)); + } +#elif CACHE_ELEM_ENUM == 2 + using F8Acc = Array2D; + F8Acc f8Acc; +#pragma unroll + for (uint32_t i = 0; i < acc.rows; i++) { +#pragma unroll + for (uint32_t j = 0; j < acc.cols; j++) { + auto const& core = acc(i, j); + static_assert(mha::is_same_v); + Vec const f8Data = { + __nv_cvt_float2_to_fp8x2(float2{core(0, 0), core(1, 0)}, __NV_SATFINITE, __NV_E4M3), + __nv_cvt_float2_to_fp8x2(float2{core(0, 1), core(1, 1)}, __NV_SATFINITE, __NV_E4M3)}; + f8Acc(i, j) = reinterpret_cast(f8Data); + } + } + + if constexpr (F8Acc::size == 4 || F8Acc::size == 2 || F8Acc::size == 1) { + LdGrain* dst = nullptr; + if (F8Acc::size == 4 || lane < 8 * F8Acc::size) { + uint32_t const idxCore = lane / 8; + uint32_t const srcRow = idxCore / F8Acc::cols; + uint32_t const srcCol = idxCore % F8Acc::cols; + uint32_t const dstCoreRow = lane % 8; + uint32_t const dstRow = srcCol * 8 + dstCoreRow; + BoundedVal const dstCol{ + srcRow * 4 + warpRank}; + dst = &smemX[dstCol.template divBy().get()].template at( + dstRow, dstCol.template mod().get()); + } + barConsumed.arrive_and_wait(); + stmatrix(dst, reinterpret_cast const&>(f8Acc)); + } else { + // we need to use loops + assert(false); + trap(); + } +#endif +} + +#else + +__device__ inline RegRowWiseVec warpRowWiseReduce(RegRowWiseVec const& init, Gemm0Acc const& src, + float (*op)(float, float)) { + RegRowWiseVec vec = init; +#pragma unroll + for (uint32_t m = 0; m < src.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + // @fixme: check if compiler is reordering these op to hide latency. + vec[m][i] = op(vec[m][i], src(m, n)(i, j)); + } + } + } + } + +#pragma unroll + for (uint32_t xorMask = 2; xorMask != 0; xorMask /= 2) { +#pragma unroll + for (uint32_t m = 0; m < src.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + auto& x = vec[m][i]; + x = op(x, __shfl_xor_sync(~0U, x, xorMask)); + } + } + } + return vec; +} + +__device__ inline RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, + ShmQWiseVec& smemRowMax, + Gemm0Acc const& src) { + assert(warpRank < 4); + RegRowWiseVec const init = loadShmRowWiseVecWithDup(warpRank, smemRowMax); + RegRowWiseVec rowMax = warpRowWiseReduce(init, src, fmax); + + storeShmRowWiseVec(warpRank, smemRowMax, rowMax); + __syncwarp(); + return rowMax; +} + +#if SPEC_DEC +__device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + int32_t tok0WinBeg, +#endif + uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank) { + constexpr uint32_t tileSize = gemm0CtaTileNbTokens; + auto const inputSeqLen = specDec.inputSeqLen; + auto const idxInputSubSeq = specDec.idxInputSubSeq; + constexpr uint64_t fullMask = ~uint64_t{0}; + static_assert(tileSize == sizeof(fullMask) * 8); +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + uint32_t const ctaTokOffset = inputTokensPerCta * idxInputSubSeq; + Range const tileRange = {tileSize * idxTile, tileSize * idxTile + tileSize}; + Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (inputTokensPerCta - 1)}; + bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end; + assert(ctaNeedBegMask == overlap(tileRange, maxMaskOutRange)); + int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(tileSize * idxTile); +#else + constexpr bool ctaNeedBegMask = false; + uint64_t const begMask = fullMask; + int32_t const tok0NbMaskOut = -2147483648; +#endif + uint32_t const offset = tileSize * idxTile; + uint32_t const nbValidCols = mha::min(offset < cacheSeqLen ? cacheSeqLen - offset : 0U, tileSize); + bool const ctaNeedEndMask = (nbValidCols < tileSize); + bool const ctaNeedSpecDecMask = specDec.needMask(idxTile, 0); + bool const needMask = ctaNeedBegMask || ctaNeedEndMask || ctaNeedSpecDecMask; + if (!needMask) { + return; + } + static_assert(tileSize == 64, "not implemented"); + auto const endMask = fullMask >> (tileSize - nbValidCols); + + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint32_t const row = gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad; + uint32_t const idxQTokInCta = row / headGrpSize; + bool const isQTokValid = + (headGrpSize * inputTokensPerCta == ctaNbQHeads) || (idxQTokInCta < inputTokensPerCta); + auto const specDecMask = (isQTokValid && specDec.needMask(idxTile, idxQTokInCta)) + ? specDec.loadTileMaskRow(idxTile, idxQTokInCta) + : SpecDec::TileMaskRow{~0U, ~0U}; +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + int32_t const begNbMaskOut = tok0NbMaskOut + int32_t(idxQTokInCta); + uint64_t const begMask = (begNbMaskOut > 0 ? fullMask << begNbMaskOut : fullMask); +#else + uint64_t const begMask = fullMask; +#endif + auto const mask = begMask & endMask & reinterpret_cast(specDecMask); + if (mask == ~uint64_t{0}) { + continue; + } +#if DBG_PRINT + if (idxInQuad == 0) { + printf("mask at row %d: %lx\n", row, mask); + } +#endif +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; + assert((col < nbValidCols) == bool(endMask & (1ULL << col))); + if ((mask & (1ULL << col)) == 0) { + acc(m, n)(i, j) = safeInitRowMax; + } + } + } + } + } +} +#else +__device__ inline void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd) { + uint32_t const idxInQuad = laneId() % 4; +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; + if (col >= validColBeg && col < validColEnd) { + continue; + } +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + acc(m, n)(i, j) = safeInitRowMax; + } + } + } + } +} +#endif + +__device__ inline void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegRowWiseVec const& rowMax) { +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + float const maxVal = rowMax[m][i]; + float const bias = maxVal * log2e; +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + float& elem = acc(m, n)(i, j); + assert(maxVal >= elem); + elem = exp2f(elem * log2e - bias); + } + } + } + } +} + +__device__ inline RegRowWiseVec computeWarpRowSum(Gemm0Acc& src) { + return warpRowWiseReduce(RegRowWiseVec{}, src, [](float a, float b) { return a + b; }); +} + +__device__ inline RegRowWiseVec loadShmRowWiseVecWithDup(uint32_t warpRank, + ShmQWiseVec const& smemVec) { + RegRowWiseVec vec; + uint32_t const idxQuad = laneId() / 4; +#pragma unroll + for (uint32_t m = 0; m < RegRowWiseVec::size; m++) { +#pragma unroll + for (uint32_t i = 0; i < RegRowWiseVec::Elem::size; i++) { + vec[m][i] = smemVec[gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad]; + } + } + return vec; +} + +__device__ void storeShmRowWiseVec(uint32_t warpRank, ShmQWiseVec& smemVec, + RegRowWiseVec const& regVec) { + uint32_t const lane = laneId(); + uint32_t const idxQuad = lane / 4; + uint32_t const idxInQuad = lane % 4; + bool const enable = (idxInQuad == 0); +#pragma unroll + for (uint32_t m = 0; m < RegRowWiseVec::size; m++) { +#pragma unroll + for (uint32_t i = 0; i < RegRowWiseVec::Elem::size; i++) { + assert(__shfl_sync(~0U, regVec[m][i], idxQuad * 4) == regVec[m][i]); + if (enable) { + smemVec[gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad] = regVec[m][i]; + } + } + } +} + +// for X +// order: 0,8,1,9, 2,10,3,11, 4,12,5,13, 6,14,7,15, ... +__device__ inline void storeGemm0AccToShm(uint32_t warpRank, uint32_t lane, + SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, + Gemm0Acc const& acc) { + uint32_t const idxMat = lane / 8; + uint32_t const idxRow = lane % 8; + barConsumed.arrive_and_wait(); +#pragma unroll + for (uint32_t m = 0; m < Gemm0Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + Vec fp8Data; +#pragma unroll + for (uint32_t n = 0; n < exactDiv(Gemm0Acc::cols, 2); n++) { + reinterpret_cast&>(fp8Data[n]) = { + __nv_fp8x2_e4m3(float2{acc(m, n * 2)(i, 0), acc(m, n * 2 + 1)(i, 0)}), + __nv_fp8x2_e4m3(float2{acc(m, n * 2)(i, 1), acc(m, n * 2 + 1)(i, 1)})}; + } + static_assert(decltype(fp8Data)::size == 4); + stmatrix_4x(this_warp(), + &smemX[m].template at(16 * warpRank + 8 * i + idxRow, idxMat), + fp8Data); + } + } +} +#endif + +#if SWAP_AB +__device__ inline Vec loadVTileTransposed( + uint32_t warpRank, uint32_t lane, SharedMem::VBuffer const& smemV, uint32_t idxGmmaInstK) { + Vec fragA; + constexpr uint32_t instK = gmma::instK; +#pragma unroll + for (uint32_t i = 0; i < gemm1NbGmmaInstM; i++) { + static_assert(exactDiv(gmma::instM, gmmaWarpsPerGrp) == grainBytes); + constexpr uint32_t grainsPerPart = exactDiv(cacheHeadPartBytes, grainBytes); +#if CACHE_ELEM_ENUM == 0 + uint32_t idxRow = lane % 8; + uint32_t idxMat = lane / 8; + uint32_t c = idxMat % 2; + uint32_t r = idxMat / 2; + auto const col = BoundedVal<2 * gmmaWarpsPerGrp * gemm1NbGmmaInstM>{ + 2 * (gmmaWarpsPerGrp * i + warpRank) + c}; + auto const src = &smemV[col.template divBy().get()].template at( + instK * idxGmmaInstK + 8 * r + idxRow, col.template mod().get()); + auto const data = ldmatrix(src); + fragA[i] = reinterpret_cast(data); +#elif CACHE_ELEM_ENUM == 2 + auto const col = BoundedVal{gmmaWarpsPerGrp * i + warpRank}; + LdGrain const* src = &smemV[col.template divBy().get()].template at( + instK * idxGmmaInstK + lane, col.template mod().get()); + auto const data = ldmatrix(src); + fragA[i](0, 0)(0, 0) = prmt(data[0], data[1], {0, 4, 2, 6}); + fragA[i](0, 0)(1, 0) = prmt(data[0], data[1], {1, 5, 3, 7}); + fragA[i](0, 1)(0, 0) = prmt(data[2], data[3], {0, 4, 2, 6}); + fragA[i](0, 1)(1, 0) = prmt(data[2], data[3], {1, 5, 3, 7}); +#endif + } + return fragA; +} +#else +__device__ inline void transposeVTile(uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, + SharedMem::VBuffer const& src) { + uint32_t const idxMat = lane / 8; + uint32_t const idxRow = lane % 8; +#pragma unroll + for (uint32_t m = 0; m < exactDiv(SharedMem::VTBuffer::rows, gmma::instM); m++) { + static_assert(cacheHeadPartElems >= gmma::instM); + uint32_t const idxPart = gmma::instM * m / cacheHeadPartElems; + constexpr uint32_t grainsPerCacheHeadPart = exactDiv(cacheHeadPartElems, cacheElemsPerGrain); +#pragma unroll + for (uint32_t n = 0; n < exactDiv(SharedMem::VTBuffer::cols, 2); n++) { + LdGrain const a = ldmatrix_4x( + this_warp(), &src[idxPart].template at( + 32 * n + lane, exactDiv(gmma::instM, cacheElemsPerGrain) * m - + grainsPerCacheHeadPart * idxPart + warpRank)); + LdGrain const b = {prmt(a[0], a[1], {0, 4, 2, 6}), prmt(a[0], a[1], {1, 5, 3, 7}), + prmt(a[2], a[3], {0, 4, 2, 6}), prmt(a[2], a[3], {1, 5, 3, 7})}; + uint32_t const i = idxMat % 2; + uint32_t const j = idxMat / 2; + stmatrix_4x( + this_warp(), + &dst.template at(gmma::instM * m + 16 * warpRank + 8 * i + idxRow, 2 * n + j), b); + } + } +} +#endif + +#if SWAP_AB +__device__ inline Vec loadShmColWiseVecNoDup( + ShmQWiseVec const& shmVec) { + Vec ret; +#pragma unroll + for (uint32_t i = 0; i < divUp(ShmQWiseVec::size, warp_size); i++) { + uint32_t const idx = i * warp_size + laneId(); + bool const inBound = ((ShmQWiseVec::size % warp_size == 0) || (idx < ShmQWiseVec::size)); + ret[i] = (inBound ? shmVec[idx] : 0); + } + return ret; +} + +__device__ inline void storeShmColWiseVecNoDup( + ShmQWiseVec& shmVec, Vec const& src) { +#pragma unroll + for (uint32_t i = 0; i < divUp(ShmQWiseVec::size, warp_size); i++) { + uint32_t const idx = i * warp_size + laneId(); + bool const inBound = ((ShmQWiseVec::size % warp_size == 0) || (idx < ShmQWiseVec::size)); + if (inBound) { + shmVec[idx] = src[i]; + } + } +} +#else +__device__ inline Vec +loadShmRowWiseVecNoDup(uint32_t warpRank, ShmQWiseVec const& shmVec) { + constexpr uint32_t const nbElems = exactDiv(ShmQWiseVec::size, gmma::instM) * (gmma::instM / 4); + Vec ret; + uint32_t const lane = laneId(); + uint32_t const idxHalf = lane / (gmma::instM / 4); + uint32_t const idxInHalf = lane % (gmma::instM / 4); +#pragma unroll + for (uint32_t i = 0; i < divUp(nbElems, warp_size); i++) { + uint32_t const idx = + gmma::instM * 2 * i + gmma::instM * idxHalf + (gmma::instM / 4) * warpRank + idxInHalf; + bool const inBound = ((nbElems % warp_size == 0) || (i + 1 < divUp(nbElems, warp_size)) || + (idx < ShmQWiseVec::size)); + ret[i] = (inBound ? shmVec[idx] : 0); + } + return ret; +} + +__device__ inline void storeShmRowWiseVecNoDup( + uint32_t warpRank, ShmQWiseVec& shmVec, + Vec const& src) { + constexpr uint32_t const nbElems = exactDiv(ShmQWiseVec::size, gmma::instM) * (gmma::instM / 4); + Vec ret; + uint32_t const lane = laneId(); + uint32_t const idxHalf = lane / (gmma::instM / 4); + uint32_t const idxInHalf = lane % (gmma::instM / 4); +#pragma unroll + for (uint32_t i = 0; i < divUp(nbElems, warp_size); i++) { + uint32_t const idx = + gmma::instM * 2 * i + gmma::instM * idxHalf + (gmma::instM / 4) * warpRank + idxInHalf; + bool const inBound = ((nbElems % warp_size == 0) || (i + 1 < divUp(nbElems, warp_size)) || + (idx < ShmQWiseVec::size)); + if (inBound) { + shmVec[idx] = src[i]; + } + } +} +#endif + +#if SWAP_AB +__device__ inline void rescaleGemm1AccForNewColMax_sync( + uint32_t warpRank, ShmQWiseVec const& shmXColMax, ShmQWiseVec const (&shmXColSum)[gemm0NbWarps], + ShmQWiseVec& shmAccColMax, Gemm1Acc& acc, ShmQWiseVec& shmAccColSum, + CtaBarrier& gemm1WarpGrpBar) { + auto accColSum = loadShmColWiseVecNoDup(shmAccColSum); + + auto const xColMax = loadShmColWiseVecNoDup(shmXColMax); + auto const accColMax = loadShmColWiseVecNoDup(shmAccColMax); + auto token = gemm1WarpGrpBar.arrive(); + auto const needRescaleVec = (accColMax < xColMax); + UniformNeedRescaleMask rescaleMask; + bool anyNeedRescale = false; +#pragma unroll + for (uint32_t i = 0; i < rescaleMask.size; i++) { + assert(accColMax[i] <= xColMax[i]); + rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]); + anyNeedRescale = anyNeedRescale || (rescaleMask[i] != 0); + } + if (anyNeedRescale) { + auto const scaleVec = expf(accColMax - xColMax); + auto const lane = laneId(); +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { + uint32_t const vecIdx = gmma::instNBase * n / warp_size; + uint32_t const offset = gmma::instNBase * n % warp_size; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + auto const mask = ((rescaleMask[vecIdx] >> (offset + j)) & 0b01010101U); + auto getScale = [&] { + return __shfl_sync(~0U, scaleVec[vecIdx], + offset + lane % nbThrdsPerInstNBase * GmmaAccCoreMat::cols + j); + }; + assert((getScale() != 1) == + ((mask >> lane % nbThrdsPerInstNBase * GmmaAccCoreMat::cols) & 0x1U)); + bool const needRescale = (mask != 0); + if (!needRescale) { // this branch is warp-uniform + continue; + } + float const scale = getScale(); +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + acc(m, n)(i, j) *= scale; + } + } + } + } + accColSum = accColSum * scaleVec; + } + gemm1WarpGrpBar.wait(mha::move(token)); + + // @fixme: with atomic, we can let the first warp reaching here to do the update, instead of + // always warp 3. + uint32_t const warpRankForUpdate = gmmaWarpsPerGrp - 1; + if (warpRank == warpRankForUpdate) { + if (anyNeedRescale) { + storeShmColWiseVecNoDup(shmAccColMax, xColMax); + } +#pragma unroll + for (uint32_t i = 0; i < gemm0NbWarps; i++) { + accColSum = accColSum + loadShmColWiseVecNoDup(shmXColSum[i]); + } + storeShmColWiseVecNoDup(shmAccColSum, accColSum); + } + gemm1WarpGrpBar.arrive_and_wait(); +} +#else +__device__ inline void rescaleGemm1AccForNewRowMax_sync(uint32_t warpRank, + ShmQWiseVec const& shmXRowMax, + ShmQWiseVec const& shmXRowSum, + ShmQWiseVec& shmAccRowMax, Gemm1Acc& acc, + ShmQWiseVec& shmAccRowSum) { + auto accRowSum = loadShmRowWiseVecNoDup(warpRank, shmAccRowSum); + auto const xRowMax = loadShmRowWiseVecNoDup(warpRank, shmXRowMax); + auto const accRowMax = loadShmRowWiseVecNoDup(warpRank, shmAccRowMax); + assert(all(xRowMax >= accRowMax)); + auto const needRescaleVec = (accRowMax < xRowMax); + UniformNeedRescaleMask rescaleMask; + bool anyNeedRescale = false; +#pragma unroll + for (uint32_t i = 0; i < rescaleMask.size; i++) { + assert(accRowMax[i] <= xRowMax[i]); + rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]); + anyNeedRescale = anyNeedRescale || (rescaleMask[i] != 0); + } + + if (anyNeedRescale) { + auto const scaleVec = expf(accRowMax - xRowMax); + auto const lane = laneId(); +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint8_t const mask = reinterpret_cast(rescaleMask[m / 2])[m % 2][i]; + bool const needRescale = (mask != 0); + if (needRescale) { // this branch is warp-uniform + float const scale = __shfl_sync(~0U, scaleVec[m / 2], 16 * (m % 2) + 8 * i + lane / 4); +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + acc(m, n)(i, j) *= scale; + } + } + } + } + } + accRowSum = accRowSum * scaleVec; + } + __syncwarp(); + auto const xRowSum = loadShmRowWiseVecNoDup(warpRank, shmXRowSum); + storeShmRowWiseVecNoDup(warpRank, shmAccRowSum, accRowSum + xRowSum); + storeShmRowWiseVecNoDup(warpRank, shmAccRowMax, xRowMax); + __syncwarp(); +} +#endif + +#if SWAP_AB +__device__ inline void rescaleAcc(Gemm1Acc& acc, RegColWiseVec const& scale) { +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + acc(m, n)(i, j) *= scale[n][j]; + } + } + } + } +} +#else +__device__ inline void rescaleAcc(Gemm1Acc& acc, RegRowWiseVec const& scale) { +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + acc(m, n)(i, j) *= scale[m][i]; + } + } + } + } +} +#endif + +#if SWAP_AB +// @fixme: consider make this noinline +template +__device__ inline void saveTransposedOutput(uint32_t threadRank, uint32_t warpRank, DstHead* dst, + SharedMem::OutSwizzleBuf& swizzleBuf, + Gemm1Acc const& acc, CtaBarrier& warpGrpBar, + uint32_t nbKHeads) { + uint32_t const lane = laneId(); +#if CACHE_ELEM_ENUM == 0 + uint32_t const idxMat = lane / 8; + uint32_t const idxRow = lane % 8; +#elif CACHE_ELEM_ENUM == 2 + uint32_t const idxQuad = lane / 4; + uint32_t const idxInQuad = lane % 4; +#endif +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { + auto const& core = acc(m, n); +#if CACHE_ELEM_ENUM == 0 + Vec f16Core; + reinterpret_cast&>(f16Core) = + convert(reinterpret_cast const&>(core)); + auto const dst = idxMat < 2 + ? &swizzleBuf.template at( + 8 * n + idxRow, 2 * (gmmaWarpsPerGrp * m + warpRank) + idxMat) + : nullptr; + stmatrix(dst, f16Core); +#elif CACHE_ELEM_ENUM == 2 + // each row is part of a b16 8x8 matrix and is transposed + Array2D coreTrans; + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + static_assert(GmmaAccCoreMat::cols == 2 && sizeof(InputElem) == 2); + InputElem2 const coreRow = float2ToInputElem2({core(i, 0), core(i, 1)}); + auto const coreRowTrans = movmatrix(reinterpret_cast(coreRow)); + reinterpret_cast(coreTrans(i, 0)) = coreRowTrans; + } + // expect compiler to generate two PRMT instructions + Vec const data = {coreTrans(0, 0), coreTrans(1, 0), coreTrans(0, 1), + coreTrans(1, 1)}; + swizzleBuf.template at( + gmma::instNBase * n + idxQuad, + (gmma::instM * m + exactDiv(gmma::instM, gmmaWarpsPerGrp) * warpRank) / 16)[idxInQuad] = + data; +#endif + } + } + warpGrpBar.arrive_and_wait(); + + constexpr uint32_t headsPerIter = exactDiv(grainBytes * gemm1NbThrds, paddedInputHeadBytes); + constexpr uint32_t nbIters = divUp(ctaNbValidQHeads, headsPerIter); + constexpr uint32_t nbWholeIters = ctaNbValidQHeads / headsPerIter; + constexpr uint32_t nbGrainsPerHead = exactDiv(paddedInputHeadBytes, grainBytes); + uint32_t const idxHeadBase = threadRank / nbGrainsPerHead; + uint32_t const idxGrain = threadRank % nbGrainsPerHead; +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t const idxHead = idxHeadBase + iter * headsPerIter; + if ((iter < nbWholeIters || idxHead < ctaNbValidQHeads) && + (!isHeadPadded || idxGrain < grainsPerIOHead)) { +#if CACHE_ELEM_ENUM == 0 + auto const data = swizzleBuf.template at(idxHead, idxGrain); +#elif CACHE_ELEM_ENUM == 2 + auto const data = reinterpret_cast&>( + swizzleBuf.template at(idxHead, idxGrain / 2))[idxGrain % 2]; +#endif + constexpr uint32_t inputElemsPerGrain = exactDiv(grainBytes, inputElemSize); + auto const outVec = convert( + reinterpret_cast const&>(data)); + uint32_t dstHeadIdx = idxHead; +#ifdef SPEC_Q_SEQ_LEN + if constexpr (dstIsStrided) { + uint32_t const idxToken = idxHead / headGrpSize; + if (idxToken < SPEC_Q_SEQ_LEN) { + uint32_t const strideBetweenTokens = nbKHeads * headGrpSize; + dstHeadIdx = idxToken * strideBetweenTokens + (idxHead % headGrpSize); + } + } +#endif + reinterpret_cast, nbGrainsPerHead>&>( + dst[dstHeadIdx])[idxGrain] = outVec; + } + } +} + +template +__device__ inline void finalizeAndWriteOut_sync( + uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, + Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, ShmQWiseVec const& accColSum, + ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, uint32_t nbKHeads) { + // @fixme: if ctaNbQHeads is large, use loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of + // mufu.rcp static_assert(ctaNbQHeads <= 8, "Warning: consider using loadShmColWiseVecNoDup + rcp + // + shfl to avoid 8x waste of mufu.rcp"); + auto regColSum = loadShmColWiseVecWithDup(accColSum); + if (attentionSinksVec != nullptr) { + auto const regAccColMax = loadShmColWiseVecWithDup(accColMax); + auto const regAttentionSinks = loadGmemColWiseVecWithDup(attentionSinksVec[0], headGrpSize - 1); + auto regColSinks = expf(regAttentionSinks - regAccColMax); + regColSum = regColSum + regColSinks; + } + auto const regOutScale = __frcp_rn(regColSum) * xvoScale; + rescaleAcc(acc, regOutScale); + + saveTransposedOutput(threadRank, warpRank, dst, swizzleBuf, acc, + warpGrpBar, nbKHeads); + warpGrpBar.arrive_and_wait(); +} +#else +template +__device__ inline void finalizeAndWriteOut_sync( + uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, + float xvoScale, ShmQWiseVec const& accRowSum, + uint32_t nbKHeads /* for spec dec. set to 1 for workspace*/, uint32_t ctaNbValidTokens) { + auto const regRowSum = loadShmRowWiseVecWithDup(warpRank, accRowSum); + auto const regOutScale = __frcp_rn(regRowSum) * xvoScale; + rescaleAcc(acc, regOutScale); + + using DstElem = typename DstHead::Elem; + auto const lane = laneId(); + uint32_t const idxQuad = lane / 4; + uint32_t const idxInQuad = lane % 4; + using Atom = Vec, 4>; + using SwizzleBuf = Array2D, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>; + static_assert(sizeof(SwizzleBuf) <= sizeof(swizzleBuf)); + auto& buf = reinterpret_cast(swizzleBuf); +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint32_t const r = gmma::instM * m + 16 * warpRank + 8 * i + idxQuad; + static_assert(SwizzleBuf::cols == exactDiv(Gemm1Acc::cols, 2)); +#pragma unroll + for (uint32_t n = 0; n < exactDiv(Gemm1Acc::cols, 2); n++) { + Vec const v = + convert(Vec{acc(m, n * 2)(i, 0), acc(m, n * 2 + 1)(i, 0), + acc(m, n * 2)(i, 1), acc(m, n * 2 + 1)(i, 1)}); + //@fixme: without reinterpret_cast to V, the compiler generates wrong code, and require a + //__syncwarp() + // after rescaleAcc() to work around. Likely a bug of the compiler. + //@todo: report a compiler bug. + using V = Vec; + reinterpret_cast(buf.template at(r, n)[idxInQuad]) = + reinterpret_cast(v); + // buf.template at(r, n)[idxInQuad] = v; + } + } + } + __syncwarp(); + +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { + constexpr uint32_t srcHeadBytes = sizeof(DstElem) * headElems; + constexpr uint32_t grpSize = exactDiv(srcHeadBytes, grainBytes); + constexpr uint32_t nbGrps = exactDiv(warp_size, grpSize); + uint32_t const idxGrp = lane / grpSize; + constexpr uint32_t grainsPerAtom = exactDiv(sizeof(Atom), grainBytes); + uint32_t const rowBase = gmma::instM * m + 16 * warpRank; + constexpr uint32_t totalNbGrains = grainsPerAtom * SwizzleBuf::cols * 16; + uint32_t const nbIters = divUp(totalNbGrains, nbGrps); + constexpr bool wholeIters = (totalNbGrains % nbGrps == 0); + constexpr bool wholeHeads = (validElemsPerHead == headElems); +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t const idxGrain = nbGrps * iter + idxGrp; + constexpr uint32_t grainsPerSrcHead = exactDiv(srcHeadBytes, grainBytes); + uint32_t const r = idxGrain / grainsPerSrcHead; + if (!wholeIters && r >= 16) { + break; + } + uint32_t const cGrain = idxGrain % grainsPerSrcHead; + uint32_t const cAtom = cGrain / grainsPerAtom; + constexpr uint32_t grainsPerDstHead = exactDiv(sizeof(DstHead), grainBytes); + uint32_t const glbRow = gmma::instM * m + 16 * warpRank + r; + if (ctaNbValidQHeads != ctaNbQHeads && glbRow >= ctaNbValidQHeads) { + break; + } + if (wholeHeads || cGrain < grainsPerDstHead) { + uint32_t const srcRow = rowBase + r; + auto const data = reinterpret_cast( + buf.template at(srcRow, cAtom))[cGrain % grainsPerAtom]; +#if SPEC_DEC + static_assert(beamWidth == 1); + uint32_t const idxToken = srcRow / headGrpSize; // inside CTA + if (idxToken >= ctaNbValidTokens) { + break; + } + uint32_t const tokenPad = headGrpSize * (nbKHeads - 1); + uint32_t const dstRow = srcRow + idxToken * tokenPad; +#else + uint32_t const dstRow = srcRow; +#endif + reinterpret_cast(dst[dstRow])[cGrain] = data; + } + } + } +} +#endif + +template +__device__ inline Vec, ropeNbPairsPerThrd> loadHead( + Vec const& head, uint32_t tid) { + constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); + constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; + constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); + bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); + static_assert(nbPairs % nbPairsPerThrd == 0); + Vec, nbPairsPerThrd> ret; + if constexpr (forNeox) { + auto const& pairs = + reinterpret_cast, nbWorkingThrds>, 2> const&>(head); + auto const data = isWorkingThrd + ? Vec, 2>{pairs[0][tid], pairs[1][tid]} + : Vec, 2>{}; + Vec, 2> const tmp = {convert(data[0]), + convert(data[1])}; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + ret[i][0] = tmp[0][i]; + ret[i][1] = tmp[1][i]; + } + } else { + auto const data = + isWorkingThrd ? reinterpret_cast, nbPairsPerThrd> const*>(&head)[tid] + : Vec, nbPairsPerThrd>{}; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + ret[i] = convert(data[i]); + } + } + return ret; +} + +template +__device__ inline mha::conditional_t, 2>, + Vec, nbPairsPerThrd>> +applyRoPE(Vec, nbPairsPerThrd> const& data, + Vec, nbPairsPerThrd> const& ropeCosSin) { + Vec, nbPairsPerThrd> r; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + float const x = data[i][0]; + float const y = data[i][1]; + float const c = ropeCosSin[i][0]; + float const s = ropeCosSin[i][1]; + r[i] = Vec{c * x - s * y, s * x + c * y}; + } + if constexpr (forNeox) { + Vec, 2> tmp; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + tmp[0][i] = r[i][0]; + tmp[1][i] = r[i][1]; + } + return Vec, 2>{convert(tmp[0]), + convert(tmp[1])}; + } else { + Vec, nbPairsPerThrd> ret; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + ret[i] = convert(r[i]); + } + return ret; + } +} + +template +__device__ inline void storeRotatedPairsForKV( + GMemCacheHead& dst, + mha::conditional_t>, 2>, + Vec, ropeNbPairsPerThrd>> const& src, + uint32_t tid) { + constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); + constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; + constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); + bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); + static_assert(nbPairs % nbPairsPerThrd == 0); + if (!isWorkingThrd) { + return; + } + if constexpr (forNeox) { + auto& pairs = + reinterpret_cast, nbWorkingThrds>, 2>&>(dst); + pairs[0][tid] = src[0]; + pairs[1][tid] = src[1]; + } else { + reinterpret_cast, nbPairsPerThrd>*>(&dst)[tid] = src; + } +} + +template +__device__ inline void storeRotatedPairsForQ( + SharedMem::QBuffer& dst, + mha::conditional_t>, 2>, + Vec, ropeNbPairsPerThrd>> const& src, + uint32_t row, uint32_t tid) { + constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); + constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; + constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); + bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); + static_assert(nbPairs % nbPairsPerThrd == 0); + if (isWorkingThrd) { + if constexpr (forNeox) { +#pragma unroll + for (uint32_t i = 0; i < 2; i++) { + auto const byteOffset = + BoundedVal{cacheElemSize * nbPairsPerThrd * (nbWorkingThrds * i + tid)}; + uint32_t const idxPart = byteOffset.template divBy().get(); + auto const byteOffsetInsidePart = byteOffset.template mod(); + uint32_t const idxGrain = byteOffsetInsidePart.template divBy().get(); + LdGrain& grain = dst[idxPart].template at(row, idxGrain); + uint32_t const byteOffsetInsideGrain = + byteOffsetInsidePart.template mod().get(); + static_assert(cacheElemSize * nbPairsPerThrd <= grainBytes && + grainBytes % (cacheElemSize * nbPairsPerThrd) == 0); + reinterpret_cast&>( + reinterpret_cast(&grain)[byteOffsetInsideGrain]) = src[i]; + } + } else { + auto const byteOffset = BoundedVal{cacheElemSize * 2 * nbPairsPerThrd * tid}; + uint32_t const idxPart = byteOffset.template divBy().get(); + auto const byteOffsetInsidePart = byteOffset.template mod(); + uint32_t const idxGrain = byteOffsetInsidePart.template divBy().get(); + LdGrain& grain = dst[idxPart].template at(row, idxGrain); + uint32_t const byteOffsetInsideGrain = byteOffsetInsidePart.template mod().get(); + static_assert(cacheElemSize * 2 * nbPairsPerThrd <= grainBytes && + grainBytes % (cacheElemSize * 2 * nbPairsPerThrd) == 0); + reinterpret_cast, nbPairsPerThrd>&>( + reinterpret_cast(&grain)[byteOffsetInsideGrain]) = src; + } + } + static_assert(validElemsPerHead % 16 == 0); + __syncwarp(); + if constexpr (validElemsPerHead < headElems) { + static_assert(validElemsPerHead >= headElems - exactDiv(headElems, nbQParts)); + constexpr uint32_t nbPadGrainsPerHead = + exactDiv(headElems - validElemsPerHead, cacheElemsPerGrain); + constexpr uint32_t nbPadGrains = nbPadGrainsPerHead * ctaNbQHeads; + uint32_t const nbIters = divUp(nbPadGrains, nbThrds); +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t idx = tid + nbThrds * iter; + if (idx >= nbPadGrains) { + break; + } + uint32_t const r = idx / nbPadGrainsPerHead; + uint32_t const c = grainsPerQPart - nbPadGrainsPerHead + idx % nbPadGrainsPerHead; + dst[dst.size - 1].template at(r, c) = LdGrain{}; + } + } +} + +#ifndef GENERATE_CUBIN +void launchHopperF8MHA( + cudaDeviceProp const& prop, uint32_t nbKHeads, +#if SLIDING_WINDOW + uint32_t slidingWinSize, +#endif + float qScale, OutputHead* output, +#if LOW_PREC_OUTPUT + float const* rcpOutScale, +#endif +#if USE_INPUT_KV + InputHead const* qkv, +#if ROPE_STYLE != 0 + Vec const* ropeCosSin, +#endif +#else + InputHead const* q, +#endif + float const* attentionSinks, // [headGrpSize] +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, +#else + GMemCacheHead* pool, // global pool of pages +#endif + KVCachePageIndex const* + kvCachePageList, // device pointer. shape: + // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] +#else + GMemKVCacheHead* kvCacheData, +#endif + uint32_t maxSeqLen, uint32_t const* seqLen, +#if USE_BEAM_SEARCH + BeamSearchParams const& beamSearchParams, +#endif + uint32_t batchSize, + float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. + // Used only for int8/fp8 KV cache. +#if SPEC_DEC + SpecDecParams const& specDecParams, +#endif + uint32_t* semaphores, void* scratch, cudaStream_t stream) { + if (beamWidth != 1) { + throw std::runtime_error("not implemented"); + } + static uint32_t const hostSmemSize = [&]() { + uint32_t size; + checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); + checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); + return size; + }(); + // printf("smemSize = %u\n", hostSmemSize); + uint32_t const nbVHeads = nbKHeads; + uint32_t const nbQHeads = nbKHeads * headGrpSize; + uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads; + uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { + auto const env = std::getenv("XQA_NB_SUB_SEQ"); + if (env != nullptr) { + int32_t const val = std::stoi(env); + if (val > 0) { + return val; + } + } + float const factor = 0.25f; + return mha::min( + mha::max( + 1U, (uint32_t)round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)), + divUp(maxSeqLen, gemm0CtaTileNbTokens)); + }(); +#if SPEC_DEC + uint32_t const qSeqLen = specDecParams.qSeqLen; +#else + uint32_t const qSeqLen = 1; +#endif + // gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x == + // nbInputSeqSplit + dim3 const dimGrid{divUp(qSeqLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; + dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); +#if USE_PAGED_KV_CACHE + uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); + auto const dtype = [] { + if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } + throw std::runtime_error("unsupported cache element type"); + }(); + +#if PAGED_KV_CACHE_LAYOUT == 1 + KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, + maxNbPagesPerSeq}; + + auto const tensorMapVLLMK = + makeTensorMapForPagedKVCache(kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); + auto const tensorMapVLLMV = + makeTensorMapForPagedKVCache(vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); +#else + KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; + auto const tensorMap = + makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); +#endif + + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, nbKHeads, +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif +#if USE_INPUT_KV + qkv, +#if ROPE_STYLE != 0 + ropeCosSin, +#endif +#else + q, +#endif + attentionSinks, cacheList, +#if USE_BEAM_SEARCH + beamSearchParams, +#endif + batchSize, kvCacheScale, +#if PAGED_KV_CACHE_LAYOUT == 1 + tensorMapVLLMK, tensorMapVLLMV, +#else + tensorMap, +#endif +#if SPEC_DEC + specDecParams, +#endif + semaphores, scratch); +#else + KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; + static_assert(!usePagedKVCache); + assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); + auto const tensorMap = makeTensorMapForContiguousKVCache( + kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, + batchSize, cacheHeadPartElems, gemm0CtaTileNbTokens); + cudaError_t const err = + cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif +#if USE_INPUT_KV + qkv, +#if ROPE_STYLE != 0 + ropeCosSin, +#endif +#else + q, +#endif + attentionSinks, cacheList, +#if USE_BEAM_SEARCH + beamSearchParams, +#endif + batchSize, kvCacheScale, tensorMap, semaphores, scratch); +#endif + checkCuda(err); +} +#endif + +void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, + uint32_t slidingWinSize, float qScale, OutputHead* output, +#if LOW_PREC_OUTPUT + float const* rcpOutScale, +#endif + InputHead const* q, float const* attentionSinks, +#if PAGED_KV_CACHE_LAYOUT == 1 + GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, +#else + GMemCacheHead* pool, +#endif + KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, + uint32_t const* seqLen, uint32_t batchSize, + float const* __restrict__ kvCacheScale, +#if SPEC_DEC + uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, +#endif + uint32_t* semaphores, void* scratch, cudaStream_t stream) { + static uint32_t const hostSmemSize = [&]() { + uint32_t size; + checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); + checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); + return size; + }(); + uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { + float const factor = 0.25f; + return mha::min( + mha::max( + 1U, (uint32_t)round(multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)), + divUp(maxSeqLen, gemm0CtaTileNbTokens)); + }(); +#if SPEC_DEC + auto specDecParams = SpecDecParams{qSeqLen, qCuSeqLens, mask}; + uint32_t const qLen = qSeqLen; +#else + uint32_t const qLen = 1; +#endif + dim3 const dimGrid{divUp(qLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; + dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); +#if USE_PAGED_KV_CACHE + uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); + auto const dtype = [] { + if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } + throw std::runtime_error("unsupported cache element type"); + }(); + +#if PAGED_KV_CACHE_LAYOUT == 1 + KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, + maxNbPagesPerSeq}; + + auto const tensorMapVLLMK = + makeTensorMapForPagedKVCache(kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); + auto const tensorMapVLLMV = + makeTensorMapForPagedKVCache(vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); +#else + KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; + auto const tensorMap = + makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); +#endif + + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, nbKHeads, +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif + q, attentionSinks, cacheList, batchSize, kvCacheScale, +#if PAGED_KV_CACHE_LAYOUT == 1 + tensorMapVLLMK, tensorMapVLLMV, +#else + tensorMap, +#endif +#if SPEC_DEC + specDecParams, +#endif + semaphores, scratch); +#else + KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; + static_assert(!usePagedKVCache); + assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); + auto const tensorMap = makeTensorMapForContiguousKVCache( + kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, + batchSize, cacheHeadPartElems, gemm0CtaTileNbTokens); + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif + q, attentionSinks, cacheList, batchSize, kvCacheScale, + tensorMap, semaphores, scratch); +#endif + checkCuda(err); +} +#endif diff --git a/csrc/xqa/mla_sm120.cu b/csrc/xqa/mla_sm120.cu new file mode 100644 index 0000000000..d3f0089722 --- /dev/null +++ b/csrc/xqa/mla_sm120.cu @@ -0,0 +1,1958 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "defines.h" +#include "mha.h" +#if IS_MLA +#include "barriers.cuh" +#include "mhaUtils.cuh" +#include "mha_components.cuh" +#include "mha_stdheaders.cuh" +#include "mla_sm120.cuh" +#include "mma.cuh" +#include "tma.h" +#include "utils.cuh" +#include "utils.h" + +#ifndef GENERATE_CUBIN +#include + +#include "hostUtils.h" +#include "tensorMap.h" +#endif + +#define USE_REG_Q 1 + +__constant__ constexpr XQAKernelType kernelType = XQAKernelType::kSM120_MLA; + +inline constexpr bool allowMultipleInputTokens = true; + +inline constexpr uint32_t partElemsK = 64; // @fixme: change this to 128 to save L2 traffic +inline constexpr uint32_t nbKParts = exactDiv(validElemsPerKHead, partElemsK); +inline constexpr uint32_t nbQParts = nbKParts; + +inline constexpr uint32_t tokensPerTile = 64; +inline constexpr uint32_t partElemsV = 128; +inline constexpr uint32_t nbVSplit = 2; +inline constexpr uint32_t gemm1V = exactDiv(validElemsPerVHead, nbVSplit); +inline constexpr uint32_t nbProducerCtasPerCga = nbVSplit; + +inline constexpr uint32_t multiBlockMinNbTilesPerCta = 2; +inline constexpr uint32_t multiBlockMinNbTiles = multiBlockMinNbTilesPerCta * 2; + +using MathElem = CacheElem; +inline constexpr uint32_t mathElemBytes = sizeof(MathElem); +inline constexpr uint32_t grainsPerPartK = exactDiv(partElemsK * mathElemBytes, grainBytes); + +inline constexpr uint32_t grainElems = exactDiv(grainBytes, mathElemBytes); + +inline constexpr float xScale = 1.f / kE4M3_MAX; +__constant__ constexpr float rcpXScale = kE4M3_MAX; + +inline constexpr uint32_t nbRegsForIOWarps = 32; +inline constexpr uint32_t nbRegsForMathWarps = 232; + +inline constexpr bool computeRowSumFromF8 = true; + +struct KVTilePartLoader { +#if USE_PAGED_KV_CACHE + static_assert(tokensPerPage % tokensPerTile == 0 || tokensPerTile % tokensPerPage == 0); + static inline constexpr uint32_t nbPagesPerTile = + tokensPerTile >= tokensPerPage ? exactDiv(tokensPerTile, tokensPerPage) : 1; +#endif + + static inline constexpr uint32_t const nbKHeads = 1; + KVCacheList const& cacheList; + uint32_t const idxReq; + static inline constexpr uint32_t const idxHeadGrp = 0; + + CUtensorMap const& tensorMap; + // if greater than 1, then we need unrolling for the loading loop. Seems 1 is fine for latency. + static inline constexpr uint32_t nbPageBuffers = 1; +#if USE_PAGED_KV_CACHE + uint32_t const nbPages; // for bound check + Vec pageBuffers[nbPageBuffers]; + uint32_t idxTileRef = ~0U; // idxTile used to load the pages +#endif + uint32_t const baseOffset; + + __device__ KVTilePartLoader(KVCacheList const& cacheList, uint32_t idxReq, + CUtensorMap const& tensorMap +#if USE_PAGED_KV_CACHE + , + uint32_t nbPages +#endif + ); + // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache + template + __device__ void loadData(Array2D& dst, + uint32_t idxTile, uint32_t idxElemBeg, CtaBarrier& bar, + uint32_t idxPageBuf); + + __device__ void loadPages(uint32_t idxTile, uint32_t idxPageBuf); +}; + +__device__ inline KVTilePartLoader::KVTilePartLoader(KVCacheList const& cacheList, + uint32_t idxReq, CUtensorMap const& tensorMap +#if USE_PAGED_KV_CACHE + , + uint32_t nbPages +#endif + ) + : cacheList{cacheList}, + idxReq{idxReq}, + tensorMap{tensorMap} +#if USE_PAGED_KV_CACHE + , + nbPages{nbPages} +#if PAGED_KV_CACHE_LAYOUT == 1 + , + baseOffset{idxReq * cacheList.maxNbPagesPerSeq} +#else + , + baseOffset{((idxReq * beamWidth) * 2) * cacheList.maxNbPagesPerSeq} +#endif +#else + , + baseOffset{(idxReq * beamWidth) * 2} +#endif +{ +#pragma unroll + for (auto& pageBuffer : pageBuffers) { + pageBuffer.fill(kBAD_PAGE_INDEX); + } +} + +// tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache +template +__device__ inline void KVTilePartLoader::loadData( + Array2D& dst, uint32_t idxTile, + uint32_t idxElemBeg, CtaBarrier& bar, uint32_t idxPageBuf) { + static_assert(nbTokens == tokensPerTile); +#if USE_PAGED_KV_CACHE + assert(idxTile == idxTileRef); + auto const& pages = pageBuffers[idxPageBuf]; + if constexpr (nbTokens < tokensPerPage) { + assert(nbPagesPerTile == 1); + uint32_t const offset = nbTokens * (idxTile % exactDiv(tokensPerPage, nbTokens)); + if (warpElectSync()) { +#if PAGED_KV_CACHE_LAYOUT == 1 + tma::loadAsync(&dst, tensorMap, DimsLE<4>{idxElemBeg, idxHeadGrp, offset, (uint32_t)pages[0]}, + bar); +#else + tma::loadAsync(&dst, tensorMap, DimsLE<4>{idxElemBeg, offset, idxHeadGrp, (uint32_t)pages[0]}, + bar); +#endif + } + } else { +#pragma unroll + for (uint32_t i = 0; i < nbPagesPerTile; i++) { + if (warpElectSync()) { +#if PAGED_KV_CACHE_LAYOUT == 1 + tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, + DimsLE<4>{idxElemBeg, idxHeadGrp, 0, (uint32_t)pages[i]}, bar); +#else + tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, + DimsLE<4>{idxElemBeg, 0, idxHeadGrp, (uint32_t)pages[i]}, bar); +#endif + } + } + } +#else + if (warpElectSync()) { + tma::loadAsync(&dst, tensorMap, + DimsLE<4>{idxElemBeg, nbTokens * idxTile, idxHeadGrp, baseOffset}, bar); + } +#endif +} + +__device__ inline void KVTilePartLoader::loadPages(uint32_t idxTile, uint32_t idxPageBuf) { +#if USE_PAGED_KV_CACHE + uint32_t const idxPageBeg = tokensPerTile >= tokensPerPage + ? nbPagesPerTile * idxTile + : idxTile / exactDiv(tokensPerPage, tokensPerTile); + auto& pages = pageBuffers[idxPageBuf]; +#pragma unroll + for (uint32_t i = 0; i < nbPagesPerTile; i++) { + uint32_t const idxPage = idxPageBeg + i; + pages[i] = + idxPage < nbPages ? cacheList.kvCachePageList[baseOffset + idxPage] : kBAD_PAGE_INDEX; + } + idxTileRef = idxTile; +#endif +} + +using Mat16x32 = Vec; + +template +class Mat16x32Loader { + public: + using Src = Array2D; + + // default r and c are for mat A. + __device__ inline Mat16x32Loader(Src const& src, uint32_t baseRow, uint32_t idxInstK, + uint32_t r = laneId() % 16, uint32_t c = laneId() / 16) + : src{src}, baseRow{baseRow}, idxInstK{idxInstK}, r{r}, c{c}, basePtr{getPtrRef(0)} { + static_assert((grainBytes * srcCols * qmmaShape.m) % 1024 == 0); + } + + __device__ inline Mat16x32 load(uint32_t idxInstM) const { + return ldmatrix(getPtr(idxInstM)); + } + + template + __device__ inline Vec loadWholeCol() const { + uint32_t const nbInstM = exactDiv(tileM, qmmaShape.m); + Vec ret; +#pragma unroll + for (uint32_t i = 0; i < nbInstM; i++) { + ret[i] = load(i); + } + return ret; + } + + __device__ inline LdGrain const* getPtr(uint32_t idxInstM) const { + return checkedVal(basePtr + idxInstM * qmmaShape.m * srcCols, getPtrRef(idxInstM)); + } + + private: + __device__ inline LdGrain const* getPtrRef(uint32_t idxInstM) const { + return &src.template at(baseRow + idxInstM * qmmaShape.m + r, + idxInstK * exactDiv(qmmaShape.k, grainElems) + c); + } + + Src const& src; + uint32_t const baseRow; + uint32_t const idxInstK; + uint32_t const r; + uint32_t const c; + LdGrain const* const basePtr; +}; + +using InstAcc = Array2D; + +using XBuffer = Array2D; + +struct CgaXBuffer { + XBuffer x; + Vec rowSum; + Vec rowMaxLog2e; +}; + +struct PingPongMutex { + using ShmStorage = CtaBarrier[2]; + ShmStorage& barriers; + uint32_t const idxGrp; + bool skipWait = false; + + static __device__ inline void initStorage(ShmStorage& barriers, uint32_t thrdsPerGrp) { + new (&barriers[0]) CtaBarrier(thrdsPerGrp); + new (&barriers[1]) CtaBarrier(thrdsPerGrp); + barriers[0].arrive(thrdsPerGrp); + } + + __device__ inline PingPongMutex(ShmStorage& shmStorage, uint32_t idxGrp) + : barriers{shmStorage}, idxGrp{idxGrp} {} + + __device__ inline void test_lock(uint32_t iter) { + skipWait = barriers[idxGrp].test_wait_parity(toParity<1>(iter)); + } + + __device__ inline void lock(uint32_t iter) { + if (!skipWait) { + barriers[idxGrp].wait_parity(toParity<1>(iter)); + } + } + + __device__ inline void unlock() { + barriers[idxGrp ^ 1U].arrive(); + skipWait = false; + } +}; + +struct PartialResult { + static constexpr uint32_t nbChunks = 4; + static constexpr uint32_t nbRowsPerChunk = exactDiv(headGrpSize, nbChunks); + + struct Chunk { + Vec data; + Vec rowSum; + Vec rowMaxLog2e; + }; + + Chunk chunks[nbChunks]; +}; + +constexpr uint32_t nbMathWarpsA = 8; +constexpr uint32_t nbComputeWarpsB = 8; +constexpr uint32_t nbMathGrpsA = 2; +constexpr uint32_t nbMathWarpsB = 8; + +constexpr uint32_t nbMultiBlockBufs = 2; +constexpr uint32_t multiBlockMathWarps = 8; + +constexpr bool useRegQ = USE_REG_Q; + +struct SharedMemA { + static inline constexpr uint32_t nbKBufs = 12; + + static inline constexpr uint32_t regQParts = (useRegQ ? 4 : 0); + static inline constexpr uint32_t shmQParts = nbQParts - regQParts; + + using ShmQPart = Array2D; + using ShmKPart = Array2D; + + Vec q; + ShmKPart k[nbKBufs]; + + // single buffer reused by two groups. sendX() warp will arbitrate the order of x buffer access + // via two xBars. + CgaXBuffer x; + + // scaled by log2e. Write by last CGA iteration (from the other producer CTA) and read by current + // producer CTA. + Vec rowMaxLog2e; + // sync rowMaxLog2e between two producer CTAs and .consumed means the buffer for next iteration + // (in next producer) is ready. The 4 groups from 2 producers CTAs form a ring + CgaBarrier rowMaxLog2eBar[nbMathGrpsA]; + + PingPongMutex::ShmStorage tensorCoreMutex; + + CtaBarrierPair kBars[nbKBufs]; + static inline constexpr uint32_t nbXBars = nbMathGrpsA; + CtaBarrierPair xBars[nbXBars]; +#if USE_REG_Q + CtaBarrierPair regQBar; +#endif + CtaBarrier shmQBar; + CgaBarrier cgaXBufConsumed; // for X + + CtaBarrierPair multiBlockBars[nbMultiBlockBufs]; + + __device__ inline void invalidateBarriers(uint32_t thrdIdx) { + constexpr uint32_t nbBars = (useRegQ ? 12 : 10) + 2 * (nbKBufs + nbXBars); +#ifndef __CUDACC_RTC__ + constexpr uint32_t nbBarsRef = + exactDiv(offsetof(SharedMemA, qkScaleLog2e) - offsetof(SharedMemA, rowMaxLog2eBar), 8); + static_assert(nbBars == nbBarsRef); +#endif + if (thrdIdx < nbBars) { + reinterpret_cast(&rowMaxLog2eBar[0])[thrdIdx].~CtaBarrier(); + } + } + + __device__ inline Vec& getMultiBlockBufs() { +#ifndef __CUDACC_RTC__ + assert(sizeof(Vec) < + offsetof(SharedMemA, rowMaxLog2eBar)); +#endif + return *reinterpret_cast*>(this); + } + + float qkScaleLog2e; + bool isLastSubSeq; +}; + +struct SharedMemB { + static inline constexpr uint32_t nbXVBufs = 2; + static inline constexpr uint32_t nbXBufs = nbXVBufs; + static inline constexpr uint32_t nbVBufs = nbXVBufs; + + using VBuffer = Vec, + exactDiv(gemm1V, partElemsV)>; + + // x and v are using gemmK=128 per iteration. If we see high pressure on shared memory capacity, + // we can change to 64 in the future. + struct XVBuffer { + VBuffer v; + CgaXBuffer x; + uint8_t + pad[headGrpSize * 128 * 2 - sizeof(VBuffer) - sizeof(CgaXBuffer)]; // for output swizzling + }; + + XVBuffer xv[nbXVBufs]; + + __device__ inline XBuffer& x(uint32_t idx) { return xv[idx].x.x; } + + __device__ inline VBuffer& v(uint32_t idx) { return xv[idx].v; } + + __device__ inline Vec& xRowSum(uint32_t idx) { return xv[idx].x.rowSum; } + + __device__ inline Vec& xRowMaxLog2e(uint32_t idx) { + return xv[idx].x.rowMaxLog2e; + } + + static inline constexpr uint32_t nbAccRowMaxSumCopies = 2; + Vec accRowMaxLog2e[nbAccRowMaxSumCopies]; + Vec accRowSum[nbAccRowMaxSumCopies]; + + CtaBarrierPair xBars[nbXBufs]; + CtaBarrierPair vBars[nbVBufs]; + + CgaBarrier cgaXBufProduced[nbProducerCtasPerCga]; + CtaBarrier mathWarpsBar; + + CtaBarrierPair multiBlockBars[nbMultiBlockBufs]; + + __device__ inline void invalidateBarriers(uint32_t thrdIdx) { + constexpr uint32_t nbBars = 15; +#ifndef __CUDACC_RTC__ + constexpr uint32_t nbBarsRef = + exactDiv(offsetof(SharedMemB, isLastSubSeq) - offsetof(SharedMemB, xBars), 8); + static_assert(nbBars == nbBarsRef); +#endif + if (thrdIdx < nbBars) { + reinterpret_cast(&xBars[0])[thrdIdx].~CtaBarrier(); + } + } + + __device__ inline Vec& getMultiBlockBufs() { +#ifndef __CUDACC_RTC__ + static_assert(sizeof(Vec) < + offsetof(SharedMemB, xBars)); +#endif + return *reinterpret_cast*>(this); + } + + bool isLastSubSeq; +}; + +__device__ void mergePartialOutputs(uint32_t& semaphore, + Vec& dst, + PartialResult const* reqPartialResults, uint32_t nbSubSeq, + uint32_t ctaRank, uint32_t warpRank, uint2 warpIdx, + void* sharedMem); + +struct KernelArgs { + CUtensorMap const& tensorMapQ; // MhaIOHead[nbQHeads * totalNbInputTokens] + CUtensorMap const& tensorMapK; + CUtensorMap const& tensorMapV; + float const& qScale; + OutputHead* __restrict__ const& output; // [totalNbIntputTokens][nbQHeads] + KVCacheList const& cacheList; + uint32_t const& batchSize; + float const* __restrict__ const& kvCacheScale; // Device memory scalar. Same scale for K and V + // cache. Used only for int8/fp8 KV cache. + Vec* __restrict__ const& + cgaXBuf; // [totalNbInputTokens][maxNbSubSeq] + uint32_t* __restrict__ const& semaphores; // [totalNbInputTokens] + PartialResult* __restrict__ const& partialResults; // [totalNbInputTokens][maxNbSubSeq] +}; + +struct Producer { + static inline constexpr uint32_t nbMathGrps = nbMathGrpsA; + static inline constexpr uint32_t nbMathWarps = nbMathWarpsA; + static inline constexpr uint32_t nbMathThrds = nbMathWarps * warp_size; + static inline constexpr uint32_t warpsPerGrp = exactDiv(nbMathWarps, nbMathGrps); + static inline constexpr uint32_t thrdsPerGrp = warpsPerGrp * warp_size; + static inline constexpr uint2 warpTile = {tokensPerTile, exactDiv(headGrpSize, warpsPerGrp)}; + using WarpAcc = WarpAccT; + using ThrdRegRowMax = ThrdRegRowMaxT; + using QuadRegRowMax = QuadRegRowMaxT; + + KernelArgs const& args; + SharedMemA& smem; + uint32_t const maxNbSubSeq; + uint32_t const idxReq; + uint32_t const idxInputTokenGlobal; + uint32_t const nbSubSeq; + uint32_t const idxSubSeq; + uint32_t const seqLen; + uint32_t const ctaRank; + uint32_t const warpRank; + uint2 const warpIdx; + + __device__ inline Producer(KernelArgs const& args, SharedMemA& smem, uint32_t const maxNbSubSeq, + uint32_t const idxReq, uint32_t idxInputTokenGlobal, + uint32_t const seqLen, uint32_t const nbSubSeq, + uint32_t const idxSubSeq, uint32_t ctaRank, uint32_t const warpRank, + uint2 const warpIdx) + : args(args), + smem(smem), + maxNbSubSeq(maxNbSubSeq), + idxReq(idxReq), + idxInputTokenGlobal(idxInputTokenGlobal), + seqLen(seqLen), + nbSubSeq(nbSubSeq), + idxSubSeq(idxSubSeq), + ctaRank(ctaRank), + warpRank(warpRank), + warpIdx(warpIdx) { +#ifndef NDEBUG + if (threadIdx.x == 0) { + asm("st.bulk.weak [%0], %1, 0;\n" ::"l"(&smem), "n"(sizeof(SharedMemA)) : "memory"); + } + __syncthreads(); +#endif + if (threadIdx.x == 0) { + smem.qkScaleLog2e = args.qScale * args.kvCacheScale[0] * log2e; + } + + if (threadIdx.x < headGrpSize) { + smem.rowMaxLog2e[threadIdx.x] = safeInitRowMax; + } + if (warpElectSync()) { + if (warpRank < SharedMemA::nbKBufs) { + auto& b = smem.kBars[warpRank]; + b.initialize(1, thrdsPerGrp); + b.consumed.arrive(thrdsPerGrp); + } + if (warpRank < SharedMemA::nbXBars) { + auto& b = smem.xBars[warpRank]; + b.initialize(thrdsPerGrp, 1); + } +#if USE_REG_Q + if (warpRank == 0) { + smem.regQBar.initialize(1, nbMathThrds); + smem.regQBar.consumed.arrive(nbMathThrds); + } +#endif + if (warpRank < nbMathGrpsA) { + auto& b = smem.rowMaxLog2eBar[warpRank]; + init(&b, thrdsPerGrp); + } + if (ctaRank == 0 && warpRank == 0) { + smem.rowMaxLog2eBar[0].arrive(thrdsPerGrp); + } + if (warpRank == 0) { + init(&smem.shmQBar, 1); + init(&smem.cgaXBufConsumed, 1 * nbVSplit); + smem.cgaXBufConsumed.arrive(1 * nbVSplit); + PingPongMutex::initStorage(smem.tensorCoreMutex, thrdsPerGrp); + } + if (nbSubSeq > 1 && warpRank < nbMultiBlockBufs) { + auto& b = smem.multiBlockBars[warpRank]; + b.initialize(1, warp_size * multiBlockMathWarps); + b.consumed.arrive(warp_size * multiBlockMathWarps); + } + } + clusterBarArrive(); + clusterBarWait(); + } + + __device__ inline ~Producer() { + clusterBarArrive(); + clusterBarWait(); + smem.invalidateBarriers(threadIdx.x); + } + + __device__ inline void run() { + if (warpIdx.y == 2) { // IO warps + asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" ::"n"(nbRegsForIOWarps)); + if (warpIdx.x == 0) { // q + loadQ(); + } else if (warpIdx.x == 1) { // k + loadK(); + } else if (warpIdx.x == 2) { // x + sendX(); + } + } else { // Compute warps + asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" ::"n"(nbRegsForMathWarps)); + compute(); + } + if (nbSubSeq > 1) { + mergePartialOutputs(args.semaphores[idxInputTokenGlobal], + reinterpret_cast&>( + args.output[headGrpSize * idxInputTokenGlobal + + PartialResult::nbRowsPerChunk * ctaRank]), + args.partialResults + maxNbSubSeq * idxInputTokenGlobal, nbSubSeq, + ctaRank, warpRank, warpIdx, &smem); + } + } + + private: + __device__ inline uint32_t iterStride() const { return nbSubSeq * nbProducerCtasPerCga; } + + __device__ inline uint32_t idxTileBeg() const { + return nbProducerCtasPerCga * idxSubSeq + ctaRank; + } + + __device__ inline uint32_t nbTiles() const { return divUp(seqLen, tokensPerTile); } + + __device__ inline SharedMemB& getConsumerShm(uint32_t const idxConsumer) { + return *mapa(reinterpret_cast(&smem), nbProducerCtasPerCga + idxConsumer); + }; + + static constexpr uint32_t regQPartShmBeg = SharedMemA::shmQParts - SharedMemA::regQParts; + + __device__ inline void loadQ() { +#if USE_REG_Q + static_assert(SharedMemA::regQParts <= SharedMemA::shmQParts); + smem.regQBar.consumed.wait_parity(toParity<1>(0)); +#pragma unroll 1 + for (uint32_t i = 0; i < SharedMemA::regQParts; i++) { + if (warpElectSync()) { + tma::loadAsync(&smem.q[regQPartShmBeg + i], args.tensorMapQ, + DimsLE<2>{partElemsK * i, headGrpSize * idxInputTokenGlobal}, + smem.regQBar.produced); + } + } + if (warpElectSync()) { + smem.regQBar.produced.arrive_tx(sizeof(SharedMemA::ShmQPart) * SharedMemA::regQParts); + } +#endif +#pragma unroll 1 + for (uint32_t i = 0; i < SharedMemA::shmQParts; i++) { + uint32_t const idxPart = SharedMemA::regQParts + i; +#if USE_REG_Q + if (i == regQPartShmBeg) { + smem.regQBar.consumed.wait_parity(toParity<1>(1)); + } +#endif + if (warpElectSync()) { + tma::loadAsync(&smem.q[i], args.tensorMapQ, + DimsLE<2>{partElemsK * idxPart, headGrpSize * idxInputTokenGlobal}, + smem.shmQBar); + } + } + if (warpElectSync()) { + smem.shmQBar.arrive_tx(sizeof(SharedMemA::ShmQPart) * SharedMemA::shmQParts); + } + } + + __device__ inline void loadK(); + + __device__ inline void sendX(); + + __device__ inline void compute() { + uint32_t const grpIdx = warpIdx.y; + uint32_t const tileBaseRow = warpTile.y * warpIdx.x; + PingPongMutex tensorCoreMutex{smem.tensorCoreMutex, grpIdx}; + + constexpr uint32_t partNbInstK = exactDiv(partElemsK, qmmaShape.k); + using AtomA = Vec; // for 16x32 data, working as mat A of QMMA.16832 + using RegQPartCol = Vec; + using RegQPart = Vec; + using RegQ = Vec; + constexpr uint32_t tileNbAtomBx2 = exactDiv(tokensPerTile, qmmaShape.n * 2); + using AtomBx2 = Vec; // one AtomB is 8x32 and AtomBx2 is 16x32 + using RegKPartCol = Vec; + using RegKPart = Vec; + + uint32_t const lane = laneId(); + uint32_t const rA = lane % 16; + uint32_t const cA = lane / 16; + uint32_t const rB = (lane / 16) * 8 + lane % 8; + uint32_t const cB = (lane % 16) / 8; + auto loadRegQCol = [&](SharedMemA::ShmQPart const& q, uint32_t idxInstK) -> RegQPartCol { + Mat16x32Loader const loaderQ(q, tileBaseRow, idxInstK, rA, cA); + return loaderQ.loadWholeCol(); + }; + auto loadRegKCol = [&](SharedMemA::ShmKPart const& k, uint32_t idxInstK) -> RegKPartCol { + Mat16x32Loader const loaderK(k, 0, idxInstK, rB, cB); + return loaderK.loadWholeCol(); + }; + auto loadPart = [&](auto const& loadCol, auto const& shmPart) { + mha::conditional_t>, + RegQPart, RegKPart> + regPart; +#pragma unroll + for (uint32_t idxInstK = 0; idxInstK < partNbInstK; idxInstK++) { + regPart[idxInstK] = loadCol(shmPart, idxInstK); + } + return regPart; + }; + +#if USE_REG_Q + // load regQ + smem.regQBar.produced.wait_parity(toParity<1>(0)); + RegQ regQ; +#pragma unroll + for (uint32_t idxPart = 0; idxPart < SharedMemA::regQParts; idxPart++) { + uint32_t const idxBuf = regQPartShmBeg + idxPart; + regQ[idxPart] = loadPart(loadRegQCol, smem.q[idxBuf]); + } + smem.regQBar.consumed.arrive(); +#endif +// main loop +#pragma unroll 1 + for (uint32_t grpIter = 0; true; grpIter++) { + uint32_t const ctaIter = grpIdx + grpIter * nbMathGrps; + uint32_t const idxTile = idxTileBeg() + iterStride() * ctaIter; + if (idxTile >= nbTiles()) { + break; + } + WarpAcc acc{}; + // wait until it's our turn + tensorCoreMutex.lock(grpIter); + BarWaiter kBarWaiter(smem.kBars, ctaIter * nbKParts + 0); + kBarWaiter.testWait(); + RegQPart regQBuf; +#if USE_REG_Q + static_assert(SharedMemA::regQParts > 0); + regQBuf[0] = regQ[0][0]; +#else + regQBuf[0] = loadRegQCol(smem.q[0], 0); +#endif + kBarWaiter.wait(); + RegKPart regKBuf; + regKBuf[0] = loadRegKCol(smem.k[kBarWaiter.idxBuf], 0); + + auto shouldTestWait = [](uint32_t idxInstK, uint32_t idxAtomBx2) { + return idxInstK == partNbInstK - 1 && idxAtomBx2 == tileNbAtomBx2 - 2; + }; + BarWaiter kBarWaiterNext = kBarWaiter.next(); +#if USE_REG_Q +#pragma unroll + for (uint32_t idxPart = 0; idxPart < SharedMemA::regQParts; idxPart++) { +#pragma unroll + for (uint32_t idxInstK = 0; idxInstK < partNbInstK; idxInstK++) { + bool const prefetchNextPart = (idxInstK == partNbInstK - 1); + uint32_t const idxPartPrefetch = prefetchNextPart ? idxPart + 1 : idxPart; + uint32_t const idxInstKPrefetch = prefetchNextPart ? 0 : idxInstK + 1; + bool const prefetch = (!prefetchNextPart || (idxPart < nbKParts - 1)); + + if (prefetchNextPart) { + kBarWaiter = kBarWaiterNext; + kBarWaiterNext = kBarWaiter.next(); + if (prefetch) { + kBarWaiter.wait(); + } + } + + Mat16x32Loader const loaderK(smem.k[kBarWaiter.idxBuf], 0, idxInstKPrefetch, rB, cB); +#pragma unroll + for (uint32_t idxAtomBx2 = 0; idxAtomBx2 < tileNbAtomBx2; idxAtomBx2++) { + if (idxAtomBx2 == 2 && prefetch) { + if (idxPartPrefetch < SharedMemA::regQParts) { + regQBuf[idxInstKPrefetch] = regQ[idxPartPrefetch][idxInstKPrefetch]; + } else { + regQBuf[idxInstKPrefetch] = + loadRegQCol(smem.q[idxPartPrefetch - SharedMemA::regQParts], idxInstKPrefetch); + } + } + AtomBx2 const& atomBx2 = regKBuf[idxInstK][idxAtomBx2]; + regKBuf[idxInstKPrefetch][idxAtomBx2] = loaderK.load(idxAtomBx2); + if (shouldTestWait(idxInstKPrefetch, idxAtomBx2) && prefetch) { + kBarWaiterNext.testWait(); + } +#pragma unroll + for (uint32_t i = 0; i < WarpAcc::rows; i++) { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { + mma<__nv_fp8_e4m3>(reinterpret_cast(acc(i, 2 * idxAtomBx2 + j)), + reinterpret_cast(regQBuf[idxInstK][i]), + reinterpret_cast(atomBx2[2 * j])); + } + } + if (prefetch) { + regKBuf[idxInstKPrefetch][idxAtomBx2] = loaderK.load(idxAtomBx2); + } + } + if (idxInstKPrefetch == partNbInstK - 1) { + assert(prefetch); + kBarWaiter.consumed(); + } + } + } +#endif + if (ctaIter == 0) { + smem.shmQBar.wait_parity(false); + } +#pragma unroll + for (uint32_t idxPart = SharedMemA::regQParts; idxPart < nbQParts; idxPart++) { +#pragma unroll + for (uint32_t idxInstK = 0; idxInstK < partNbInstK; idxInstK++) { + bool const prefetchNextPart = (idxInstK == partNbInstK - 1); + uint32_t const idxPartPrefetch = prefetchNextPart ? idxPart + 1 : idxPart; + uint32_t const idxInstKPrefetch = prefetchNextPart ? 0 : idxInstK + 1; + bool const prefetch = (!prefetchNextPart || (idxPart < nbKParts - 1)); + + if (prefetchNextPart) { + kBarWaiter = kBarWaiterNext; + kBarWaiterNext = kBarWaiter.next(); + if (prefetch) { + kBarWaiter.wait(); + } + } + + Mat16x32Loader const loaderK(smem.k[kBarWaiter.idxBuf], 0, idxInstKPrefetch, rB, cB); +#pragma unroll + for (uint32_t idxAtomBx2 = 0; idxAtomBx2 < tileNbAtomBx2; idxAtomBx2++) { + if (idxAtomBx2 == 2 && prefetch) { + regQBuf[idxInstKPrefetch] = + loadRegQCol(smem.q[idxPartPrefetch - SharedMemA::regQParts], idxInstKPrefetch); + } + AtomBx2 const& atomBx2 = regKBuf[idxInstK][idxAtomBx2]; + if (shouldTestWait(idxInstKPrefetch, idxAtomBx2) && prefetch) { + kBarWaiterNext.testWait(); + } +#pragma unroll + for (uint32_t i = 0; i < WarpAcc::rows; i++) { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { + mma<__nv_fp8_e4m3>(reinterpret_cast(acc(i, 2 * idxAtomBx2 + j)), + reinterpret_cast(regQBuf[idxInstK][i]), + reinterpret_cast(atomBx2[2 * j])); + } + } + if (prefetch) { + regKBuf[idxInstKPrefetch][idxAtomBx2] = loaderK.load(idxAtomBx2); + } + } + if (idxInstKPrefetch == partNbInstK - 1) { + assert(prefetch); + kBarWaiter.consumed(); + if (idxPartPrefetch == nbKParts - 1) { + tensorCoreMutex.unlock(); // let the other group to use tensor cores + } + } + } + } + uint32_t const validTokens = seqLen - tokensPerTile * idxTile; + if (validTokens < tokensPerTile) { + applyMask(this_warp(), acc, 0, validTokens); + } + ThrdRegRowMax rowMaxLog2e; + WarpAcc const xF32 = scaleAndSoftmax(rowMaxLog2e, acc, grpIdx, grpIter, tileBaseRow); + + auto& xBar = smem.xBars[grpIdx]; + bool const skipXBarWait = xBar.consumed.test_wait_parity(toParity<1>(grpIter)); + // convert to fp8 + WarpAcc const xF32Quant = xF32 * rcpXScale; + // 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15 + Array2D, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> xF8; +#pragma unroll + for (uint32_t i = 0; i < WarpAcc::rows; i++) { +#pragma unroll + for (uint32_t m = 0; m < exactDiv(qmmaShape.m, 8); m++) { +#pragma unroll + for (uint32_t j = 0; j < WarpAcc::cols; j += 2) { + auto& dst = reinterpret_cast<__nv_fp8x2_e4m3(&)[2]>(xF8(i, j / 2)(m, 0)); + dst[0] = __nv_fp8x2_e4m3(float2{xF32Quant(i, j)(m, 0), xF32Quant(i, j)(m, 1)}); + dst[1] = __nv_fp8x2_e4m3(float2{xF32Quant(i, j + 1)(m, 0), xF32Quant(i, j + 1)(m, 1)}); + } + } + } + // use tensor core to compute rowSum + ThrdRegRowMax const rowSum = + computeRowSumFromF8 ? computeRowSumF8(this_warp(), xF8) + : computeRowSumF32(this_warp(), xF32); + + // store xF8 and rowSum into L2 scratch buffer + if (!skipXBarWait) { + xBar.consumed.wait_parity(toParity<1>(grpIter)); + } + storeRowMax(smem.x.rowMaxLog2e, rowMaxLog2e, tileBaseRow, lane); + storeRowMax(smem.x.rowSum, rowSum, tileBaseRow, lane); + storeOrderedXToShm(smem.x.x, xF8, tileBaseRow, lane); + xBar.produced.arrive(); + } + } + + __device__ inline WarpAcc scaleAndSoftmax(ThrdRegRowMax& rowMaxLog2e, WarpAcc const& acc, + uint32_t grpIdx, uint32_t grpIter, + uint32_t tileBaseRow); + + __device__ inline void storeOrderedXToShm( + XBuffer& dst, + Array2D, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> const& src, + uint32_t const tileBaseRow, uint32_t const lane = laneId()); +}; + +__device__ inline void Producer::loadK() { + KVTilePartLoader loader{args.cacheList, idxReq, args.tensorMapK +#if USE_PAGED_KV_CACHE + , + divUp(seqLen, tokensPerPage) +#endif + }; + +#pragma unroll 1 + for (uint32_t iter = 0; true; iter++) { + uint32_t const idxTile = idxTileBeg() + iterStride() * iter; + if (idxTile >= nbTiles()) { + break; + } + uint32_t const idxPageBuf = iter % KVTilePartLoader::nbPageBuffers; + loader.loadPages(idxTile, idxPageBuf); +#pragma unroll 1 + for (uint32_t idxPart = 0; idxPart < nbKParts; idxPart++) { + uint32_t const idxPartGlobal = iter * nbKParts + idxPart; + uint32_t const idxBuf = idxPartGlobal % SharedMemA::nbKBufs; + auto& bar = smem.kBars[idxBuf]; + bar.consumed.wait_parity(toParity(idxPartGlobal)); + loader.loadData(smem.k[idxBuf], idxTile, partElemsK * idxPart, bar.produced, idxPageBuf); + if (warpElectSync()) { + bar.produced.arrive_tx(sizeof(SharedMemA::ShmKPart)); + } + } + } +} + +__device__ inline void Producer::sendX() { + // let group 0 to produce first. + if (warpElectSync()) { + smem.xBars[0].consumed.arrive(); + } + for (uint32_t iter = 0; true; iter++) { + uint32_t const idxTile = idxTileBeg() + iterStride() * iter; + if (idxTile >= nbTiles()) { + break; + } + uint32_t const idxBar = iter % SharedMemA::nbXBars; + auto& xBar = smem.xBars[idxBar]; + xBar.produced.wait_parity(toParity(iter)); + smem.cgaXBufConsumed.wait_parity(toParity<1>(iter)); + if (warpElectSync()) { + auto& dst = args.cgaXBuf[nbSubSeq * idxInputTokenGlobal + idxSubSeq][ctaRank]; + tma::store1DAsync(&dst, &smem.x, sizeof(CgaXBuffer)); + tma::commitGroup(); + tma::waitGroup<0>(); + // it's turn for the other math group to produce. + uint32_t const idxBarNext = (iter + 1) % SharedMemA::nbXBars; + auto& xBarNext = smem.xBars[idxBarNext]; + xBarNext.consumed.arrive(); + asm volatile("fence.release.cluster;\n"); +#pragma unroll + for (uint32_t i = 0; i < nbVSplit; i++) { + auto& producedBar = getConsumerShm(i).cgaXBufProduced[ctaRank]; + producedBar.arrive(); + } + } + } +} + +__device__ inline Producer::WarpAcc Producer::scaleAndSoftmax(ThrdRegRowMax& rowMaxLog2e, + WarpAcc const& acc, uint32_t grpIdx, + uint32_t grpIter, + uint32_t tileBaseRow) { + uint32_t const ctaIter = grpIdx + grpIter * nbMathGrps; + uint32_t const cgaIter = ctaRank + ctaIter * nbProducerCtasPerCga; + auto const warp = this_warp(); + uint32_t const lane = laneId(); + uint32_t const idxProducer = ctaRank; + assert(ctaRank < nbProducerCtasPerCga); + + float const qkScaleLog2e = smem.qkScaleLog2e; + bool const skipWaitLastShmRowMax = + smem.rowMaxLog2eBar[grpIdx].test_wait_parity(toParity<1>(grpIter)); + QuadRegRowMax const tileRowMaxLog2e = computeRowMax(acc) * qkScaleLog2e; + // get max with previous CTA's rowMax + if (!skipWaitLastShmRowMax) { + smem.rowMaxLog2eBar[grpIdx].wait_parity(toParity<1>(grpIter)); + } + auto const lastRowMaxLog2e = loadShmRowMax(smem.rowMaxLog2e, tileBaseRow, lane); + + auto const quadRowMaxLog2e = fmaxf(tileRowMaxLog2e, replicateForQuad(warp, lastRowMaxLog2e)); + + // transfer new row max to the other producer CTA for next iteration + SharedMemA& smemNext = mapa(smem, ctaRank ^ 1U); + CgaBarrier& nextRowMaxLog2eBar = + smemNext.rowMaxLog2eBar[(cgaIter + 1) % (nbMathGrps * nbProducerCtasPerCga) / nbMathGrps]; + rowMaxLog2e = dedupFromQuad(warp, quadRowMaxLog2e); + storeRowMaxAsync(nextRowMaxLog2eBar, smemNext.rowMaxLog2e, rowMaxLog2e, tileBaseRow, + lane); + nextRowMaxLog2eBar.arrive_tx_relaxed( + sizeof(rowMaxLog2e)); // notify that the next CTA can read rowMax now. + + WarpAcc x; +// apply softmax +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) { + float const maxVal = quadRowMaxLog2e[m * InstAcc::rows + i]; +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j++) { + float elem = acc(m, n)(i, j); + assert(maxVal >= elem * qkScaleLog2e); + x(m, n)(i, j) = exp2f(elem * qkScaleLog2e - maxVal); + } + } + } + } + + return x; +} + +__device__ inline void Producer::storeOrderedXToShm( + XBuffer& dst, + Array2D, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> const& src, + uint32_t const tileBaseRow, uint32_t const lane) { + uint32_t const r = lane % 16; + uint32_t const c = lane / 16; + using Src = mha::decay_t; + LdGrain* ptrs[exactDiv(Src::cols, 2)][Src::rows]; +#pragma unroll + for (uint32_t idxInstK = 0; idxInstK < exactDiv(Src::cols, 2); idxInstK++) { + Mat16x32Loader const loader(dst, tileBaseRow, idxInstK, r, c); +#pragma unroll + for (uint32_t idxInstM = 0; idxInstM < Src::rows; idxInstM++) { + auto const p = const_cast(loader.getPtr(idxInstM)); + stmatrix(p, reinterpret_cast(src(idxInstM, idxInstK * 2))); + ptrs[idxInstK][idxInstM] = p; + } + } + // reorder from 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15 + // to 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + __syncwarp(); +#pragma unroll + for (uint32_t idxInstK = 0; idxInstK < exactDiv(Src::cols, 2); idxInstK++) { +#pragma unroll + for (uint32_t idxInstM = 0; idxInstM < Src::rows; idxInstM++) { + auto const p = ptrs[idxInstK][idxInstM]; + auto const i = *p; + LdGrain const o = { + prmt(i[0], i[1], PermuteOrder{0, 1, 4, 5}), prmt(i[2], i[3], PermuteOrder{0, 1, 4, 5}), + prmt(i[0], i[1], PermuteOrder{2, 3, 6, 7}), prmt(i[2], i[3], PermuteOrder{2, 3, 6, 7})}; + *p = o; + } + } +} + +struct Consumer { + static inline constexpr uint32_t nbMathWarps = nbMathWarpsB; + static inline constexpr uint32_t nbMathThrds = warp_size * nbMathWarps; + static inline constexpr uint2 ctaShape = {2, 4}; + static_assert(SharedMemB::nbAccRowMaxSumCopies == ctaShape.x); + static_assert(ctaShape.x * ctaShape.y == nbMathWarps); + static inline constexpr uint2 warpTile = {exactDiv(gemm1V, ctaShape.x), + exactDiv(headGrpSize, ctaShape.y)}; + + static inline constexpr uint32_t nbWarpOutSwizzleBuf = nbMathWarps; + using WarpOutSwizzleBuf = + Array2D; + static_assert(WarpOutSwizzleBuf::rows % 8 == 0); + + using WarpAcc = WarpAccT; + using ThrdRegRowMax = ThrdRegRowMaxT; + using UniformNeedRescaleMask = Vec; + + KernelArgs const& args; + SharedMemB& smem; + uint32_t const maxNbSubSeq; + uint32_t const idxReq; + uint32_t const idxInputTokenGlobal; + uint32_t const nbSubSeq; + uint32_t const idxSubSeq; + uint32_t const seqLen; + uint32_t const ctaRank; + uint32_t const warpRank; + uint2 const warpIdx; + + __device__ inline uint32_t iterStride() const { return nbSubSeq * nbProducerCtasPerCga; } + + __device__ inline uint32_t idxTileBeg() const { return nbProducerCtasPerCga * idxSubSeq; } + + __device__ inline uint32_t nbTiles() const { return divUp(seqLen, tokensPerTile); } + + __device__ inline uint32_t idxConsumer() const { return ctaRank - 2; } + + __device__ inline Consumer(KernelArgs const& args, SharedMemB& smem, uint32_t const maxNbSubSeq, + uint32_t const idxReq, uint32_t const idxInputTokenGlobal, + uint32_t const seqLen, uint32_t const nbSubSeq, + uint32_t const idxSubSeq, uint32_t ctaRank, uint32_t const warpRank, + uint2 const warpIdx) + : args(args), + smem(smem), + maxNbSubSeq(maxNbSubSeq), + idxReq(idxReq), + idxInputTokenGlobal(idxInputTokenGlobal), + seqLen(seqLen), + nbSubSeq(nbSubSeq), + idxSubSeq(idxSubSeq), + ctaRank(ctaRank), + warpRank(warpRank), + warpIdx(warpIdx) { +#ifndef NDEBUG + if (threadIdx.x == 0) { + asm("st.bulk.weak [%0], %1, 0;\n" ::"l"(&smem), "n"(sizeof(SharedMemB)) : "memory"); + } + __syncthreads(); +#endif + if (threadIdx.x < headGrpSize) { + for (uint32_t i = 0; i < SharedMemB::nbAccRowMaxSumCopies; i++) { + smem.accRowMaxLog2e[i][threadIdx.x] = safeInitRowMax; + smem.accRowSum[i][threadIdx.x] = 0; + } + } + if (warpElectSync()) { + if (warpRank < nbProducerCtasPerCga) { + init(&smem.cgaXBufProduced[warpRank], 1); + } + if (warpRank < SharedMemB::nbXBufs) { + auto& bar = smem.xBars[warpRank]; + bar.initialize(1, nbMathThrds); + bar.consumed.arrive(nbMathThrds); + } + if (warpRank < SharedMemB::nbVBufs) { + auto& bar = smem.vBars[warpRank]; + bar.initialize(1, nbMathThrds); + bar.consumed.arrive(nbMathThrds); + } + if (warpRank == 0) { + init(&smem.mathWarpsBar, warp_size * nbMathWarps); + } + if (nbSubSeq > 1 && warpRank < nbMultiBlockBufs) { + auto& b = smem.multiBlockBars[warpRank]; + b.initialize(1, warp_size * multiBlockMathWarps); + b.consumed.arrive(warp_size * multiBlockMathWarps); + } + } + clusterBarArrive(); + clusterBarWait(); + } + + __device__ inline ~Consumer() { + clusterBarArrive(); + clusterBarWait(); + smem.invalidateBarriers(threadIdx.x); + } + + __device__ inline void run() { + if (warpIdx.y == 2) { + asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" ::"n"(nbRegsForIOWarps)); + if (warpIdx.x == 0) { + loadX(); + } else if (warpIdx.x == 1) { + loadV(); + } + } else { + asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" ::"n"(nbRegsForMathWarps)); + compute(); + } + if (nbSubSeq > 1) { + mergePartialOutputs(args.semaphores[idxInputTokenGlobal], + reinterpret_cast&>( + args.output[headGrpSize * idxInputTokenGlobal + + PartialResult::nbRowsPerChunk * ctaRank]), + args.partialResults + maxNbSubSeq * idxInputTokenGlobal, nbSubSeq, + ctaRank, warpRank, warpIdx, &smem); + } + } + + __device__ inline void loadX(); + __device__ inline void loadV(); + __device__ inline void compute(); + + __device__ inline uint32_t iterToTile(uint32_t iter) const { + return idxTileBeg() + iterStride() * (iter / 2) + iter % 2; + } + + __device__ inline SharedMemA& getProducerShm(uint32_t idxProducer) const { + return mapa(reinterpret_cast(smem), idxProducer); + } + + using WarpOutputTile = + Array2D; + __device__ inline WarpOutputTile finalize(WarpAcc const& acc, ThrdRegRowMax const& accRowSum, + float xvScale, uint32_t lane = laneId()); + __device__ inline void storeOutput(Vec& dst, uint32_t dstBaseCol, + WarpOutputTile const& regTile, WarpOutSwizzleBuf& swizzleBuf, + uint32_t lane = laneId()); +}; + +__device__ inline void Consumer::compute() { + uint2 const tileIdx = {warpIdx.y, warpIdx.x}; + uint2 const tileBase = {tileIdx.x * warpTile.x, tileIdx.y * warpTile.y}; + + constexpr uint32_t tileNbInstK = exactDiv(tokensPerTile, qmmaShape.k); + constexpr uint32_t warpTileNbAtomBx2 = exactDiv(warpTile.x, qmmaShape.n * 2); + + uint32_t const lane = laneId(); + uint32_t const idxHalf = lane / 16; + uint32_t const laneInHalf = lane % 16; + uint32_t const rA = laneInHalf; + uint32_t const cA = idxHalf; + uint32_t const rB = lane; + uint32_t const cB = 0; + + WarpAcc acc{}; + uint32_t idxXVBufLast{}; + for (uint32_t iter = 0; true; iter++) { + uint32_t const idxTile = iterToTile(iter); + if (idxTile >= nbTiles()) { + break; + } + + ThrdRegRowMax accRowMaxLog2e = + loadShmRowMax(smem.accRowMaxLog2e[tileIdx.x], tileBase.y, lane); + ThrdRegRowMax accRowSum = + loadShmRowMax(smem.accRowSum[tileIdx.x], tileBase.y, lane); + + uint32_t const idxXBuf = iter % SharedMemB::nbXBufs; + uint32_t const idxVBuf = iter % SharedMemB::nbVBufs; + auto& xBar = smem.xBars[idxXBuf]; + auto& vBar = smem.vBars[idxVBuf]; + // @fixme: merge these two barriers and use test_wait_parity() early to avoid latency. + bool const skipVBarWait = vBar.produced.test_wait_parity(toParity(iter)); + xBar.produced.wait_parity(toParity(iter)); + + ThrdRegRowMax const xRowMaxLog2e = + loadShmRowMax(smem.xRowMaxLog2e(idxXBuf), tileBase.y, lane); + assert(all(accRowMaxLog2e <= xRowMaxLog2e)); + + auto const needRescaleVec = (xRowMaxLog2e > accRowMaxLog2e); + UniformNeedRescaleMask rescaleMask{}; +#pragma unroll + for (uint32_t i = 0; i < rescaleMask.size; i++) { + rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]); + } + bool const anyNeedRescale = any(rescaleMask != UniformNeedRescaleMask::filled(0)); + if (anyNeedRescale) { + auto const scaleVec = exp2f(accRowMaxLog2e - xRowMaxLog2e); +#pragma unroll + for (uint32_t m = 0; m < WarpAcc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) { + uint8_t const mask = + reinterpret_cast(rescaleMask[m / 2])[m % 2][i]; + bool const needRescale = (mask != 0); + if (needRescale) { // this branch is warp-uniform + float const scale = __shfl_sync(~0U, scaleVec[m / 2], 16 * (m % 2) + 8 * i + lane / 4); +#pragma unroll + for (uint32_t n = 0; n < WarpAcc::cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j++) { + acc(m, n)(i, j) *= scale; + } + } + } + } + } + accRowSum = accRowSum * scaleVec; + } + accRowMaxLog2e = xRowMaxLog2e; + storeRowMax(smem.accRowMaxLog2e[tileIdx.x], accRowMaxLog2e, tileBase.y, lane); + if (!skipVBarWait) { + vBar.produced.wait_parity(toParity(iter)); + } + auto const& xBuf = smem.x(idxXBuf); + auto const& vBuf = smem.v(idxVBuf)[tileIdx.x]; + auto const xRowSum = loadShmRowMax(smem.xRowSum(idxXBuf), tileBase.y, lane); + accRowSum = accRowSum + xRowSum; + storeRowMax(smem.accRowSum[tileIdx.x], accRowSum, tileBase.y, lane); + +#pragma unroll + for (uint32_t idxInstK = 0; idxInstK < tileNbInstK; idxInstK++) { + Mat16x32Loader const loaderX(xBuf, tileBase.y, idxInstK, rA, cA); + Vec const x = loaderX.loadWholeCol(); + using AtomB = Vec; +#pragma unroll + for (uint32_t idxAtomBx2 = 0; idxAtomBx2 < warpTileNbAtomBx2; idxAtomBx2++) { + auto const data = ldmatrix_16x16_trans<2>( + &vBuf.template at(qmmaShape.k * idxInstK + rB, idxAtomBx2 + cB)); + AtomB const v[2] = {data[0], data[2], data[1], data[3]}; +#pragma unroll + for (uint32_t i = 0; i < WarpAcc::rows; i++) { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { +#if 1 + mma<__nv_fp8_e4m3>( +#else + mmaF8_k32_2inst( +#endif + reinterpret_cast(acc(i, 2 * idxAtomBx2 + j)), + reinterpret_cast(x[i]), + reinterpret_cast(v[j])); + } + } + } + } + bool const isLastIter = (iterToTile(iter + 1) >= nbTiles()); + if (isLastIter) { + idxXVBufLast = idxXBuf; + assert(idxXBuf == idxVBuf); + } else { + xBar.consumed.arrive(); + vBar.consumed.arrive(); + } + } + + smem.mathWarpsBar.arrive(); + + ThrdRegRowMax const accRowSum = + loadShmRowMax(smem.accRowSum[tileIdx.x], tileBase.y, lane); + float const xvScale = computeRowSumFromF8 ? args.kvCacheScale[0] : args.kvCacheScale[0] * xScale; + WarpOutputTile const output = finalize(acc, accRowSum, xvScale, lane); + + bool const isMultiBlockMode = (nbSubSeq != 1); + static_assert(PartialResult::nbRowsPerChunk == warpTile.y); + auto& dst = isMultiBlockMode ? args.partialResults[maxNbSubSeq * idxInputTokenGlobal + idxSubSeq] + .chunks[tileIdx.y] + .data + : reinterpret_cast&>( + args.output[headGrpSize * idxInputTokenGlobal + tileBase.y]); + + assert(warpRank < nbMathWarps); + WarpOutSwizzleBuf& swizzleBuf = reinterpret_cast&>( + smem.xv[idxXVBufLast])[warpRank]; + // make sure all math warps have finished using XVBuffer. + smem.mathWarpsBar.wait_parity(false); + + storeOutput(dst, gemm1V * idxConsumer() + tileBase.x, output, swizzleBuf, lane); + if (isMultiBlockMode && tileIdx.x == 0) { + ThrdRegRowMax const accRowMaxLog2e = + loadShmRowMax(smem.accRowMaxLog2e[tileIdx.x], tileBase.y, lane); + auto& chunk = + args.partialResults[maxNbSubSeq * idxInputTokenGlobal + idxSubSeq].chunks[tileIdx.y]; +#pragma unroll + for (uint32_t i = 0; i < ThrdRegRowMax::size; i++) { + chunk.rowMaxLog2e[warp_size * i + lane] = accRowMaxLog2e[i]; + chunk.rowSum[warp_size * i + lane] = accRowSum[i]; + } + } + smem.xBars[idxXVBufLast].consumed.arrive(); + smem.vBars[idxXVBufLast].consumed.arrive(); +} + +__device__ inline void Consumer::loadX() { +#pragma unroll 1 + for (uint32_t iter = 0; true; iter++) { + uint32_t const idxTile = iterToTile(iter); + if (idxTile >= nbTiles()) { + break; + } + // @todo: merge these two barriers. + uint32_t const idxScratchXBuf = iter % nbProducerCtasPerCga; + auto& srcProducedBar = smem.cgaXBufProduced[idxScratchXBuf]; + srcProducedBar.wait_parity(toParity(iter)); + uint32_t const idxXBuf = iter % SharedMemB::nbXBufs; + auto& xBar = smem.xBars[idxXBuf]; + xBar.consumed.wait_parity(toParity(iter)); + if (warpElectSync()) { + auto& src = args.cgaXBuf[nbSubSeq * idxInputTokenGlobal + idxSubSeq][idxScratchXBuf]; + auto& dst = smem.xv[idxXBuf].x; + tma::loadLinearAsync(&dst, &src.x, sizeof(CgaXBuffer), xBar.produced); + xBar.produced.arrive_tx(sizeof(CgaXBuffer)); + xBar.produced.wait_parity(toParity(iter)); + uint32_t const idxProducer = idxScratchXBuf; + // @fixme: check if this works. If it doesn't, randomly pick some data from dstX and dstRowSum + // and use STAS + arrive_tx to avoid fence. + getProducerShm(idxProducer).cgaXBufConsumed.arrive(); + } + } +} + +__device__ inline void Consumer::loadV() { + KVTilePartLoader loader(args.cacheList, idxReq, args.tensorMapV +#if USE_PAGED_KV_CACHE + , + divUp(seqLen, tokensPerPage) +#endif + ); + for (uint32_t iter = 0; true; iter++) { + uint32_t const idxTile = iterToTile(iter); + if (idxTile >= nbTiles()) { + break; + } + uint32_t const idxPageBuf = iter % KVTilePartLoader::nbPageBuffers; + loader.loadPages(idxTile, idxPageBuf); + uint32_t const idxVBuf = iter % SharedMemB::nbVBufs; + auto& vBar = smem.vBars[idxVBuf]; + vBar.consumed.wait_parity(toParity(iter)); +#pragma unroll + for (uint32_t idxPart = 0; idxPart < SharedMemB::VBuffer::size; idxPart++) { + loader.loadData( + smem.v(idxVBuf)[idxPart], idxTile, + gemm1V * idxConsumer() + exactDiv(gemm1V, SharedMemB::VBuffer::size) * idxPart, + vBar.produced, idxPageBuf); + } + if (warpElectSync()) { + vBar.produced.arrive_tx(sizeof(SharedMemB::VBuffer)); + } + } +} + +__device__ inline Array2D +Consumer::finalize(WarpAcc const& acc, ThrdRegRowMax const& accRowSum, float const xvScale, + uint32_t const lane) { + ThrdRegRowMax const scaleVec = 1.F / (accRowSum)*xvScale; + WarpOutputTile ret; +#pragma unroll + for (uint32_t m = 0; m < WarpAcc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) { + uint32_t retRow = m * InstAcc::rows + i; + float const scale = __shfl_sync(~0U, scaleVec[m / 2], 16 * (m % 2) + 8 * i + lane / 4); +#pragma unroll + for (uint32_t n = 0; n < WarpAcc::cols; n++) { + float data[InstAcc::cols]; +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j++) { + data[j] = acc(m, n)(i, j) * scale; + } + assert(InstAcc::cols == 2); + reinterpret_cast<__nv_bfloat162&>(ret(retRow, n)) = + __float22bfloat162_rn(float2{data[0], data[1]}); + } + } + } + return ret; +} + +__device__ inline void Consumer::storeOutput(Vec& dst, uint32_t dstBaseCol, + WarpOutputTile const& src, + WarpOutSwizzleBuf& swizzleBuf, uint32_t lane) { + using Dst = mha::decay_t; + static_assert(Dst::size == WarpOutputTile::rows * 8 && Dst::size % WarpOutSwizzleBuf::rows == 0); + uint32_t const nbIters = exactDiv(Dst::size, WarpOutSwizzleBuf::rows); + + uint32_t const rS = lane % 8; + uint32_t const cS = lane / 8; + + uint32_t const thrdsPerRow = + exactDiv(sizeof(WarpOutSwizzleBuf::Elem) * WarpOutSwizzleBuf::cols, grainBytes); + static_assert(thrdsPerRow <= 32); + uint32_t const rL = lane / thrdsPerRow; + uint32_t const cL = lane % thrdsPerRow; +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) { +#pragma unroll + for (uint32_t j = 0; j < WarpOutputTile::cols; j += 4) { + auto const baseSwzPtr = &swizzleBuf.template at(rS, j + cS); + constexpr uint32_t srcRowsPerIter = exactDiv(WarpOutputTile::rows, nbIters); +#pragma unroll + for (uint32_t i = 0; i < srcRowsPerIter; i++) { + static_assert(sizeof(WarpOutSwizzleBuf::Elem) * WarpOutSwizzleBuf::cols * 8 % 1024 == 0); + auto const swzPtr = checkedVal(baseSwzPtr + WarpOutputTile::cols * 8 * i, + &swizzleBuf.template at(8 * i + rS, j + cS)); + stmatrix( + swzPtr, reinterpret_cast const&>(src(srcRowsPerIter * iter + i, j))); + } + } + __syncwarp(); + + uint32_t const dstRowsPerIter = WarpOutSwizzleBuf::rows; + uint32_t const rowsPerOp = exactDiv(warp_size, thrdsPerRow); + LdGrain* const baseDstPtr = reinterpret_cast( + &dst[dstRowsPerIter * iter + rL] + [dstBaseCol + exactDiv(grainBytes, sizeof(OutputElem)) * cL]); +#pragma unroll + for (uint32_t i = 0; i < dstRowsPerIter; i += rowsPerOp) { + LdGrain* const dstPtr = + checkedVal(baseDstPtr + i * exactDiv(sizeof(OutputHead), grainBytes), + reinterpret_cast( + &dst[dstRowsPerIter * iter + i + rL] + [dstBaseCol + exactDiv(grainBytes, sizeof(OutputElem)) * cL])); + LdGrain* const srcPtr = &swizzleBuf.template at(i + rL, cL); + *dstPtr = *srcPtr; + } + __syncwarp(); + } +} + +__device__ inline void mergePartialOutputs(uint32_t& semaphore, + Vec& dst, + PartialResult const* reqPartialResults, + uint32_t nbSubSeq, uint32_t ctaRank, uint32_t warpRank, + uint2 warpIdx, void* sharedMem) { + assert(nbSubSeq > 1); + clusterBarArrive(); + clusterBarWait(); + bool const isProducer = (ctaRank < nbProducerCtasPerCga); + + bool& shmIsLastSubSeq = isProducer ? static_cast(sharedMem)->isLastSubSeq + : static_cast(sharedMem)->isLastSubSeq; + + if (ctaRank == 3 && threadIdx.x == 0) { + uint32_t old; + uint32_t const lastOld = nbSubSeq - 1; + asm volatile("atom.relaxed.gpu.global.inc.u32 %0, [%1], %2;\n" + : "=r"(old) + : "l"(&semaphore), "r"(lastOld)); + bool const isLastSubSeq = (old == lastOld); +#pragma unroll + for (uint32_t i = 0; i < nbProducerCtasPerCga; i++) { + static_cast(mapa(sharedMem, i))->isLastSubSeq = isLastSubSeq; + } + mapa(shmIsLastSubSeq, 2) = isLastSubSeq; + shmIsLastSubSeq = isLastSubSeq; + } + clusterBarArrive(); + clusterBarWait(); + bool const isLastCga = shmIsLastSubSeq; + if (!isLastCga) { + return; + } + + CtaBarrierPair(&bars)[nbMultiBlockBufs] = + isProducer ? static_cast(sharedMem)->multiBlockBars + : static_cast(sharedMem)->multiBlockBars; + Vec& shmBufs = + isProducer ? static_cast(sharedMem)->getMultiBlockBufs() + : static_cast(sharedMem)->getMultiBlockBufs(); + + constexpr uint32_t nbShmBufs = nbMultiBlockBufs; + + if (warpIdx.y == 2) { + asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" ::"n"(nbRegsForIOWarps)); + if (warpIdx.x == 0) { +#pragma unroll 1 + for (uint32_t idxSubSeq = 0; idxSubSeq < nbSubSeq; idxSubSeq++) { + uint32_t const idxBuf = idxSubSeq % nbShmBufs; + auto& bar = bars[idxBuf]; + bar.consumed.wait_parity(toParity(idxSubSeq)); + if (warpElectSync()) { + tma::loadLinearAsync(&shmBufs[idxBuf], &reqPartialResults[idxSubSeq].chunks[ctaRank], + sizeof(PartialResult::Chunk), bar.produced); + bar.produced.arrive_tx(sizeof(PartialResult::Chunk)); + } + } + } + } else { + asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" ::"n"(nbRegsForMathWarps)); + constexpr uint32_t nbMathWarps = 8; + constexpr uint32_t rowsPerWarp = exactDiv(PartialResult::nbRowsPerChunk, nbMathWarps); + constexpr uint32_t regGrainsPerRow = exactDiv(sizeof(OutputHead), grainBytes * warp_size); + constexpr uint32_t grainOutElems = exactDiv(grainBytes, sizeof(OutputElem)); + uint32_t const lane = laneId(); + + uint32_t const tileRowBase = rowsPerWarp * warpRank; + using RowWise = Vec; + using RegChunk = Array2D, rowsPerWarp, regGrainsPerRow>; + auto loadBuf = [&](RowWise& rowMaxLog2e, RowWise& rowSum, RegChunk& regChunk, + PartialResult::Chunk const& chunk) { + auto loadRowWise = [&](Vec const& src) { + return reinterpret_cast(src[tileRowBase]); + }; + rowMaxLog2e = loadRowWise(chunk.rowMaxLog2e); + rowSum = loadRowWise(chunk.rowSum); + regChunk; +#pragma unroll + for (uint32_t i = 0; i < rowsPerWarp; i++) { +#pragma unroll + for (uint32_t j = 0; j < regGrainsPerRow; j++) { + regChunk(i, j) = reinterpret_cast const&>( + chunk.data[tileRowBase + i][grainOutElems * (warp_size * j + lane)]); + } + } + }; + + uint32_t const idxSubSeqInit = 0; + uint32_t const idxBufInit = idxSubSeqInit % nbShmBufs; + bars[idxBufInit].produced.wait_parity(toParity(idxSubSeqInit)); + RowWise accRowMaxLog2e; + RowWise accRowSum; + RegChunk chunk; + loadBuf(accRowMaxLog2e, accRowSum, chunk, shmBufs[idxBufInit]); + bars[idxBufInit].consumed.arrive(); + + using Acc = Array2D, rowsPerWarp, regGrainsPerRow>; + Acc acc; +#pragma unroll + for (uint32_t i = 0; i < rowsPerWarp; i++) { +#pragma unroll + for (uint32_t j = 0; j < regGrainsPerRow; j++) { + acc(i, j) = convert(chunk(i, j)) * accRowSum[i]; + } + } + +#pragma unroll 1 + for (uint32_t idxSubSeq = idxSubSeqInit + 1; idxSubSeq < nbSubSeq; idxSubSeq++) { + uint32_t const idxBuf = idxSubSeq % nbShmBufs; + auto& bar = bars[idxBuf]; + bar.produced.wait_parity(toParity(idxSubSeq)); + RowWise chunkRowMaxLog2e; + RowWise chunkRowSum; + loadBuf(chunkRowMaxLog2e, chunkRowSum, chunk, shmBufs[idxBuf]); + bar.consumed.arrive(); +#pragma unroll + for (uint32_t i = 0; i < rowsPerWarp; i++) { + bool const newChunkGreater = (chunkRowMaxLog2e[i] > accRowMaxLog2e[i]); + if (newChunkGreater) { + float const scale = exp2f(accRowMaxLog2e[i] - chunkRowMaxLog2e[i]); +#pragma unroll + for (uint32_t j = 0; j < regGrainsPerRow; j++) { + acc(i, j) = acc(i, j) * scale + convert(chunk(i, j)) * chunkRowSum[i]; + } + accRowSum[i] = accRowSum[i] * scale + chunkRowSum[i]; + accRowMaxLog2e[i] = chunkRowMaxLog2e[i]; + } else { + float const scale = exp2f(chunkRowMaxLog2e[i] - accRowMaxLog2e[i]); + float const fusedScale = scale * chunkRowSum[i]; +#pragma unroll + for (uint32_t j = 0; j < regGrainsPerRow; j++) { + acc(i, j) = acc(i, j) + convert(chunk(i, j)) * fusedScale; + } + accRowSum[i] = accRowSum[i] + chunkRowSum[i] * scale; + } + } + } + +#pragma unroll + for (uint32_t i = 0; i < rowsPerWarp; i++) { + float const scale = 1.F / accRowSum[i]; + auto const dstHead = reinterpret_cast*>(&dst[tileRowBase + i]); +#pragma unroll + for (uint32_t j = 0; j < regGrainsPerRow; j++) { + dstHead[warp_size * j + lane] = convert(acc(i, j) * scale); + } + } + } +} + +inline constexpr uint32_t cgaSize = nbProducerCtasPerCga + nbVSplit; + +CUBIN_EXPORT __global__ +__launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha( + __grid_constant__ CUtensorMap const tensorMapQ, // MhaIOHead[nbQHeads * totalNbInputTokens], + __grid_constant__ CUtensorMap const tensorMapK, // with box=64 for the least significant dim + __grid_constant__ CUtensorMap const tensorMapV, // with box=128 for the least significant dim + float const qScale, + OutputHead* __restrict__ const output, // [totalNbIntputTokens][nbQHeads] + KVCacheList const cacheList, uint32_t const batchSize, + float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and V + // cache. Used only for int8/fp8 KV cache. + Vec* __restrict__ const cgaXBuf, // [totalNbInputTokens][maxNbSubSeq] + uint32_t* __restrict__ const semaphores = nullptr, // [totalNbInputTokens] + PartialResult* __restrict__ const partialResults = + nullptr) // [totalNbInputTokens][maxNbSubSeq] +{ + assert(blockDim.x == 32 * 12 && blockDim.y == 1 && blockDim.z == 1); + extern __shared__ char smemBuf[]; + uint32_t const warpRank = makeWarpUniform(this_warp(), threadIdx.x / warp_size); + uint2 const warpIdx = {warpRank % 4, warpRank / 4}; + + uint3 const& cgaId = clusterId(); + uint32_t const& idxReq = cgaId.z; + uint32_t const& maxNbSubSeq = nbClusters().y; + uint32_t const& idxSubSeq = cgaId.y; + uint32_t const inputSeqLen = + (allowMultipleInputTokens ? exactDiv(gridDim.x, cgaSize) + : checkedVal(1U, exactDiv(gridDim.x, cgaSize))); + uint32_t const reqIdxInputToken = + (allowMultipleInputTokens ? blockIdx.x / cgaSize : checkedVal(0U, blockIdx.x / cgaSize)); + uint32_t const idxInputTokenGlobal = inputSeqLen * idxReq + reqIdxInputToken; + uint32_t const cacheSeqLen = cacheList.seqLenList[idxReq] - (inputSeqLen - 1) + reqIdxInputToken; + assert(beamWidth == 1); + uint32_t const nbTiles = useKVCache ? divUp(cacheSeqLen, tokensPerTile) : 0; + bool const isMultiBlockMode = (maxNbSubSeq > 1 && nbTiles >= multiBlockMinNbTiles); + uint32_t const nbSubSeq = + isMultiBlockMode ? mha::min(nbTiles / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1; + static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2); + assert(isMultiBlockMode == (nbSubSeq > 1)); + if (idxSubSeq >= nbSubSeq) { + return; + } + + uint32_t const ctaRank = clusterCtaRank(); + bool const isProducer = (ctaRank < nbProducerCtasPerCga); + + KernelArgs const args{tensorMapQ, tensorMapK, tensorMapV, qScale, output, cacheList, + batchSize, kvCacheScale, cgaXBuf, semaphores, partialResults}; + + if (isProducer) { + Producer{args, + *reinterpret_cast(smemBuf), + maxNbSubSeq, + idxReq, + idxInputTokenGlobal, + cacheSeqLen, + nbSubSeq, + idxSubSeq, + ctaRank, + warpRank, + warpIdx} + .run(); + } else { + Consumer{args, + *reinterpret_cast(smemBuf), + maxNbSubSeq, + idxReq, + idxInputTokenGlobal, + cacheSeqLen, + nbSubSeq, + idxSubSeq, + ctaRank, + warpRank, + warpIdx} + .run(); + } +} + +__constant__ constexpr uint32_t smemSize = mha::max(sizeof(SharedMemA), sizeof(SharedMemB)); +static_assert(smemSize <= 99 * 1024, "Shared memory size exceeded"); +#endif // is_MLA + +#ifndef GENERATE_CUBIN +#if IS_MLA +CUtensorMap makeTensorMapForQ(void const* addr, CUtensorMapDataType_enum dataType, + uint32_t headElems, uint32_t totalNbHeads, uint32_t partElems) { + CUtensorMap tensorMap{}; + uint64_t const globalDims[] = {headElems, totalNbHeads}; + uint32_t elemBytes = getElemBytes(dataType); + uint32_t const headBytes = elemBytes * headElems; + uint64_t const globalStrides[] = {headBytes}; + uint32_t const boxDims[] = {partElems, headGrpSize}; + uint32_t const elemStrides[] = {1, 1}; + auto const swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + + checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 2, const_cast(addr), globalDims, + globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensorMap; +} +#endif // IS_MLA + +void launchMLA( + cudaDeviceProp const& prop, + uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed + float qScale, OutputHead* output, InputHead const* q, +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + GMemCacheHead* kCacheVLLM, // K cache pool for VLLM layout + GMemCacheHead* vCacheVLLM, // V cache pool for VLLM layout +#else + GMemCacheHead* pool, // global pool of pages +#endif + KVCachePageIndex const* + kvCachePageList, // device pointer. shape: + // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or + // [batchSize][maxNbPagesPerSeq] (Layout 1) +#else + GMemKVCacheHead* kvCacheData, +#endif + uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, + float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. + // Used only for int8/fp8 KV cache. + uint32_t* semaphores, void* scratch, cudaStream_t stream) { +#if IS_MLA + static_assert( + SLIDING_WINDOW == 0 && LOW_PREC_OUTPUT == 0 && USE_INPUT_KV == 0 && USE_BEAM_SEARCH == 0, + "not implemented"); + if (beamWidth != 1) { + throw std::runtime_error("not implemented"); + } + static uint32_t const hostSmemSize = [&]() { + // printf("smemSize = %u\n", smemSize); + uint32_t size; + checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); + checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); + return size; + }(); + uint32_t const nbKHeads = 1; + uint32_t const nbVHeads = nbKHeads; + uint32_t const nbQHeads = nbKHeads * headGrpSize; + uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads; + uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { + auto const env = std::getenv("XQA_NB_SUB_SEQ"); + if (env != nullptr) { + int32_t const val = std::stoi(env); + if (val > 0) { + return val; + } + } + float const factor = 4.f; + return mha::min( + mha::max( + 1U, (uint32_t)round(prop.multiProcessorCount / 4 / (batchSize * nbKHeads) * factor)), + divUp(maxSeqLen, tokensPerTile * 2)); + }(); + // printf("nbSubSeqPerSeq = %u\n", nbSubSeqPerSeq); + // gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x == + // nbInputSeqSplit + dim3 const dimGrid{4 * inputSeqLen, nbSubSeqPerSeq, nbKHeads * batchSize}; + dim3 const dimCta{warp_size * 4 * 3, 1, 1}; + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); +#if USE_PAGED_KV_CACHE + uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); +#if PAGED_KV_CACHE_LAYOUT == 1 + KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, + maxNbPagesPerSeq}; +#else + KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; +#endif + auto const dtype = [] { + if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } + throw std::runtime_error("unsupported cache element type"); + }(); + + auto const tensorMapQ = makeTensorMapForQ(q, dtype, validElemsPerHead, + headGrpSize * inputSeqLen * batchSize, partElemsK); +#if PAGED_KV_CACHE_LAYOUT == 1 + auto const tensorMapK = makeTensorMapForPagedKVCache( + kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile); + auto const tensorMapV = makeTensorMapForPagedKVCache( + vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile); +#else + auto const tensorMapK = makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, + tokensPerPage, partElemsK, tokensPerTile); + auto const tensorMapV = makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, + tokensPerPage, partElemsV, tokensPerTile); +#endif + + uint32_t const nbCgas = exactDiv(dimGrid.x, 4) * dimGrid.y * dimGrid.z; + auto const cgaXBuf = static_cast*>(scratch); + auto const partialResults = reinterpret_cast(cgaXBuf + nbCgas); + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, tensorMapQ, tensorMapK, + tensorMapV, qScale, output, cacheList, batchSize, + kvCacheScale, cgaXBuf, semaphores, partialResults); +#else + KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; + static_assert(!usePagedKVCache); + assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); + auto const tensorMap = makeTensorMapForContiguousKVCache( + kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, + batchSize, gemm0CtaTileNbTokens); + cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif +#if USE_INPUT_KV + qkv, +#if ROPE_STYLE != 0 + ropeCosSin, +#endif +#else + q, +#endif + cacheList, +#if USE_BEAM_SEARCH + beamSearchParams, +#endif + batchSize, kvCacheScale, tensorMap, semaphores, scratch); +#endif + checkCuda(err); +#endif +} + +void launchMLAFlashInfer( + uint32_t multiProcessorCount, + uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed + float qScale, OutputHead* output, InputHead const* q, +#if PAGED_KV_CACHE_LAYOUT == 1 + GMemCacheHead* kCacheVLLM, // K cache pool for VLLM layout + GMemCacheHead* vCacheVLLM, // V cache pool for VLLM layout +#else + GMemCacheHead* pool, // global pool of pages +#endif + KVCachePageIndex const* + kvCachePageList, // device pointer. shape: + // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or + // [batchSize][maxNbPagesPerSeq] (Layout 1) + uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, + float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. + // Used only for int8/fp8 KV cache. + uint32_t* semaphores, void* scratch, cudaStream_t stream) { +#if IS_MLA + static_assert( + SLIDING_WINDOW == 0 && LOW_PREC_OUTPUT == 0 && USE_INPUT_KV == 0 && USE_BEAM_SEARCH == 0, + "not implemented"); + if (beamWidth != 1) { + throw std::runtime_error("not implemented"); + } + static uint32_t const hostSmemSize = [&]() { + // printf("smemSize = %u\n", smemSize); + uint32_t size; + checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); + checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); + return size; + }(); + uint32_t const nbKHeads = 1; + uint32_t const nbVHeads = nbKHeads; + uint32_t const nbQHeads = nbKHeads * headGrpSize; + uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads; + uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { + float const factor = 4.f; + return mha::min( + mha::max( + 1U, (uint32_t)round(multiProcessorCount / 4 / (batchSize * nbKHeads) * factor)), + divUp(maxSeqLen, tokensPerTile * 2)); + }(); + // printf("nbSubSeqPerSeq = %u\n", nbSubSeqPerSeq); + // gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x == + // nbInputSeqSplit + dim3 const dimGrid{4 * inputSeqLen, nbSubSeqPerSeq, nbKHeads * batchSize}; + dim3 const dimCta{warp_size * 4 * 3, 1, 1}; + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); +#if USE_PAGED_KV_CACHE + uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); +#if PAGED_KV_CACHE_LAYOUT == 1 + KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, + maxNbPagesPerSeq}; +#else + KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; +#endif + auto const dtype = [] { + if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } + throw std::runtime_error("unsupported cache element type"); + }(); + + auto const tensorMapQ = makeTensorMapForQ(q, dtype, validElemsPerHead, + headGrpSize * inputSeqLen * batchSize, partElemsK); +#if PAGED_KV_CACHE_LAYOUT == 1 + auto const tensorMapK = makeTensorMapForPagedKVCache( + kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile); + auto const tensorMapV = makeTensorMapForPagedKVCache( + vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile); +#else + auto const tensorMapK = makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, + tokensPerPage, partElemsK, tokensPerTile); + auto const tensorMapV = makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, + tokensPerPage, partElemsV, tokensPerTile); +#endif + + uint32_t const nbCgas = exactDiv(dimGrid.x, 4) * dimGrid.y * dimGrid.z; + auto const cgaXBuf = static_cast*>(scratch); + auto const partialResults = reinterpret_cast(cgaXBuf + nbCgas); + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, tensorMapQ, tensorMapK, + tensorMapV, qScale, output, cacheList, batchSize, + kvCacheScale, cgaXBuf, semaphores, partialResults); +#else + KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; + static_assert(!usePagedKVCache); + assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); + auto const tensorMap = makeTensorMapForContiguousKVCache( + kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, + batchSize, gemm0CtaTileNbTokens); + cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif +#if USE_INPUT_KV + qkv, +#if ROPE_STYLE != 0 + ropeCosSin, +#endif +#else + q, +#endif + cacheList, +#if USE_BEAM_SEARCH + beamSearchParams, +#endif + batchSize, kvCacheScale, tensorMap, semaphores, scratch); +#endif + checkCuda(err); +#endif +} +#endif diff --git a/csrc/xqa/mla_sm120.cuh b/csrc/xqa/mla_sm120.cuh new file mode 100644 index 0000000000..f4349acab0 --- /dev/null +++ b/csrc/xqa/mla_sm120.cuh @@ -0,0 +1,76 @@ +#pragma once +#include "mha_components.cuh" +#include "mha_stdheaders.cuh" +#include "tma.h" +#include "utils.cuh" + +template +__device__ inline ThrdRegRowMaxT loadShmRowMax(Vec const& shm, + uint32_t tileBaseRow, + uint32_t lane = laneId()) { + ThrdRegRowMaxT result{}; +#pragma unroll + for (uint32_t i = 0; i < result.size; i++) { + result[i] = shm[tileBaseRow + i * warp_size + lane]; + } + return result; +} + +template +__device__ inline void storeRowMax(Vec& shm, ThrdRegRowMaxT const& src, + uint32_t tileBaseRow, uint32_t lane = laneId()) { +#pragma unroll + for (uint32_t i = 0; i < src.size; i++) { + shm[tileBaseRow + i * warp_size + lane] = src[i]; + } +} + +template +__device__ inline void storeRowMaxAsync(CgaBarrier& bar, Vec& shm, + ThrdRegRowMaxT const& src, uint32_t tileBaseRow, + uint32_t lane = laneId()) { +#pragma unroll + for (uint32_t i = 0; i < src.size; i++) { + tma::storeAsync(&shm[tileBaseRow + i * warp_size + lane], src[i], bar); + } +} + +template +__device__ inline QuadRegRowMaxT computeRowMax(WarpAccT const& acc) { + QuadRegRowMaxT rowMaxLog2e{}; +// compute per-thread row max +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j++) { +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) { + float& dst = rowMaxLog2e[m * InstAcc::rows + i]; + dst = ((n == 0 && j == 0) ? acc(m, n)(i, j) : fmaxf(dst, acc(m, n)(i, j))); + } + } + } + } +// compute warp row max +#pragma unroll + for (uint32_t xorMask = 2; xorMask != 0; xorMask /= 2) { +#pragma unroll + for (uint32_t i = 0; i < rowMaxLog2e.size; i++) { + rowMaxLog2e[i] = fmaxf(rowMaxLog2e[i], __shfl_xor_sync(~0U, rowMaxLog2e[i], xorMask)); + } + } + return rowMaxLog2e; +} + +template +__device__ inline uint32_t hashRegData(Vec const& data) { + static_assert(sizeof(T) == 4); + uint32_t result = 0; +#pragma unroll + for (uint32_t i = 0; i < n; i++) { + result ^= reinterpret_cast(data[i]); + } + return result; +} diff --git a/csrc/xqa/tensorMap.cpp b/csrc/xqa/tensorMap.cpp new file mode 100644 index 0000000000..e79272b018 --- /dev/null +++ b/csrc/xqa/tensorMap.cpp @@ -0,0 +1,117 @@ +#include "tensorMap.h" + +#include +#include + +#include + +#include "utils.h" + +uint32_t getElemBytes(CUtensorMapDataType_enum dataType) { + switch (dataType) { + case CU_TENSOR_MAP_DATA_TYPE_UINT8: + return 1; + case CU_TENSOR_MAP_DATA_TYPE_UINT16: + return 2; + case CU_TENSOR_MAP_DATA_TYPE_UINT32: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_INT32: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_UINT64: + return 8; + case CU_TENSOR_MAP_DATA_TYPE_INT64: + return 8; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT16: + return 2; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT32: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT64: + return 8; + case CU_TENSOR_MAP_DATA_TYPE_BFLOAT16: + return 2; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ: + return 4; + default: + throw std::runtime_error("unsupported data type"); + } +} + +CUtensorMap makeTensorMapForContiguousKVCache(void const* addr, CUtensorMapDataType_enum dataType, + uint32_t headElems, uint32_t nbKHeads, + uint32_t maxCacheLen, uint32_t beamWidth, + uint32_t batchSize, uint32_t partElems, + uint32_t nbTokens) { + CUtensorMap tensorMap{}; + uint64_t const globalDims[] = {headElems, maxCacheLen, nbKHeads, 2 * beamWidth * batchSize}; + uint32_t elemBytes = getElemBytes(dataType); + uint32_t const headBytes = elemBytes * headElems; + uint64_t const globalStrides[] = {headBytes, headBytes * maxCacheLen, + headBytes * maxCacheLen * nbKHeads}; + uint32_t const boxDims[] = {partElems, nbTokens, 1, 1}; + uint32_t const elemStrides[] = {1, 1, 1, 1}; + + auto const swizzle = [&] { + switch (partElems) { + case 128: + return CU_TENSOR_MAP_SWIZZLE_128B; + case 64: + return CU_TENSOR_MAP_SWIZZLE_64B; + default: + throw std::runtime_error("unsupported cache head size"); + } + }(); + + checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast(addr), globalDims, + globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensorMap; +} + +CUtensorMap makeTensorMapForPagedKVCache(void const* addr, CUtensorMapDataType_enum dataType, + uint32_t headElems, uint32_t nbKHeads, + uint32_t tokensPerPage, uint32_t partElems, + uint32_t nbTokensPerTile) { + CUtensorMap tensorMap{}; + uint32_t elemBytes = getElemBytes(dataType); +// VLLM Layout +#if PAGED_KV_CACHE_LAYOUT == 1 + uint64_t const globalDims[] = {headElems, nbKHeads, tokensPerPage, 1U << 31}; + uint32_t const headBytes = elemBytes * headElems; + uint64_t const globalStrides[] = {headBytes, headBytes * nbKHeads, + headBytes * nbKHeads * tokensPerPage}; + uint32_t const partBytes = partElems * elemBytes; + uint32_t const boxDims[] = {partElems, 1, mha::min(tokensPerPage, nbTokensPerTile), 1}; + uint32_t const elemStrides[] = {1, 1, 1, 1}; + // XQA Original Layout +#else + uint64_t const globalDims[] = {headElems, tokensPerPage, nbKHeads, 1U << 31}; + uint32_t const headBytes = elemBytes * headElems; + uint64_t const globalStrides[] = {headBytes, headBytes * tokensPerPage, + headBytes * tokensPerPage * nbKHeads}; + uint32_t const partBytes = partElems * elemBytes; + uint32_t const boxDims[] = {partElems, mha::min(tokensPerPage, nbTokensPerTile), 1, 1}; + uint32_t const elemStrides[] = {1, 1, 1, 1}; +#endif + + auto const swizzle = [&] { + switch (partBytes) { + case 128: + return CU_TENSOR_MAP_SWIZZLE_128B; + case 64: + return CU_TENSOR_MAP_SWIZZLE_64B; + default: + throw std::runtime_error("unsupported cache head size"); + } + }(); + + checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast(addr), globalDims, + globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensorMap; +} diff --git a/csrc/xqa/tensorMap.h b/csrc/xqa/tensorMap.h new file mode 100644 index 0000000000..d0b2c76b96 --- /dev/null +++ b/csrc/xqa/tensorMap.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +uint32_t getElemBytes(CUtensorMapDataType_enum dataType); + +CUtensorMap makeTensorMapForContiguousKVCache(void const* addr, CUtensorMapDataType_enum dataType, + uint32_t headElems, uint32_t nbKHeads, + uint32_t maxCacheLen, uint32_t beamWidth, + uint32_t batchSize, uint32_t partElems, + uint32_t nbTokens); + +CUtensorMap makeTensorMapForPagedKVCache(void const* addr, CUtensorMapDataType_enum dataType, + uint32_t headElems, uint32_t nbKHeads, + uint32_t tokensPerPage, uint32_t partElems, + uint32_t nbTokensPerTile); diff --git a/csrc/xqa/tma.h b/csrc/xqa/tma.h new file mode 100644 index 0000000000..5cf67238a2 --- /dev/null +++ b/csrc/xqa/tma.h @@ -0,0 +1,302 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "cuda_hint.cuh" +#include "utils.h" +#ifndef GENERATE_CUBIN +#include +#include + +#include +#endif +#include "barriers.cuh" + +enum class StateSpace { kCONSTANT, kPARAMETER, kGENERIC }; + +#ifdef GENERATE_CUBIN +#define CU_TENSOR_MAP_NUM_QWORDS 16 + +typedef struct CUtensorMap_st { +#if defined(__cplusplus) && (__cplusplus >= 201103L) + alignas(64) +#elif __STDC_VERSION__ >= 201112L + _Alignas(64) +#endif + uint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; +} CUtensorMap; +#endif + +namespace tma { + +__device__ inline void loadLinearAsync(void* dst, void const* src, uint32_t nbBytes, + CtaBarrier& bar) { + asm volatile( + "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(src)), "r"(nbBytes), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); +} + +__device__ inline void prefetchLinear(void const* src, uint32_t nbBytes) { + asm volatile( + "cp.async.bulk.prefetch.L2.global [%0], %1;\n" ::"l"(reinterpret_cast(src)), + "r"(nbBytes) + : "memory"); +} + +// dsr and &bar must be remote address generated by mapa and src must be local address +__device__ inline void sm2smCopyAsync(void* dst, void const* src, uint32_t nbBytes, + CgaBarrier& bar) { + asm volatile( + "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, " + "[%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(src)), "r"(nbBytes), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); +} + +template +__device__ inline void loadAsync(void* dst, CUtensorMap const& tensorMap, DimsLE offset, + CtaBarrier& bar) { + if constexpr (nbDims == 1) { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile( + "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2}], [%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 2) { + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2, %3}], [%4];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 3) { + asm volatile( + "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2, %3, %4}], " + "[%5];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 4) { + asm volatile( + "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2, %3, %4, " + "%5}], [%6];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 5) { + asm volatile( + "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2, %3, %4, %5, " + "%6}], [%7];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else { + static_assert(nbDims >= 1 && nbDims <= 5); + } +} + +template +__device__ inline void loadAsync(void* dst, CUtensorMap const& tensorMap, DimsLE offset, + CtaBarrier& bar, uint64_t cacheHint) { + if constexpr (nbDims == 1) { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile( + "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2}], [%3], %4;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 2) { + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2, %3}], [%4], %5;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 3) { + asm volatile( + "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2, %3, %4}], [%5], %6;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "l"(__cvta_generic_to_shared(&bar)), + "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 4) { + asm volatile( + "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2, %3, %4, %5}], [%6], %7;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), + "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 5) { + asm volatile( + "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), + "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else { + static_assert(nbDims >= 1 && nbDims <= 5); + } +} + +// shared::cta -> global +__device__ inline void store1DAsync(void* dst, void const* src, uint32_t nbBytes) { + asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;\n" + : + : "l"(reinterpret_cast(dst)), "l"(__cvta_generic_to_shared(src)), + "r"(nbBytes)); +} + +template +__device__ inline void storeAsync(CUtensorMap const& tensorMap, DimsLE const& offset, + void* src) { + if constexpr (nbDims == 1) { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group.tile [%0, {%1}], [%2];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 2) { + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2}], [%3];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 3) { + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3}], [%4];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "r"(offset[2]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 4) { + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4}], [%5];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "r"(offset[2]), "r"(offset[3]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 5) { + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], " + "[%6];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else { + static_assert(nbDims >= 1 && nbDims <= 5); + } +} + +__device__ inline void setTensorMapGlbAddr(CUtensorMap& tensorMap, void* ptr) { + asm volatile( + "tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;\n" ::"l"(&tensorMap), + "l"(ptr) + : "memory"); +} + +__device__ inline void commitGroup() { + asm volatile("cp.async.bulk.commit_group;\n" : : : "memory"); +} + +// wait until only targetNbInFlightGroups groups are still in-flight. +template +__device__ inline void waitGroup() { + asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(targetNbInFlightGroups) : "memory"); +} + +__device__ inline void prefetchTensorMap(CUtensorMap const& tensorMap, + StateSpace loc = StateSpace::kGENERIC) { + assert(reinterpret_cast(&tensorMap) % alignof(CUtensorMap) == 0); + switch (loc) { + case StateSpace::kCONSTANT: + asm volatile("prefetch.const.tensormap [%0];\n" ::"l"(__cvta_generic_to_constant(&tensorMap)) + : "memory"); + break; + case StateSpace::kPARAMETER: + asm volatile( + "prefetch.param.tensormap [%0];\n" ::"l"(__cvta_generic_to_grid_constant(&tensorMap)) + : "memory"); + break; + case StateSpace::kGENERIC: + asm volatile("prefetch.tensormap [%0];\n" ::"l"(reinterpret_cast(&tensorMap)) + : "memory"); + break; + default: + asm volatile("trap;\n"); + } +} + +template +__device__ inline void storeAsync(void* dst, T const& src, CgaBarrier& bar) { + constexpr uint32_t nbWords = exactDiv(sizeof(T), sizeof(uint32_t)); + Vec const& srcVec = reinterpret_cast const&>(src); + if constexpr (nbWords == 1) { + asm volatile( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.u32 [%0], %1, [%2];\n" ::"l"( + __cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbWords == 2) { + asm volatile( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.u32 [%0], {%1, %2}, " + "[%3];\n" ::"l"(__cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "r"(srcVec[1]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbWords == 4) { + asm volatile( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v4.u32 [%0], {%1, %2, %3, %4}, " + "[%5];\n" ::"l"(__cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "r"(srcVec[1]), "r"(srcVec[2]), "r"(srcVec[3]), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else { + static_assert(nbWords == 1 || nbWords == 2 || nbWords == 4, + "src size must be 4, 8 or 16 bytes"); + } +} + +} // namespace tma diff --git a/csrc/xqa/utils.cuh b/csrc/xqa/utils.cuh index 5883e5b834..9d0dbd8afe 100644 --- a/csrc/xqa/utils.cuh +++ b/csrc/xqa/utils.cuh @@ -31,7 +31,13 @@ #include "barriers.cuh" inline constexpr float log2e = 1.4426950408889634; // std::log2(M_E) -inline constexpr float safeInitRowMax = -1e+30F; +// we used an optimization where exp(x-rowMax) is computed as: +/* bias = rowMax * log2e // shared for the whole row + exp(x-rowMax) = exp2f(x * log2e - bias) +*/ +// But this optimization is not numerically stable when (x * log2e - bias) is computed with FMA and +// x is too large. For this reason, don't set safeInitRowMax with a huge absolute value. +inline constexpr float safeInitRowMax = -1e+5F; inline constexpr int32_t kBAD_PAGE_INDEX = -1; __constant__ constexpr float kE4M3_MAX = 448.F; @@ -40,7 +46,7 @@ __constant__ constexpr float kE4M3_MAX = 448.F; constexpr uint32_t kMAX_SMEM_SIZE = (99u << 10); #elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870 constexpr uint32_t kMAX_SMEM_SIZE = (163u << 10); -#elif __CUDA_ARCH__ == 900 +#elif __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 constexpr uint32_t kMAX_SMEM_SIZE = (227u << 10); #endif #endif diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index 1a5d636e10..4ca35d6bd2 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -19,12 +19,45 @@ using tvm::ffi::Optional; -void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize, - double qScale, TensorView output, +#if MLA_WRAPPER +void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q, +#if PAGED_KV_CACHE_LAYOUT == 1 + TensorView kCacheVLLM, TensorView vCacheVLLM, +#else + TensorView pool, +#endif + TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, + int64_t batchSize, TensorView kvCacheScale, TensorView semaphores, + TensorView scratch) { + auto stream = get_stream(output->device); + + launchMLAFlashInfer(multiProcessorCount, 1, qScale, reinterpret_cast(output->data), + reinterpret_cast(q->data), +#if PAGED_KV_CACHE_LAYOUT == 1 + reinterpret_cast(kCacheVLLM->data), + reinterpret_cast(vCacheVLLM->data), +#else + reinterpret_cast(pool->data), +#endif + reinterpret_cast(kvCachePageList->data), maxSeqLen, + reinterpret_cast(seqLen->data), batchSize, + reinterpret_cast(kvCacheScale->data), + reinterpret_cast(semaphores->data), + reinterpret_cast(scratch->data), stream); +} +#else + +void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads, + int64_t slidingWinSize, double qScale, TensorView output, #if LOW_PREC_OUTPUT TensorView rcpOutScale, #endif - TensorView q, Optional attentionSinks, TensorView pool, + TensorView q, Optional attentionSinks, +#if PAGED_KV_CACHE_LAYOUT == 1 + TensorView kCacheVLLM, TensorView vCacheVLLM, +#else + TensorView pool, +#endif TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, TensorView kvCacheScale, #if SPEC_DEC @@ -35,21 +68,28 @@ void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingW float const* attentionSinksPtr = attentionSinks.has_value() ? reinterpret_cast(attentionSinks.value()->data) : nullptr; + auto const mha_func = run_sm90_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer; - launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale, - reinterpret_cast(output->data), + mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale, + reinterpret_cast(output->data), #if LOW_PREC_OUTPUT - reinterpret_cast(rcpOutScale->data), + reinterpret_cast(rcpOutScale->data), #endif - reinterpret_cast(q->data), attentionSinksPtr, - reinterpret_cast(pool->data), - reinterpret_cast(kvCachePageList->data), maxSeqLen, - reinterpret_cast(seqLen->data), batchSize, - reinterpret_cast(kvCacheScale->data), + reinterpret_cast(q->data), attentionSinksPtr, +#if PAGED_KV_CACHE_LAYOUT == 1 + reinterpret_cast(kCacheVLLM->data), + reinterpret_cast(vCacheVLLM->data), +#else + reinterpret_cast(pool->data), +#endif + reinterpret_cast(kvCachePageList->data), maxSeqLen, + reinterpret_cast(seqLen->data), batchSize, + reinterpret_cast(kvCacheScale->data), #if SPEC_DEC - qSeqLen, reinterpret_cast(qCuSeqLens->data), - reinterpret_cast(mask->data), + qSeqLen, reinterpret_cast(qCuSeqLens->data), + reinterpret_cast(mask->data), #endif - reinterpret_cast(semaphores->data), - reinterpret_cast(scratch->data), stream); + reinterpret_cast(semaphores->data), reinterpret_cast(scratch->data), + stream); } +#endif diff --git a/docs/api/attention.rst b/docs/api/attention.rst index a468d20e17..bb65664c83 100644 --- a/docs/api/attention.rst +++ b/docs/api/attention.rst @@ -38,6 +38,16 @@ Batch Decoding .. automethod:: __init__ +XQA +--- + +.. currentmodule:: flashinfer.xqa + +.. autosummary:: + :toctree: ../generated + + xqa + flashinfer.prefill ================== diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 4b34a6418b..93813577b7 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -147,3 +147,4 @@ ) from .utils import next_positive_power_of_2 as next_positive_power_of_2 from .xqa import xqa as xqa +from .xqa import xqa_mla as xqa_mla diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 69c78ec00c..33a93bd760 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -67,7 +67,7 @@ from .jit.rope import gen_rope_module from .jit.sampling import gen_sampling_module from .jit.tllm_utils import gen_trtllm_utils_module -from .jit.xqa import gen_xqa_module +from .jit.xqa import gen_xqa_module, gen_xqa_module_mla from .jit.attention import ( gen_batch_attention_module, gen_batch_decode_module, @@ -356,29 +356,47 @@ def gen_attention( def gen_xqa( - use_fp16_: List[bool], + fp16_input_: List[bool], + fp8_kv_cache_: List[bool], token_per_page_: List[int], head_size_: List[int], head_grp_size_: List[int], use_sliding_window_: List[bool], has_sm90: bool, + has_sm100: bool, + has_sm120: bool, + has_sm121: bool, ) -> Iterator[JitSpec]: """Generate XQA modules for various configurations.""" - if not has_sm90: + if not has_sm90 and not has_sm100 and not has_sm120 and not has_sm121: return # XQA requires SM90+ + sm_versions = [] + if has_sm90: + sm_versions.append(90) + if has_sm100: + sm_versions.append(100) + if has_sm120: + sm_versions.append(120) + if has_sm121: + sm_versions.append(121) + for ( - use_fp16, + fp16_input, + fp8_kv_cache, token_per_page, head_size, head_grp_size, use_sliding_window, + sm_version, ) in product( - use_fp16_, + fp16_input_, + fp8_kv_cache_, token_per_page_, head_size_, head_grp_size_, use_sliding_window_, + sm_versions, ): # Skip invalid configurations if head_size % 16 != 0 or head_size > 256 or head_size < 16: @@ -386,14 +404,46 @@ def gen_xqa( if token_per_page not in [16, 32, 64, 128]: continue + if fp8_kv_cache: + kv_cache_dtype = torch.float8_e4m3fn + elif fp16_input: + kv_cache_dtype = torch.float16 + else: + kv_cache_dtype = torch.bfloat16 yield gen_xqa_module( - use_fp16=use_fp16, - token_per_page=token_per_page, - head_size=head_size, - head_grp_size=head_grp_size, + input_dtype=torch.float16 if fp16_input else torch.bfloat16, + kv_cache_dtype=kv_cache_dtype, + page_size=token_per_page, + head_dim=head_size, + head_group_ratio=head_grp_size, use_sliding_window=use_sliding_window, + sm_version=sm_version, ) + if has_sm120: + for token_per_page in token_per_page_: + yield gen_xqa_module_mla( + input_dtype=torch.float8_e4m3fn, + kv_cache_dtype=torch.float8_e4m3fn, + page_size=token_per_page, + head_dim=576, + head_group_ratio=128, + use_sliding_window=False, + sm_version=120, + ) + + if has_sm121: + for token_per_page in token_per_page_: + yield gen_xqa_module_mla( + input_dtype=torch.float8_e4m3fn, + kv_cache_dtype=torch.float8_e4m3fn, + page_size=token_per_page, + head_dim=576, + head_group_ratio=128, + use_sliding_window=False, + sm_version=121, + ) + def gen_all_modules( f16_dtype_: List[torch.dtype], @@ -508,19 +558,24 @@ def gen_all_modules( if add_xqa: # Define XQA configurations to iterate over - xqa_use_fp16_ = [True, False] # fp16 and bf16 + xqa_fp16_input_ = [True, False] # fp16 and bf16 + xqa_fp8_kv_cache_ = [True, False] xqa_token_per_page_ = [16, 32, 64, 128] xqa_head_size_ = [64, 128, 256] xqa_head_grp_size_ = [1, 2, 4, 8] # Different group sizes for MQA/GQA jit_specs += list( gen_xqa( - xqa_use_fp16_, + xqa_fp16_input_, + xqa_fp8_kv_cache_, xqa_token_per_page_, xqa_head_size_, xqa_head_grp_size_, use_sliding_window_, has_sm90, + has_sm100, + has_sm120, + has_sm121, ) ) diff --git a/flashinfer/jit/xqa.py b/flashinfer/jit/xqa.py index c9196c9ec6..2b4da5aa12 100644 --- a/flashinfer/jit/xqa.py +++ b/flashinfer/jit/xqa.py @@ -15,12 +15,23 @@ """ from . import env as jit_env -from .core import JitSpec, gen_jit_spec, sm90a_nvcc_flags +import torch +from .utils import filename_safe_dtype_map +from .core import ( + JitSpec, + gen_jit_spec, + sm90a_nvcc_flags, + sm100f_nvcc_flags, + sm120a_nvcc_flags, + sm121a_nvcc_flags, +) xqa_nvcc_flags = [ "-DNDEBUG=1", + "-DUSE_PAGED_KV_CACHE=1", + "-DPAGED_KV_CACHE_LAYOUT=1", "-DBEAM_WIDTH=1", - "-DCACHE_ELEM_ENUM=0", + "-DUSE_INPUT_KV=0", "-DUSE_CUSTOM_BARRIER=1", "-DLOW_PREC_OUTPUT=0", "-DSPEC_DEC=0", @@ -28,48 +39,142 @@ def gen_xqa_module( - use_fp16: bool, - token_per_page: int, - head_size: int, - head_grp_size: int, + input_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, + page_size: int, + head_dim: int, + head_group_ratio: int, use_sliding_window: bool, + sm_version: int = 90, ) -> JitSpec: - if use_fp16: - flag_use_fp16 = ["-DINPUT_FP16=1", "-DDTYPE=__half"] + if input_dtype == torch.float16: + flag_input_dtype = ["-DINPUT_FP16=1", "-DDTYPE=__half"] + elif input_dtype == torch.bfloat16: + flag_input_dtype = ["-DINPUT_FP16=0", "-DDTYPE=__nv_bfloat16"] else: - flag_use_fp16 = ["-DINPUT_FP16=0", "-DDTYPE=__nv_bfloat16"] + raise ValueError( + f"Invalid dtype: {input_dtype} for XQA, only float16 and bfloat16 input are supported" + ) + + if kv_cache_dtype == torch.float8_e4m3fn: + flag_kv_cache_dtype = ["-DCACHE_ELEM_ENUM=2"] + elif kv_cache_dtype == torch.int8: + flag_kv_cache_dtype = ["-DCACHE_ELEM_ENUM=1"] + else: + flag_kv_cache_dtype = ["-DCACHE_ELEM_ENUM=0"] - if token_per_page not in [16, 32, 64, 128]: + if page_size not in [16, 32, 64, 128]: raise ValueError( - f"Invalid token_per_page: {token_per_page}, only 16, 32, 64, 128 are supported" + f"Invalid page_size: {page_size}, only 16, 32, 64, 128 are supported" ) - flag_tokens_per_page = [f"-DTOKENS_PER_PAGE={token_per_page}"] + flag_tokens_per_page = [f"-DTOKENS_PER_PAGE={page_size}"] - if head_size % 16 != 0 or head_size > 256 or head_size < 16: + if head_dim % 16 != 0 or head_dim > 256 or head_dim < 16: raise ValueError( - f"Invalid head_size: {head_size}, must be divisible by 16 and in range [16, 256]" + f"Invalid head_dim: {head_dim}, must be divisible by 16 and in range [16, 256]" ) - flag_head_size = [f"-DHEAD_ELEMS={head_size}"] + flag_head_dim = [f"-DHEAD_ELEMS={head_dim}"] - flag_head_grp_size = [f"-DHEAD_GRP_SIZE={head_grp_size}"] + flag_head_group_ratio = [f"-DHEAD_GRP_SIZE={head_group_ratio}"] if use_sliding_window: flag_sliding_window = ["-DSLIDING_WINDOW=1"] else: flag_sliding_window = ["-DSLIDING_WINDOW=0"] + if sm_version == 100: + sm_nvcc_flags = sm100f_nvcc_flags + elif sm_version == 120: + sm_nvcc_flags = sm120a_nvcc_flags + elif sm_version == 121: + sm_nvcc_flags = sm121a_nvcc_flags + else: + sm_nvcc_flags = sm90a_nvcc_flags + + flag_mla_wrapper = ["-DMLA_WRAPPER=0"] + return gen_jit_spec( - f"xqa_use_fp16_{use_fp16}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}", + f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_sm_{sm_version}", [ jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu", + jit_env.FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu", + jit_env.FLASHINFER_CSRC_DIR / "xqa/tensorMap.cpp", + jit_env.FLASHINFER_CSRC_DIR / "xqa/xqa_wrapper.cu", + jit_env.FLASHINFER_CSRC_DIR / "flashinfer_xqa_binding.cu", + ], + extra_cuda_cflags=xqa_nvcc_flags + + sm_nvcc_flags + + flag_tokens_per_page + + flag_head_dim + + flag_input_dtype + + flag_kv_cache_dtype + + flag_head_group_ratio + + flag_sliding_window + + flag_mla_wrapper, + extra_ldflags=["-lcuda"], # Add CUDA Driver API library + extra_cflags=["-DPAGED_KV_CACHE_LAYOUT=1"], + ) + + +def gen_xqa_module_mla( + input_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, + page_size: int, + head_dim: int, + head_group_ratio: int, + use_sliding_window: bool = False, + sm_version: int = 120, +) -> JitSpec: + assert sm_version == 120 or sm_version == 121, ( + "Only SM 120 and 121 are supported for xqa MLA" + ) + assert head_group_ratio == 128, "Only head group ratio 128 is supported for xqa MLA" + assert head_dim == 576, "Only head dim 576 is supported for xqa_module_mla" + assert input_dtype == torch.float8_e4m3fn, ( + "Only fp8 input is supported for xqa_module_mla" + ) + assert kv_cache_dtype == torch.float8_e4m3fn, ( + "Only fp8 kv cache is supported for xqa_module_mla" + ) + assert not use_sliding_window, "Sliding window is not supported for xqa_module_mla" + + flag_kv_cache_dtype = ["-DCACHE_ELEM_ENUM=2"] + + if page_size not in [16, 32, 64, 128]: + raise ValueError( + f"Invalid page_size: {page_size}, only 16, 32, 64, 128 are supported" + ) + flag_tokens_per_page = [f"-DTOKENS_PER_PAGE={page_size}"] + + flag_head_dim = [f"-DHEAD_ELEMS={head_dim}"] + + flag_head_group_ratio = [f"-DHEAD_GRP_SIZE={head_group_ratio}"] + + flag_sliding_window = ["-DSLIDING_WINDOW=0"] + + if sm_version == 120: + sm_nvcc_flags = sm120a_nvcc_flags + elif sm_version == 121: + sm_nvcc_flags = sm121a_nvcc_flags + + flag_mla_wrapper = ["-DMLA_WRAPPER=1"] + + return gen_jit_spec( + f"xqa_mla_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_sm_{sm_version}", + [ + jit_env.FLASHINFER_CSRC_DIR / "xqa/mla_sm120.cu", + jit_env.FLASHINFER_CSRC_DIR / "xqa/tensorMap.cpp", jit_env.FLASHINFER_CSRC_DIR / "xqa/xqa_wrapper.cu", jit_env.FLASHINFER_CSRC_DIR / "flashinfer_xqa_binding.cu", ], extra_cuda_cflags=xqa_nvcc_flags - + sm90a_nvcc_flags + + sm_nvcc_flags + flag_tokens_per_page - + flag_head_size - + flag_use_fp16 - + flag_head_grp_size - + flag_sliding_window, + + flag_head_dim + + flag_kv_cache_dtype + + flag_head_group_ratio + + flag_sliding_window + + flag_mla_wrapper, + extra_ldflags=["-lcuda"], # Add CUDA Driver API library + extra_cflags=["-DPAGED_KV_CACHE_LAYOUT=1"], ) diff --git a/flashinfer/xqa.py b/flashinfer/xqa.py index 726002741e..ae2de86758 100644 --- a/flashinfer/xqa.py +++ b/flashinfer/xqa.py @@ -17,11 +17,12 @@ import functools from types import SimpleNamespace from typing import Optional - import torch -from .jit.xqa import gen_xqa_module +from .jit.xqa import gen_xqa_module, gen_xqa_module_mla +from .jit.utils import filename_safe_dtype_map from .utils import ( + get_device_sm_count, register_custom_op, register_fake_op, get_compute_capability, @@ -30,74 +31,88 @@ @functools.cache def get_xqa_module( - use_fp16: bool, - token_per_page: int, - head_size: int, - head_grp_size: int, + input_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, + page_size: int, + head_dim: int, + head_group_ratio: int, use_sliding_window: bool, + sm_version: int = 90, ): module = gen_xqa_module( - use_fp16, token_per_page, head_size, head_grp_size, use_sliding_window + input_dtype, + kv_cache_dtype, + page_size, + head_dim, + head_group_ratio, + use_sliding_window, + sm_version, ).build_and_load() @register_custom_op( - f"flashinfer::xqa_use_fp16_{use_fp16}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}", - mutates_args=("output", "scratch"), + f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_sm_{sm_version}", + mutates_args=("output", "workspace_buffer"), ) def xqa( - multiProcessorCount: int, - nbKHeads: int, - slidingWinSize: int, - qScale: float, + run_sm90_fp8_mha: bool, + sm_count: int, + num_kv_heads: int, + sliding_win_size: int, + q_scale: float, output: torch.Tensor, q: torch.Tensor, - attentionSinks: Optional[torch.Tensor], - pool: torch.Tensor, - kvCachePageList: torch.Tensor, - maxSeqLen: int, - seqLen: torch.Tensor, - batchSize: int, - kvCacheScale: torch.Tensor, + sinks: Optional[torch.Tensor], + k_cache: torch.Tensor, + v_cache: torch.Tensor, + page_table: torch.Tensor, + max_seq_len: int, + seq_lens: torch.Tensor, + batch_size: int, + kv_scale: torch.Tensor, semaphores: torch.Tensor, - scratch: torch.Tensor, + workspace_buffer: torch.Tensor, ) -> None: module.xqa_wrapper( - multiProcessorCount, - nbKHeads, - slidingWinSize, - qScale, + run_sm90_fp8_mha, + sm_count, + num_kv_heads, + sliding_win_size, + q_scale, output, q, - attentionSinks, - pool, - kvCachePageList, - maxSeqLen, - seqLen, - batchSize, - kvCacheScale, + sinks, + k_cache, + v_cache, + page_table, + max_seq_len, + seq_lens, + batch_size, + kv_scale, semaphores, - scratch, + workspace_buffer, ) @register_fake_op( - f"flashinfer::xqa_use_fp16_{use_fp16}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}" + f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_sm_{sm_version}" ) def _fake_xqa( - multiProcessorCount: int, - nbKHeads: int, - slidingWinSize: int, - qScale: float, + run_sm90_fp8_mha: bool, + sm_count: int, + num_kv_heads: int, + sliding_win_size: int, + q_scale: float, output: torch.Tensor, q: torch.Tensor, - attentionSinks: Optional[torch.Tensor], - pool: torch.Tensor, - kvCachePageList: torch.Tensor, - maxSeqLen: int, - seqLen: torch.Tensor, - batchSize: int, - kvCacheScale: torch.Tensor, + sinks: Optional[torch.Tensor], + k_cache: torch.Tensor, + v_cache: torch.Tensor, + page_table: torch.Tensor, + max_seq_len: int, + seq_lens: torch.Tensor, + batch_size: int, + kv_scale: torch.Tensor, semaphores: torch.Tensor, - scratch: torch.Tensor, + workspace_buffer: torch.Tensor, ) -> None: pass @@ -107,46 +122,343 @@ def _fake_xqa( def xqa( - use_fp16: bool, - token_per_page: int, - head_size: int, - head_grp_size: int, - use_sliding_window: bool, - sliding_win_size: int, - multiProcessorCount: int, - nbKHeads: int, - qScale: float, - output: torch.Tensor, q: torch.Tensor, - attentionSinks: Optional[torch.Tensor], - pool: torch.Tensor, - kvCachePageList: torch.Tensor, - maxSeqLen: int, - seqLen: torch.Tensor, - batchSize: int, - kvCacheScale: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + page_table: torch.Tensor, + seq_lens: torch.Tensor, + output: torch.Tensor, + workspace_buffer: torch.Tensor, semaphores: torch.Tensor, - scratch: torch.Tensor, + num_kv_heads: int, + page_size: int, + sinks: Optional[torch.Tensor] = None, + q_scale: float = 1.0, + kv_scale: Optional[torch.Tensor] = None, + sliding_win_size: int = 0, + sm_count: Optional[int] = None, ) -> None: - if get_compute_capability(torch.device(device="cuda"))[0] != 9: - raise RuntimeError("XQA is only supported on SM90 GPUs") + r"""Apply attention with paged KV cache using XQA kernel. + Parameters + ---------- + q : torch.Tensor + Query tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``. + Data type should be torch.float16 or torch.bfloat16. + Now only beam_width 1 is supported. + k_cache: torch.Tensor + Paged K cache tensor with shape ``[total_num_cache_heads, head_dim]``. + Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation. + Should be the same data type as v_cache. + v_cache: torch.Tensor + Paged V cache tensor with shape ``[total_num_cache_heads, head_dim]``. + Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation. + Should be the same data type as k_cache. + page_table : torch.Tensor + Page table tensor with shape ``batch_size, nb_pages_per_seq``. + Data type should be torch.uint32. + K and V share the same table. + seq_lens : torch.Tensor + Sequence lengths tensor with shape ``[batch_size, beam_width]``. + Data type should be torch.uint32. + output : torch.Tensor + Output tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``. + Data type should match query tensor. This tensor will be modified in-place. + workspace_buffer : torch.Tensor + Workspace buffer for temporary computations. + Data type should be torch.uint8. + semaphores : torch.Tensor + Semaphore buffer for synchronization. + Data type should be torch.uint32. + num_kv_heads : int + Number of key-value heads in the attention mechanism. + page_size : int + Size of each page in the paged KV cache. Must be one of [16, 32, 64, 128]. + sinks : Optional[torch.Tensor], default=None + Attention sink values with shape ``[num_kv_heads, head_group_ratio]``. + Data type should be torch.float32. + If None, no attention sinks are used. + q_scale : float, default=1.0 + Scale factor for query tensor. + kv_scale : Optional[torch.Tensor], default=None + Scale factor for KV cache with shape ``[1]``. + Data type should be torch.float32. + If None, defaults to 1.0. + sliding_win_size : int, default=0 + Sliding window size for attention. If 0, no sliding window is used. + sm_count : Optional[int], default=None + Number of streaming multiprocessors to use. + If None, will be inferred from the device. + Note + ---- + The function automatically infers several parameters from tensor shapes: + - batch_size from q.shape[0] + - num_q_heads from q.shape[2] + - head_dim from q.shape[-1] + - input_dtype from q.dtype + - kv_cache_dtype from k.dtype + - head_group_ratio from num_q_heads // num_kv_heads + """ + # Handle optional parameters + if sm_count is None: + sm_count = get_device_sm_count(q.device) + + if kv_scale is None: + kv_scale = torch.ones(1, dtype=torch.float32, device=q.device) + + # Infer parameters from tensors + batch_size = q.shape[0] + num_q_heads = q.shape[2] + head_dim = q.shape[-1] + + # Calculate head_group_ratio + head_group_ratio = num_q_heads // num_kv_heads + + # Calculate max_seq_len from page_table and page_size + num_pages_per_seq = page_table.shape[-1] + max_seq_len = num_pages_per_seq * page_size + + # Determine if sliding window is used + use_sliding_window = sliding_win_size > 0 + + assert k_cache.dtype == v_cache.dtype, "K and V cache must have the same dtype" + + if ( + k_cache.dtype == torch.float8_e4m3fn + and get_compute_capability(torch.device(device="cuda"))[0] == 9 + ): + run_sm90_fp8_mha = True + else: + run_sm90_fp8_mha = False + + if get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]: + raise RuntimeError("XQA is only supported on SM90, SM100, SM120 GPUs") + sm_version = int( + get_compute_capability(torch.device(device="cuda"))[0] * 10 + + get_compute_capability(torch.device(device="cuda"))[1] + ) + xqa_module = get_xqa_module( - use_fp16, token_per_page, head_size, head_grp_size, use_sliding_window + q.dtype, + k_cache.dtype, + page_size, + head_dim, + head_group_ratio, + use_sliding_window, + sm_version, ) xqa_module.xqa( - multiProcessorCount, - nbKHeads, + run_sm90_fp8_mha, + sm_count, + num_kv_heads, sliding_win_size if use_sliding_window else 0, - qScale, + q_scale, + output, + q, + sinks, + k_cache, + v_cache, + page_table, + max_seq_len, + seq_lens, + batch_size, + kv_scale, + semaphores, + workspace_buffer, + ) + + +@functools.cache +def get_xqa_module_mla( + input_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, + page_size: int, + head_dim: int, + head_group_ratio: int, + use_sliding_window: bool = False, + sm_version: int = 120, +): + module = gen_xqa_module_mla( + input_dtype, + kv_cache_dtype, + page_size, + head_dim, + head_group_ratio, + use_sliding_window, + sm_version, + ).build_and_load() + + @register_custom_op( + f"flashinfer::xqa_mla_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_sm_{sm_version}", + mutates_args=("output", "workspace_buffer"), + ) + def xqa_mla( + sm_count: int, + q_scale: float, + output: torch.Tensor, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + page_table: torch.Tensor, + max_seq_len: int, + seq_lens: torch.Tensor, + batch_size: int, + kv_scale: torch.Tensor, + semaphores: torch.Tensor, + workspace_buffer: torch.Tensor, + ) -> None: + module.xqa_wrapper_mla( + sm_count, + q_scale, + output, + q, + k_cache, + v_cache, + page_table, + max_seq_len, + seq_lens, + batch_size, + kv_scale, + semaphores, + workspace_buffer, + ) + + @register_fake_op( + f"flashinfer::xqa_mla_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_sm_{sm_version}" + ) + def _fake_xqa_mla( + sm_count: int, + q_scale: float, + output: torch.Tensor, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + page_table: torch.Tensor, + max_seq_len: int, + seq_lens: torch.Tensor, + batch_size: int, + kv_scale: torch.Tensor, + semaphores: torch.Tensor, + workspace_buffer: torch.Tensor, + ) -> None: + pass + + return SimpleNamespace( + xqa_mla=xqa_mla, + ) + + +def xqa_mla( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + page_table: torch.Tensor, + seq_lens: torch.Tensor, + output: torch.Tensor, + workspace_buffer: torch.Tensor, + semaphores: torch.Tensor, + page_size: int, + q_scale: float = 1.0, + kv_scale: Optional[torch.Tensor] = None, + sm_count: Optional[int] = None, +) -> None: + r"""Apply attention with paged KV cache using XQA kernel. + Parameters + ---------- + q : torch.Tensor + Query tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``. + Data type should be torch.float16 or torch.bfloat16. + Now only beam_width 1 is supported. + k_cache: torch.Tensor + Paged K cache tensor with shape ``[total_num_cache_heads, head_dim]``. + Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation. + Should be the same data type as v_cache. + v_cache: torch.Tensor + Paged V cache tensor with shape ``[total_num_cache_heads, head_dim]``. + Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation. + Should be the same data type as k_cache. + page_table : torch.Tensor + Page table tensor with shape ``batch_size, nb_pages_per_seq``. + Data type should be torch.uint32. + K and V share the same table. + seq_lens : torch.Tensor + Sequence lengths tensor with shape ``[batch_size, beam_width]``. + Data type should be torch.uint32. + output : torch.Tensor + Output tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``. + Data type should match query tensor. This tensor will be modified in-place. + workspace_buffer : torch.Tensor + Workspace buffer for temporary computations. + Data type should be torch.uint8. + semaphores : torch.Tensor + Semaphore buffer for synchronization. + Data type should be torch.uint32. + page_size : int + Size of each page in the paged KV cache. Must be one of [16, 32, 64, 128]. + q_scale : float, default=1.0 + Scale factor for query tensor. + kv_scale : Optional[torch.Tensor], default=None + Scale factor for KV cache with shape ``[1]``. + Data type should be torch.float32. + If None, defaults to 1.0. + sm_count : Optional[int], default=None + Number of streaming multiprocessors to use. + If None, will be inferred from the device. + Note + ---- + The function automatically infers several parameters from tensor shapes: + - batch_size from q.shape[0] + - head_dim from q.shape[-1] + - input_dtype from q.dtype + - kv_cache_dtype from k.dtype + """ + # Handle optional parameters + if sm_count is None: + sm_count = get_device_sm_count(q.device) + + if kv_scale is None: + kv_scale = torch.ones(1, dtype=torch.float32, device=q.device) + + # Infer parameters from tensors + batch_size = q.shape[0] + head_dim = q.shape[-1] + + # Calculate head_group_ratio + head_group_ratio = 128 + + # Calculate max_seq_len from page_table and page_size + num_pages_per_seq = page_table.shape[-1] + max_seq_len = num_pages_per_seq * page_size + + assert k_cache.dtype == v_cache.dtype, "K and V cache must have the same dtype" + + if get_compute_capability(torch.device(device="cuda"))[0] not in [12]: + raise RuntimeError("XQA is only supported on SM120 GPUs") + sm_version = int( + get_compute_capability(torch.device(device="cuda"))[0] * 10 + + get_compute_capability(torch.device(device="cuda"))[1] + ) + + xqa_module = get_xqa_module_mla( + q.dtype, + k_cache.dtype, + page_size, + head_dim, + head_group_ratio, + False, + sm_version, + ) + xqa_module.xqa_mla( + sm_count, + q_scale, output, q, - attentionSinks, - pool, - kvCachePageList, - maxSeqLen, - seqLen, - batchSize, - kvCacheScale, + k_cache, + v_cache, + page_table, + max_seq_len, + seq_lens, + batch_size, + kv_scale, semaphores, - scratch, + workspace_buffer, ) diff --git a/tests/attention/test_xqa.py b/tests/attention/test_xqa.py index 2bdbb9e579..a99849b61c 100644 --- a/tests/attention/test_xqa.py +++ b/tests/attention/test_xqa.py @@ -4,7 +4,7 @@ import pytest import torch -from flashinfer import xqa +from flashinfer import xqa, xqa_mla from flashinfer.utils import get_compute_capability @@ -49,10 +49,11 @@ def __init__( def __getitem__(self, i: int) -> torch.Tensor: page_idx = self.page_indices[i // self.tokens_per_page].to(torch.int32) + # VLLM layout (PAGED_KV_CACHE_LAYOUT=1): [page_idx][token_in_page][nb_heads][head_dim] idx_head = ( - self.tokens_per_page * self.nb_heads * page_idx - + self.tokens_per_page * self.idx_head - + i % self.tokens_per_page + page_idx * self.tokens_per_page * self.nb_heads + + (i % self.tokens_per_page) * self.nb_heads + + self.idx_head ) return self.pool[idx_head] @@ -68,23 +69,37 @@ def ref_attention( attention_sinks, sliding_win_size, valid_elems_per_head, + valid_elems_per_v_head=None, # Optional: for MLA where V dim != K dim ): + """ + For MLA: + - Q/K dimension: 576 (valid_elems_per_head) + - V dimension: 512 (valid_elems_per_v_head) + - Output dimension: matches valid_elems_per_head (576) but only first + valid_elems_per_v_head (512) elements are valid + """ head_grp_size = q.shape[0] rcp_x_scale = 1.0 / x_scale qk_scale = q_scale * kv_scale / math.sqrt(valid_elems_per_head) + # For MLA: V dimension may differ from K dimension + if valid_elems_per_v_head is None: + valid_elems_per_v_head = valid_elems_per_head + q_f32 = q.to(torch.float32) # [head_grp_size, valid_elems_per_head] k_cache_f32 = torch.zeros( seq_len, valid_elems_per_head, dtype=torch.float32, device="cuda" ) + # V cache: load only valid_elems_per_v_head dimensions v_cache_f32 = torch.zeros( - seq_len, valid_elems_per_head, dtype=torch.float32, device="cuda" + seq_len, valid_elems_per_v_head, dtype=torch.float32, device="cuda" ) for j in range(seq_len): k_cache_f32[j] = k_cache_seq[j].to(torch.float32) - v_cache_f32[j] = v_cache_seq[j].to(torch.float32) + # For MLA: V cache storage is 576 but only first 512 elements are valid + v_cache_f32[j] = v_cache_seq[j][:valid_elems_per_v_head].to(torch.float32) # q_f32: [head_grp_size, valid_elems_per_head] # k_cache_f32: [seq_len, valid_elems_per_head] @@ -125,14 +140,14 @@ def ref_attention( valid_x = x[:, seq_beg:seq_len] # [head_grp_size, valid_seq_len] valid_v_cache = v_cache_f32[ seq_beg:seq_len - ] # [valid_seq_len, valid_elems_per_head] + ] # [valid_seq_len, valid_elems_per_v_head] out = torch.matmul( valid_x, valid_v_cache - ) # [head_grp_size, valid_elems_per_head] + ) # [head_grp_size, valid_elems_per_v_head] else: out = torch.zeros( head_grp_size, - valid_elems_per_head, + valid_elems_per_v_head, dtype=torch.float32, device=q_f32.device, ) @@ -149,15 +164,16 @@ def ref_attention( @pytest.mark.skipif( - get_compute_capability(torch.device(device="cuda"))[0] != 9, - reason="XQA is only supported on SM90 GPUs", + get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12], + reason="XQA is only supported on SM90, SM100, SM120 GPUs", ) @pytest.mark.parametrize("use_sliding_window", [True, False]) -@pytest.mark.parametrize("use_fp16", [True, False]) +@pytest.mark.parametrize("fp16_input", [True, False]) +@pytest.mark.parametrize("fp8_kv_cache", [True, False]) @pytest.mark.parametrize("use_attention_sinks", [True, False]) @pytest.mark.parametrize("seq_len", [2, 15, 256, 514]) @pytest.mark.parametrize("batch_size", [1, 4]) -@pytest.mark.parametrize("nb_k_heads", [1, 4, 8]) +@pytest.mark.parametrize("nb_k_heads", [2, 4]) @pytest.mark.parametrize("tokens_per_page", [16, 64]) @pytest.mark.parametrize("valid_elems_per_head", [32, 128]) @pytest.mark.parametrize("head_grp_size", [8, 16]) @@ -166,18 +182,15 @@ def test_xqa( nb_k_heads, seq_len, tokens_per_page, - use_fp16, + fp16_input, + fp8_kv_cache, valid_elems_per_head, head_grp_size, use_attention_sinks, use_sliding_window, ): - compute_capability = get_compute_capability(torch.device(device="cuda")) - if compute_capability[0] != 9: - pytest.skip("XQA only supports on Hopper at this moment") set_random_seed(42) - nb_v_heads = nb_k_heads nb_q_heads = nb_k_heads * head_grp_size output = torch.zeros( @@ -185,7 +198,7 @@ def test_xqa( beam_width, nb_q_heads, valid_elems_per_head, - dtype=torch.bfloat16 if not use_fp16 else torch.float16, + dtype=torch.bfloat16 if not fp16_input else torch.float16, device="cuda", ) output.fill_(float("nan")) @@ -194,7 +207,7 @@ def test_xqa( beam_width, nb_q_heads, valid_elems_per_head, - dtype=torch.bfloat16 if not use_fp16 else torch.float16, + dtype=torch.bfloat16 if not fp16_input else torch.float16, device="cuda", ) q_heads.normal_(0, 1) @@ -213,25 +226,43 @@ def test_xqa( sliding_win_size = 0 max_seq_len = round_up(seq_len, tokens_per_page) - total_nb_cache_heads = ( - (nb_k_heads + nb_v_heads) * max_seq_len * beam_width * batch_size + nb_pages_per_seq = div_up(max_seq_len, tokens_per_page) + # Layout 1: K and V share page indices + # Total cache heads = nb_k_heads * max_seq_len * batch_size + total_nb_cache_heads = nb_k_heads * max_seq_len * batch_size + + cache_k_heads = torch.zeros( + total_nb_cache_heads, + valid_elems_per_head, + dtype=torch.bfloat16 if not fp16_input else torch.float16, + device="cuda", ) - cache_heads = torch.zeros( + cache_k_heads.normal_(0, 1) + + cache_v_heads = torch.zeros( total_nb_cache_heads, valid_elems_per_head, - dtype=torch.bfloat16 if not use_fp16 else torch.float16, + dtype=torch.bfloat16 if not fp16_input else torch.float16, device="cuda", ) - cache_heads.normal_(0, 1) + cache_v_heads.normal_(0, 1) - nb_pages_per_seq = div_up(max_seq_len, tokens_per_page) - total_nb_pages = nb_pages_per_seq * 2 * beam_width * batch_size + if fp8_kv_cache: + # Scale down the cache heads to keep values within the representable range of FP8 + # and prevent overflow during computation. The factor 4.0 is chosen empirically. + cache_k_heads /= 4.0 + cache_v_heads /= 4.0 page_list_arg = torch.zeros( - batch_size, beam_width, 2, nb_pages_per_seq, dtype=torch.uint32, device="cuda" + batch_size, nb_pages_per_seq, dtype=torch.uint32, device="cuda" ) - page_list_arg.view(-1)[:total_nb_pages] = torch.arange( - total_nb_pages, dtype=torch.int32, device="cuda" - ).to(torch.uint32) + + # Initialize page list sequentially + page_idx = 0 + for batch in range(batch_size): + for page in range(nb_pages_per_seq): + page_list_arg[batch, page] = page_idx + page_idx += 1 + flattened = page_list_arg.flatten() indices = torch.randperm(flattened.numel()) shuffled_flat = flattened.to(torch.int32)[indices].to(torch.uint32) @@ -242,25 +273,24 @@ def cache_head_at( is_k, idx_kv_head, pos, - cache_heads, + cache_k_heads, + cache_v_heads, page_list, beam_width, nb_k_heads, tokens_per_page, ): - beam = 0 - kv = 0 if is_k else 1 - - page_idx = page_list_arg[batch][beam][kv][pos // tokens_per_page].to( - torch.int32 - ) + # Layout 1: K and V share page indices + page_idx = page_list[batch][pos // tokens_per_page].to(torch.int32) + # VLLM layout: [page_idx][token_in_page][nb_heads][head_dim] idx_head = ( - tokens_per_page * (nb_k_heads * page_idx + idx_kv_head) - + pos % tokens_per_page + page_idx * tokens_per_page * nb_k_heads + + (pos % tokens_per_page) * nb_k_heads + + idx_kv_head ) - return cache_heads[idx_head] + return cache_k_heads[idx_head] if is_k else cache_v_heads[idx_head] for batch in range(batch_size): for kv in range(2): @@ -271,7 +301,8 @@ def cache_head_at( kv == 0, idx_kv_head, pos, - cache_heads, + cache_k_heads, + cache_v_heads, page_list_arg, beam_width, nb_k_heads, @@ -295,63 +326,287 @@ def cache_head_at( scratch_buf = torch.zeros(scratch_size, dtype=torch.uint8, device="cuda") xqa( - use_fp16, - tokens_per_page, - valid_elems_per_head, - head_grp_size, - use_sliding_window, - sliding_win_size, - sm_count, - nb_k_heads, - q_scale, - output, q_heads, - attention_sinks, - cache_heads, + cache_k_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_k_heads, + cache_v_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_v_heads, page_list_arg, - max_seq_len, seq_len_list, - batch_size, - kv_cache_scale, - semaphores, + output, scratch_buf, + semaphores, + nb_k_heads, + tokens_per_page, + sinks=attention_sinks, + q_scale=q_scale, + kv_scale=kv_cache_scale, + sliding_win_size=sliding_win_size, + sm_count=sm_count, ) for req in range(batch_size): for b in range(beam_width): for idx_k_head in range(nb_k_heads): + # Layout 1: K and V use separate pools but share page indices k_cache_seq = CacheSeq( - pool=cache_heads, - page_indices=page_list_arg[req][b][0], + pool=cache_k_heads, + page_indices=page_list_arg[req], nb_heads=nb_k_heads, idx_head=idx_k_head, tokens_per_page=tokens_per_page, ) v_cache_seq = CacheSeq( - pool=cache_heads, - page_indices=page_list_arg[req][b][1], + pool=cache_v_heads, + page_indices=page_list_arg[req], nb_heads=nb_k_heads, idx_head=idx_k_head, tokens_per_page=tokens_per_page, ) - ref_output = ref_attention( - q=q_heads[req][b][ - idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size - ], - k_cache_seq=k_cache_seq, - v_cache_seq=v_cache_seq, - seq_len=seq_len, - q_scale=q_scale, - kv_scale=kv_cache_scale[0], - x_scale=1.0, - attention_sinks=attention_sinks[idx_k_head, :] - if use_attention_sinks - else None, - sliding_win_size=sliding_win_size if use_sliding_window else 0, - valid_elems_per_head=valid_elems_per_head, + ref_output = ref_attention( + q=q_heads[req][b][ + idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size + ], + k_cache_seq=k_cache_seq, + v_cache_seq=v_cache_seq, + seq_len=seq_len, + q_scale=q_scale, + kv_scale=kv_cache_scale[0], + x_scale=1.0, + attention_sinks=attention_sinks[idx_k_head, :] + if use_attention_sinks + else None, + sliding_win_size=sliding_win_size if use_sliding_window else 0, + valid_elems_per_head=valid_elems_per_head, + ) + kernel_output = output[req][b][ + idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size + ].to(torch.float32) + if fp8_kv_cache: + atol = 0.05 + rtol = 0.05 + else: + atol = 0.01 + rtol = 0.01 + + diff_abs = torch.abs(ref_output - kernel_output) + diff_rel = diff_abs / (torch.abs(ref_output) + 1e-8) + + within_tolerance = (diff_abs <= atol) | (diff_rel <= rtol) + + pass_ratio = within_tolerance.float().mean().item() + + required_ratio = 0.99 + assert pass_ratio >= required_ratio, ( + f"req={req}, b={b}, idx_k_head={idx_k_head}: " + f"Total {ref_output.numel()} elements, only {pass_ratio:.1%} meet tolerance criteria, " + f"require at least {required_ratio:.1%}" + ) + + +@pytest.mark.skipif( + get_compute_capability(torch.device(device="cuda"))[0] not in [12], + reason="XQA mla is only supported on SM120 GPUs", +) +@pytest.mark.parametrize("seq_len", [2, 15, 256, 514, 2048]) +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("tokens_per_page", [32, 64]) +def test_xqa_mla( + batch_size, + seq_len, + tokens_per_page, +): + set_random_seed(42) + + # MLA specific constants (fixed, not parameterized) + nb_k_heads = 1 # MLA only supports 1 K head + head_grp_size = 128 # Fixed for MLA + valid_elems_per_head_qk = 576 # Q and K dimension + valid_elems_per_head_v = 512 # V dimension (output dimension) + + nb_q_heads = nb_k_heads * head_grp_size + + output = torch.zeros( + batch_size, + beam_width, + nb_q_heads, + valid_elems_per_head_v, # Output dimension is 512 (V dimension) + dtype=torch.bfloat16, + device="cuda", + ) + output.fill_(float("nan")) + q_heads = torch.zeros( + batch_size, + beam_width, + nb_q_heads, + valid_elems_per_head_qk, # Q dimension is 576 + dtype=torch.float32, + device="cuda", + ) + q_heads.normal_(0, 1) + + max_seq_len = round_up(seq_len, tokens_per_page) + nb_pages_per_seq = div_up(max_seq_len, tokens_per_page) + # Layout 1: K and V share page indices + # Total cache heads = nb_k_heads * max_seq_len * batch_size + total_nb_cache_heads = nb_k_heads * max_seq_len * batch_size + + cache_k_heads = torch.zeros( + total_nb_cache_heads, + valid_elems_per_head_qk, # K dimension is 576 + dtype=torch.float32, + device="cuda", + ) + cache_k_heads.normal_(0, 1) + + cache_v_heads = torch.zeros( + total_nb_cache_heads, + valid_elems_per_head_qk, # V storage is 576 (but only 512 used) + dtype=torch.float32, + device="cuda", + ) + cache_v_heads.normal_(0, 1) + + cache_k_heads /= 4.0 + cache_v_heads /= 4.0 + + page_list_arg = torch.zeros( + batch_size, nb_pages_per_seq, dtype=torch.uint32, device="cuda" + ) + + # Initialize page list sequentially + page_idx = 0 + for batch in range(batch_size): + for page in range(nb_pages_per_seq): + page_list_arg[batch, page] = page_idx + page_idx += 1 + + flattened = page_list_arg.flatten() + indices = torch.randperm(flattened.numel()) + shuffled_flat = flattened.to(torch.int32)[indices].to(torch.uint32) + page_list_arg = shuffled_flat.view(page_list_arg.shape) + + def cache_head_at( + batch, + is_k, + idx_kv_head, + pos, + cache_k_heads, + cache_v_heads, + page_list, + beam_width, + nb_k_heads, + tokens_per_page, + ): + # Layout 1: K and V share page indices + page_idx = page_list[batch][pos // tokens_per_page].to(torch.int32) + + # VLLM layout: [page_idx][token_in_page][nb_heads][head_dim] + idx_head = ( + page_idx * tokens_per_page * nb_k_heads + + (pos % tokens_per_page) * nb_k_heads + + idx_kv_head ) - kernel_output = output[req][b][ - idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size - ].to(torch.float32) - assert torch.allclose(ref_output, kernel_output, atol=0.01, rtol=0.01) + + return cache_k_heads[idx_head] if is_k else cache_v_heads[idx_head] + + for batch in range(batch_size): + for kv in range(2): + for idx_kv_head in range(nb_k_heads): + for pos in range(seq_len, max_seq_len): + cache_head = cache_head_at( + batch, + kv == 0, + idx_kv_head, + pos, + cache_k_heads, + cache_v_heads, + page_list_arg, + beam_width, + nb_k_heads, + tokens_per_page, + ) + cache_head.fill_(0.0) + + seq_len_list = torch.zeros( + batch_size, beam_width, dtype=torch.uint32, device="cuda" + ) + seq_len_list.fill_(seq_len) + + kv_cache_scale = torch.ones(1, dtype=torch.float32, device="cuda") + + nb_seq = nb_k_heads * batch_size + nb_semaphores = round_up(nb_seq, 2) + 2 + nb_seq + 2 + + semaphores = torch.zeros(nb_semaphores, dtype=torch.uint32, device="cuda") + + scratch_size = 256 << 20 + scratch_buf = torch.zeros(scratch_size, dtype=torch.uint8, device="cuda") + + xqa_mla( + q_heads.to(torch.float8_e4m3fn), + cache_k_heads.to(torch.float8_e4m3fn), + cache_v_heads.to(torch.float8_e4m3fn), + page_list_arg, + seq_len_list, + output, + scratch_buf, + semaphores, + tokens_per_page, + q_scale=q_scale, + kv_scale=kv_cache_scale, + sm_count=sm_count, + ) + + for req in range(batch_size): + for b in range(beam_width): + for idx_k_head in range(nb_k_heads): + # Layout 1: K and V use separate pools but share page indices + k_cache_seq = CacheSeq( + pool=cache_k_heads, + page_indices=page_list_arg[req], + nb_heads=nb_k_heads, + idx_head=idx_k_head, + tokens_per_page=tokens_per_page, + ) + v_cache_seq = CacheSeq( + pool=cache_v_heads, + page_indices=page_list_arg[req], + nb_heads=nb_k_heads, + idx_head=idx_k_head, + tokens_per_page=tokens_per_page, + ) + + ref_output = ref_attention( + q=q_heads[req][b][ + idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size + ], + k_cache_seq=k_cache_seq, + v_cache_seq=v_cache_seq, + seq_len=seq_len, + q_scale=q_scale * math.sqrt(576), + kv_scale=kv_cache_scale[0], + x_scale=1.0, + attention_sinks=None, + sliding_win_size=0, + valid_elems_per_head=valid_elems_per_head_qk, # Q/K dimension (576) + valid_elems_per_v_head=valid_elems_per_head_v, # V dimension (512) + ).to(torch.float32) + kernel_output = output[req][b][ + idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size + ].to(torch.float32) + atol = 0.05 + rtol = 0.05 + + diff_abs = torch.abs(ref_output - kernel_output) + diff_rel = diff_abs / (torch.abs(ref_output) + 1e-8) + + within_tolerance = (diff_abs <= atol) | (diff_rel <= rtol) + + pass_ratio = within_tolerance.float().mean().item() + + required_ratio = 0.95 + assert pass_ratio >= required_ratio, ( + f"req={req}, b={b}, idx_k_head={idx_k_head}: " + f"Total {ref_output.numel()} elements, only {pass_ratio:.1%} meet tolerance criteria, " + f"require at least {required_ratio:.1%}" + )