12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
- // The file has been adapted from DeepSeek DeepEP project
15
+ // The file has been adapted from DeepSeek DeepGEMM project
16
16
// Copyright (c) 2025 DeepSeek
17
- // Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP /blob/main/LICENSE
17
+ // Licensed under the MIT License - https://github.com/deepseek-ai/DeepGEMM /blob/main/LICENSE
18
18
19
+ #pragma once
19
20
#include " utils.cuh"
20
21
21
22
namespace deep_gemm {
@@ -41,13 +42,16 @@ struct Scheduler {
41
42
// For normal GEMM
42
43
// Maybe not used in the masked grouped GEMM
43
44
uint32_t num_blocks;
45
+ uint32_t num_blocks_in_group;
46
+ bool is_peer_cta_alive = true ;
44
47
45
48
// For grouped GEMM
46
49
int * grouped_layout;
50
+
47
51
// Only used for masked layout
48
52
uint32_t curr_group_idx, curr_cumsum;
49
53
50
- __device__ __forceinline__ explicit Scheduler (const uint32_t shape_m,
54
+ __device__ __forceinline__ explicit Scheduler (const uint32_t & shape_m,
51
55
int * grouped_layout = nullptr ) {
52
56
num_aligned_m_blocks = ceil_div (shape_m, BLOCK_M);
53
57
if constexpr (kGemmType == GemmType::Normal) {
@@ -61,39 +65,77 @@ struct Scheduler {
61
65
}
62
66
}
63
67
64
- __device__ __forceinline__ void get_swizzled_block_idx (const uint32_t num_m_blocks, int block_idx, uint32_t & m_block_idx, uint32_t & n_block_idx) {
68
+ // ReSharper disable once CppNotAllPathsReturnValue
69
+ __device__ __forceinline__ bool is_computation_valid (const uint32_t & m_block_idx, const uint32_t & m_offset) const {
70
+ if constexpr (kGemmType == GemmType::Normal) {
71
+ return true ;
72
+ } else if constexpr (kGemmType == GemmType::GroupedContiguous) {
73
+ return __ldg (grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0 ;
74
+ } else if constexpr (kGemmType == GemmType::GroupedMasked) {
75
+ return m_offset + m_block_idx * BLOCK_M < __ldg (grouped_layout + curr_group_idx);
76
+ }
77
+ }
78
+
79
+ __device__ __forceinline__ bool is_tma_multicast_valid (const uint32_t & m_block_idx) const {
80
+ if (num_blocks_in_group == 1 )
81
+ return false ;
82
+ if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) {
83
+ return true ;
84
+ } else {
85
+ DG_STATIC_ASSERT (kGemmType == GemmType::GroupedContiguous, " Invalid Gemm type" );
86
+ if constexpr (kIsTMAMulticastOnA ) {
87
+ return true ;
88
+ } else {
89
+ auto group_idx = __ldg (grouped_layout + m_block_idx * BLOCK_M);
90
+ auto peer_group_idx = __ldg (grouped_layout + (m_block_idx ^ 1 ) * BLOCK_M);
91
+ return group_idx == peer_group_idx;
92
+ }
93
+ }
94
+ }
95
+
96
+ __device__ __forceinline__ void get_swizzled_block_idx (const uint32_t & num_m_blocks, const uint32_t & block_idx,
97
+ uint32_t & m_block_idx, uint32_t & n_block_idx) {
65
98
DG_STATIC_ASSERT (kNum1DBlocksPerGroup % kNumTMAMulticast == 0 , " Invalid group size" );
66
99
67
100
// Swizzle for better L2 usages
68
- // TODO: unify these 2 branches
101
+ auto primary_num_blocks = kIsTMAMulticastOnA ? kNumNBlocks : num_m_blocks;
102
+ auto secondary_num_blocks = kIsTMAMulticastOnA ? num_m_blocks : kNumNBlocks ;
103
+ auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup ;
104
+ auto group_idx = block_idx / num_blocks_per_group;
105
+ auto first_block_idx = group_idx * kNum1DBlocksPerGroup ;
106
+ auto in_group_idx = block_idx % num_blocks_per_group;
107
+ num_blocks_in_group = min (kNum1DBlocksPerGroup , primary_num_blocks - first_block_idx);
108
+
109
+ // Fix unaligned TMA multicast
110
+ if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0 ) {
111
+ if (in_group_idx < (num_blocks_in_group ^ 1 ) * secondary_num_blocks) {
112
+ num_blocks_in_group = num_blocks_in_group ^ 1 ;
113
+ } else {
114
+ in_group_idx = in_group_idx - (num_blocks_in_group ^ 1 ) * secondary_num_blocks;
115
+ first_block_idx += num_blocks_in_group ^ 1 ;
116
+ num_blocks_in_group = 1 ;
117
+ }
118
+ }
119
+
120
+ // Convert to final M/N block indices
69
121
if constexpr (kIsTMAMulticastOnA ) {
70
- auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup ;
71
- auto group_idx = block_idx / num_blocks_per_group;
72
- auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup ;
73
- auto num_n_blocks_in_group = min (kNum1DBlocksPerGroup , kNumNBlocks - first_n_block_idx);
74
- auto in_group_idx = block_idx % num_blocks_per_group;
75
- m_block_idx = in_group_idx / num_n_blocks_in_group;
76
- n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
122
+ m_block_idx = in_group_idx / num_blocks_in_group;
123
+ n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
77
124
} else {
78
- auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup ;
79
- auto group_idx = block_idx / num_blocks_per_group;
80
- auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup ;
81
- auto num_m_blocks_in_group = min (kNum1DBlocksPerGroup , num_m_blocks - first_m_block_idx);
82
- auto in_group_idx = block_idx % num_blocks_per_group;
83
- m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group;
84
- n_block_idx = in_group_idx / num_m_blocks_in_group;
125
+ m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
126
+ n_block_idx = in_group_idx / num_blocks_in_group;
85
127
}
86
128
}
87
129
88
130
template <bool kIgnoreGroupedForGroupedContiguous =true >
89
- __device__ __forceinline__ uint32_t get_global_idx (const uint32_t shape_dim, const uint32_t block_size,
131
+ __device__ __forceinline__ uint32_t get_global_idx (const uint32_t & shape_dim, const uint32_t & block_size,
90
132
const uint32_t & block_idx, const uint32_t & m_block_idx=0 ) {
91
133
if constexpr (kGemmType == GemmType::Normal) {
92
134
return block_idx * block_size;
93
- } else if (kGemmType == GemmType::GroupedContiguous) {
135
+ } else if constexpr (kGemmType == GemmType::GroupedContiguous) {
94
136
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg (grouped_layout + m_block_idx * BLOCK_M);
95
137
return offset * shape_dim + block_idx * block_size;
96
- } else if (kGemmType == GemmType::GroupedMasked) {
138
+ } else if constexpr (kGemmType == GemmType::GroupedMasked) {
97
139
return curr_group_idx * shape_dim + block_idx * block_size;
98
140
}
99
141
}
@@ -108,7 +150,7 @@ struct Scheduler {
108
150
if (curr_group_idx == kNumGroups )
109
151
return false ;
110
152
111
- // Within current group
153
+ // Within the current group
112
154
num_m_blocks = ceil_div (static_cast <uint32_t >(__ldg (grouped_layout + curr_group_idx)), BLOCK_M);
113
155
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
114
156
if (next_block_idx < current_m_block_cumsum * kNumNBlocks )
@@ -123,6 +165,10 @@ struct Scheduler {
123
165
if (next_block_idx >= num_blocks)
124
166
return false ;
125
167
168
+ // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
169
+ is_peer_cta_alive = kNumNBlocks % kNumTMAMulticast == 0 or // Always aligned on N (constant bypass)
170
+ num_aligned_m_blocks % kNumTMAMulticast == 0 or // Always aligned on M (constant bypass)
171
+ (next_block_idx ^ 1 ) < num_blocks; // Peer CTA in bound
126
172
get_swizzled_block_idx (num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
127
173
}
128
174
return true ;
0 commit comments