Skip to content

Commit 1fb7393

Browse files
authored
Merge branch 'dsv3_dev' into rewrite_fp8_utils
2 parents 7e4c254 + 8a0986e commit 1fb7393

25 files changed

+2269
-1141
lines changed

ops/csrc/fp8/deep_gemm/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
get_col_major_tma_aligned_tensor,
2424
get_m_alignment_for_contiguous_layout,
2525
get_num_sms,
26+
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt,
2627
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
2728
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
2829
set_num_sms,
30+
wgrad_gemm_fp8_fp8_fp32_nt,
2931
)
30-
from .utils import bench, calc_diff, get_cuda_home
32+
from .utils import calc_diff

ops/csrc/fp8/deep_gemm/include/deep_gemm/fp8_gemm.cuh

Lines changed: 74 additions & 187 deletions
Large diffs are not rendered by default.

ops/csrc/fp8/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh

Lines changed: 381 additions & 0 deletions
Large diffs are not rendered by default.

ops/csrc/fp8/deep_gemm/include/deep_gemm/mma_utils.cuh

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
// The file has been adapted from DeepSeek DeepEP project
15+
// The file has been adapted from DeepSeek DeepGEMM project
1616
// Copyright (c) 2025 DeepSeek
17-
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE
17+
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
1818

1919
#pragma once
2020

21+
#ifndef __CUDACC_RTC__
2122
#include <cuda.h>
23+
#endif
2224

2325
#include <cute/arch/mma_sm90_gmma.hpp>
2426
#include <cute/arch/mma_sm90_gmma_ext.hpp>
@@ -84,6 +86,12 @@ __device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
8486
return ret;
8587
}
8688

89+
__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) {
90+
float2 ret;
91+
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
92+
return ret;
93+
}
94+
8795
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
8896
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
8997
}
@@ -92,6 +100,10 @@ __device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
92100
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
93101
}
94102

103+
__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
104+
asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y));
105+
}
106+
95107
template <int N>
96108
__device__ void warpgroup_wait() {
97109
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
@@ -186,6 +198,7 @@ struct FP8MMASelector {
186198
if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN();
187199
if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN();
188200
if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN();
201+
if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN();
189202
if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN();
190203
if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN();
191204
if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN();
@@ -199,4 +212,19 @@ struct FP8MMASelector {
199212
using type = decltype(select_type());
200213
};
201214

215+
enum class Layout {
216+
RowMajor,
217+
ColMajor
218+
};
219+
220+
__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) {
221+
return block_m == 64 ? 1 : 2;
222+
}
223+
224+
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
225+
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
226+
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
227+
return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads;
228+
}
229+
202230
} // namespace deep_gemm
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// The file has been adapted from DeepSeek DeepGEMM project
16+
// Copyright (c) 2025 DeepSeek
17+
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
18+
19+
#pragma once
20+
21+
#ifdef __CUDACC_RTC__
22+
23+
using int8_t = signed char;
24+
using uint8_t = unsigned char;
25+
using int16_t = signed short;
26+
using uint16_t = unsigned short;
27+
using int32_t = signed int;
28+
using uint32_t = unsigned int;
29+
using int64_t = signed long long;
30+
using uint64_t = unsigned long long;
31+
using cuuint64_t = unsigned long long;
32+
33+
#ifndef CU_TENSOR_MAP_NUM_QWORDS
34+
#define CU_TENSOR_MAP_NUM_QWORDS 16
35+
36+
struct CUtensorMap_st {
37+
#if defined(__cplusplus) && (__cplusplus >= 201103L)
38+
alignas(64)
39+
#elif __STDC_VERSION__ >= 201112L
40+
_Alignas(64)
41+
#endif
42+
cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS];
43+
};
44+
45+
using CUtensorMap = CUtensorMap_st;
46+
#endif
47+
48+
namespace std {
49+
50+
template <class T, T v> struct integral_constant {
51+
static constexpr T value = v;
52+
53+
using value_type = T;
54+
using type = integral_constant;
55+
56+
__device__ constexpr operator value_type() const noexcept { return value; }
57+
58+
__device__ constexpr value_type operator()() const noexcept { return value; }
59+
};
60+
61+
using false_type = integral_constant<bool, false>;
62+
using true_type = integral_constant<bool, true>;
63+
64+
template <class T, class U> struct is_same : false_type {};
65+
66+
template <class T> struct is_same<T, T> : true_type {};
67+
68+
template <class T, class U>
69+
inline constexpr bool is_same_v = is_same<T, U>::value;
70+
71+
namespace index_sequence_impl {
72+
73+
// Based on https://stackoverflow.com/a/32223343/11717224
74+
template <size_t... Ints> struct index_sequence {
75+
using type = index_sequence;
76+
using value_type = size_t;
77+
static constexpr size_t size() noexcept { return sizeof...(Ints); }
78+
};
79+
80+
template <class Sequence1, class Sequence2> struct _merge_and_renumber;
81+
82+
template <size_t... I1, size_t... I2>
83+
struct _merge_and_renumber<index_sequence<I1...>, index_sequence<I2...>>
84+
: index_sequence<I1..., (sizeof...(I1) + I2)...> {};
85+
86+
template <size_t N>
87+
struct make_index_sequence
88+
: _merge_and_renumber<typename make_index_sequence<N / 2>::type,
89+
typename make_index_sequence<N - N / 2>::type> {};
90+
91+
template <> struct make_index_sequence<0> : index_sequence<> {};
92+
template <> struct make_index_sequence<1> : index_sequence<0> {};
93+
94+
} // namespace index_sequence_impl
95+
96+
template <size_t... Ns>
97+
using index_sequence = index_sequence_impl::index_sequence<Ns...>;
98+
99+
template <size_t N>
100+
using make_index_sequence = index_sequence_impl::make_index_sequence<N>;
101+
102+
} // namespace std
103+
104+
#endif

ops/csrc/fp8/deep_gemm/include/deep_gemm/scheduler.cuh

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
// The file has been adapted from DeepSeek DeepEP project
15+
// The file has been adapted from DeepSeek DeepGEMM project
1616
// Copyright (c) 2025 DeepSeek
17-
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE
17+
// Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM/blob/main/LICENSE
1818

19+
#pragma once
1920
#include "utils.cuh"
2021

2122
namespace deep_gemm {
@@ -41,13 +42,16 @@ struct Scheduler {
4142
// For normal GEMM
4243
// Maybe not used in the masked grouped GEMM
4344
uint32_t num_blocks;
45+
uint32_t num_blocks_in_group;
46+
bool is_peer_cta_alive = true;
4447

4548
// For grouped GEMM
4649
int* grouped_layout;
50+
4751
// Only used for masked layout
4852
uint32_t curr_group_idx, curr_cumsum;
4953

50-
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
54+
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m,
5155
int* grouped_layout = nullptr) {
5256
num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M);
5357
if constexpr (kGemmType == GemmType::Normal) {
@@ -61,39 +65,77 @@ struct Scheduler {
6165
}
6266
}
6367

64-
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
68+
// ReSharper disable once CppNotAllPathsReturnValue
69+
__device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const {
70+
if constexpr (kGemmType == GemmType::Normal) {
71+
return true;
72+
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
73+
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0;
74+
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
75+
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + curr_group_idx);
76+
}
77+
}
78+
79+
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
80+
if (num_blocks_in_group == 1)
81+
return false;
82+
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) {
83+
return true;
84+
} else {
85+
DG_STATIC_ASSERT(kGemmType == GemmType::GroupedContiguous, "Invalid Gemm type");
86+
if constexpr (kIsTMAMulticastOnA) {
87+
return true;
88+
} else {
89+
auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
90+
auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
91+
return group_idx == peer_group_idx;
92+
}
93+
}
94+
}
95+
96+
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& num_m_blocks, const uint32_t& block_idx,
97+
uint32_t& m_block_idx, uint32_t& n_block_idx) {
6598
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
6699

67100
// Swizzle for better L2 usages
68-
// TODO: unify these 2 branches
101+
auto primary_num_blocks = kIsTMAMulticastOnA ? kNumNBlocks : num_m_blocks;
102+
auto secondary_num_blocks = kIsTMAMulticastOnA ? num_m_blocks : kNumNBlocks;
103+
auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
104+
auto group_idx = block_idx / num_blocks_per_group;
105+
auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
106+
auto in_group_idx = block_idx % num_blocks_per_group;
107+
num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
108+
109+
// Fix unaligned TMA multicast
110+
if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) {
111+
if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
112+
num_blocks_in_group = num_blocks_in_group ^ 1;
113+
} else {
114+
in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
115+
first_block_idx += num_blocks_in_group ^ 1;
116+
num_blocks_in_group = 1;
117+
}
118+
}
119+
120+
// Convert to final M/N block indices
69121
if constexpr (kIsTMAMulticastOnA) {
70-
auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup;
71-
auto group_idx = block_idx / num_blocks_per_group;
72-
auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup;
73-
auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx);
74-
auto in_group_idx = block_idx % num_blocks_per_group;
75-
m_block_idx = in_group_idx / num_n_blocks_in_group;
76-
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
122+
m_block_idx = in_group_idx / num_blocks_in_group;
123+
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
77124
} else {
78-
auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup;
79-
auto group_idx = block_idx / num_blocks_per_group;
80-
auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup;
81-
auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx);
82-
auto in_group_idx = block_idx % num_blocks_per_group;
83-
m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group;
84-
n_block_idx = in_group_idx / num_m_blocks_in_group;
125+
m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
126+
n_block_idx = in_group_idx / num_blocks_in_group;
85127
}
86128
}
87129

88130
template <bool kIgnoreGroupedForGroupedContiguous=true>
89-
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
131+
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t& shape_dim, const uint32_t& block_size,
90132
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
91133
if constexpr (kGemmType == GemmType::Normal) {
92134
return block_idx * block_size;
93-
} else if (kGemmType == GemmType::GroupedContiguous) {
135+
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
94136
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
95137
return offset * shape_dim + block_idx * block_size;
96-
} else if (kGemmType == GemmType::GroupedMasked) {
138+
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
97139
return curr_group_idx * shape_dim + block_idx * block_size;
98140
}
99141
}
@@ -108,7 +150,7 @@ struct Scheduler {
108150
if (curr_group_idx == kNumGroups)
109151
return false;
110152

111-
// Within current group
153+
// Within the current group
112154
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
113155
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
114156
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
@@ -123,6 +165,10 @@ struct Scheduler {
123165
if (next_block_idx >= num_blocks)
124166
return false;
125167

168+
// NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
169+
is_peer_cta_alive = kNumNBlocks % kNumTMAMulticast == 0 or // Always aligned on N (constant bypass)
170+
num_aligned_m_blocks % kNumTMAMulticast == 0 or // Always aligned on M (constant bypass)
171+
(next_block_idx ^ 1) < num_blocks; // Peer CTA in bound
126172
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
127173
}
128174
return true;

0 commit comments

Comments
 (0)