Skip to content

Commit 8b2c265

Browse files
Aya-ZIbrameta-codesync[bot]
authored andcommitted
Write back LSE (#5209)
Summary: Pull Request resolved: #5209 X-link: https://github.com/facebookresearch/FBGEMM/pull/2204 * **Python interface**: Modifies `fmha_gen_fwd` to return LSE tensor instead of creating a dummy one * **CUDA implementation**: Adds LSE tensor allocation and computation logic * **Epilogue**: Adds LSE computation and storage in the epilogue * **Mainloop**: Updates `correction_epilogue` to compute and write LSE values Reviewed By: jsisometa Differential Revision: D86949420 fbshipit-source-id: bf6fd9fa616d91c3b758b8a47a933690a88a9b80
1 parent 2aa8cd0 commit 8b2c265

File tree

6 files changed

+82
-47
lines changed

6 files changed

+82
-47
lines changed

fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -180,26 +180,6 @@ def _prepare_decode_inputs(
180180
return q, k, v, batch_size, needs_reshape_output, original_shape
181181

182182

183-
def _create_decode_lse(
184-
out: torch.Tensor,
185-
batch_size: int,
186-
needs_reshape_output: bool,
187-
q_shape: tuple[int, ...],
188-
) -> torch.Tensor:
189-
"""
190-
Create dummy LSE tensor for decode output compatibility.
191-
Gen kernel doesn't return LSE, so we create a zero tensor.
192-
"""
193-
if needs_reshape_output:
194-
# For varlen output format
195-
lse_shape = [batch_size, q_shape[-1]] # [B, H]
196-
else:
197-
# For batch output format
198-
lse_shape = [batch_size, q_shape[-2], q_shape[1]] # [B, H, 1]
199-
200-
return torch.zeros(*lse_shape, dtype=torch.float32, device=out.device)
201-
202-
203183
def cutlass_blackwell_fmha_decode_forward(
204184
q: torch.Tensor,
205185
k: torch.Tensor,
@@ -233,7 +213,7 @@ def cutlass_blackwell_fmha_decode_forward(
233213
q, k, v
234214
)
235215
# Call the gen kernel (optimized for decode)
236-
out = torch.ops.fbgemm.fmha_gen_fwd(
216+
out, lse = torch.ops.fbgemm.fmha_gen_fwd(
237217
q,
238218
k,
239219
v,
@@ -248,9 +228,6 @@ def cutlass_blackwell_fmha_decode_forward(
248228
if needs_reshape_output:
249229
out = out.view(*original_shape)
250230

251-
# Create dummy LSE for compatibility
252-
lse = _create_decode_lse(out, batch_size, needs_reshape_output, original_shape)
253-
254231
return out, lse
255232

256233

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ struct GenRunner {
9999
using StrideNewV = StrideNewK;
100100
using StrideCacheV = StrideCacheK;
101101
using StrideO = StrideQ;
102+
using StrideLSE = Stride<int, int, _1>;
102103

103104
using Mainloop =
104105
cutlass::fmha::collective::Sm100FmhaGenMainloopWarpspecialized<
@@ -117,7 +118,9 @@ struct GenRunner {
117118
using Epilogue =
118119
cutlass::fmha::collective::Sm100FmhaGenEpilogueWarpspecialized<
119120
ElementOut,
120-
StrideO>;
121+
StrideO,
122+
ElementAcc,
123+
StrideLSE>;
121124

122125
using TileScheduler = std::conditional_t<
123126
kKernelType == KernelType::UMMA_P,
@@ -138,12 +141,14 @@ struct GenRunner {
138141
StrideCacheK stride_cache_k;
139142
StrideCacheV stride_cache_v;
140143
StrideO stride_o;
144+
StrideLSE stride_lse;
141145

142146
at::Tensor block_o;
147+
at::Tensor block_lse;
143148
at::Tensor q, k, v, seqlen_kv;
144149
std::optional<at::Tensor> batch_idx;
145150

146-
at::Tensor fmha_fwd(
151+
std::tuple<at::Tensor, at::Tensor> fmha_fwd(
147152
const at::Tensor& q_input,
148153
const at::Tensor& k_input,
149154
const at::Tensor& v_input,
@@ -177,7 +182,7 @@ struct GenRunner {
177182

178183
run(options, hw_info);
179184

180-
return block_o;
185+
return std::make_tuple(block_o, block_lse);
181186
}
182187

183188
ProblemShape _initialize(const InputShape& options) {
@@ -210,13 +215,20 @@ struct GenRunner {
210215
stride_new_v = stride_new_k;
211216
stride_cache_v = stride_cache_k;
212217
stride_o = stride_q;
218+
stride_lse = make_stride(options.h_k * h_r, h_r, cute::_1{});
213219

214220
block_o = at::empty(
215221
q.sizes(),
216222
at::TensorOptions()
217223
.dtype(to_torch_type<ElementOut>())
218224
.device(at::Device(at::kCUDA, at::cuda::current_device())));
219225

226+
block_lse = at::empty(
227+
{options.b, options.h, _1{}},
228+
at::TensorOptions()
229+
.dtype(at::kFloat)
230+
.device(at::Device(at::kCUDA, at::cuda::current_device())));
231+
220232
return result;
221233
}
222234

@@ -241,6 +253,8 @@ struct GenRunner {
241253
stride_cache_v,
242254
static_cast<ElementOut*>(block_o.data_ptr()),
243255
stride_o,
256+
static_cast<ElementAcc*>(block_lse.data_ptr()),
257+
stride_lse,
244258
hw_info};
245259

246260
Operation op;
@@ -306,7 +320,7 @@ struct GenRunner {
306320
}()
307321

308322
template <typename Element, KernelType KType, int HeadDim>
309-
at::Tensor run_gen_runner_fwd(
323+
std::tuple<at::Tensor, at::Tensor> run_gen_runner_fwd(
310324
const at::Tensor& q,
311325
const at::Tensor& k,
312326
const at::Tensor& v,
@@ -321,7 +335,7 @@ at::Tensor run_gen_runner_fwd(
321335
}
322336
}
323337

324-
at::Tensor dispatch_fmha_gen_fwd(
338+
std::tuple<at::Tensor, at::Tensor> dispatch_fmha_gen_fwd(
325339
const at::Tensor& q,
326340
const at::Tensor& k,
327341
const at::Tensor& v,
@@ -343,30 +357,38 @@ at::Tensor dispatch_fmha_gen_fwd(
343357
});
344358
}
345359

346-
at::Tensor dispatch_fmha_gen_fwd_meta(
360+
std::tuple<at::Tensor, at::Tensor> dispatch_fmha_gen_fwd_meta(
347361
const at::Tensor& q,
348362
const at::Tensor& k,
349363
const at::Tensor& v,
350364
const at::Tensor& seqlen_kv,
351365
const std::optional<at::Tensor>& batch_idx,
352366
int64_t kernel_type
353367
) {
354-
return at::empty_like(q);
368+
// Return tuple matching the operator signature: (output, lse)
369+
at::Tensor output = at::empty_like(q);
370+
// LSE should have shape [B, num_splits, H]
371+
int b = q.size(0);
372+
int h = q.size(2);
373+
// For meta, just create a dummy LSE with single split
374+
at::Tensor lse = at::empty(
375+
{b, 1, h},
376+
at::TensorOptions().dtype(at::kFloat).device(at::kMeta));
377+
return std::make_tuple(output, lse);
355378
}
356379

357380
// -------------------------------------------------------------------------------------------------
358381
// Op registration
359382
// -------------------------------------------------------------------------------------------------
360383
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
361384
m.def("fmha_gen_fwd("
362-
" Tensor query, "
363-
" Tensor key, "
364-
" Tensor value, "
365-
" Tensor seqlen_kv, "
366-
" Tensor? batch_idx = None,"
367-
" int kernel_type = 0"
368-
") -> Tensor"
369-
);
385+
" Tensor query, "
386+
" Tensor key, "
387+
" Tensor value, "
388+
" Tensor seqlen_kv, "
389+
" Tensor? batch_idx = None,"
390+
" int kernel_type = 0"
391+
") -> (Tensor, Tensor)");
370392
}
371393

372394
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_interface.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ at::ScalarType to_torch_type() {
4141
}
4242

4343
// Main dispatch function for the generation FMHA
44-
at::Tensor dispatch_fmha_gen_fwd(
44+
std::tuple<at::Tensor, at::Tensor> dispatch_fmha_gen_fwd(
4545
const at::Tensor& q,
4646
const at::Tensor& k,
4747
const at::Tensor& v,

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_epilogue_warpspecialized.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ namespace cutlass::fmha::collective {
3838

3939
template<
4040
class Element_,
41-
class StrideO_
41+
class StrideO_,
42+
class ElementAcc_,
43+
class StrideLSE_
4244
>
4345
struct Sm100FmhaGenEpilogueWarpspecialized {
4446

@@ -47,9 +49,11 @@ struct Sm100FmhaGenEpilogueWarpspecialized {
4749
using SmemLayoutO = Layout<Shape<_1, _1, _1>>;
4850
using SmemLayoutO_ = SmemLayoutO;
4951
using Element = Element_;
52+
using ElementAcc = ElementAcc_;
5053
using StrideOOrig = StrideO_;
5154
using StrideO = decltype(replace<0>(StrideOOrig{}, 0));
52-
55+
using StrideLSE = StrideLSE_;
56+
5357
struct TensorStorage {
5458

5559
using SmemLayoutO = SmemLayoutO_;
@@ -60,6 +64,8 @@ struct Sm100FmhaGenEpilogueWarpspecialized {
6064
struct Arguments {
6165
Element* ptr_o;
6266
StrideO dO;
67+
ElementAcc* ptr_LSE;
68+
StrideLSE dLSE;
6369
};
6470

6571
using Params = Arguments;

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -870,12 +870,12 @@ struct Sm100FmhaGenMainloopWarpspecialized {
870870
++pipeline_s_consumer_state;
871871
}
872872

873-
template<class Vector, class GTensor, class CTensor, class Shape, class Epilogue>
873+
template<class Vector, class GTensor, class CTensor, class Shape, class Epilogue, class BlkCoord, class ProblemShape>
874874
CUTLASS_DEVICE auto
875875
correction_epilogue(
876876
float scale_softmax_log2, float scale_out, Vector const& v0, Vector const& v1,
877877
GTensor& gO, CTensor const& cO, Shape const& g_shape,
878-
Epilogue const& epilogue) {
878+
Epilogue const& epilogue, BlkCoord const& blk_coord, ProblemShape const& problem_shape,int const row_idx) {
879879

880880
using ElementOut = typename GTensor::value_type;
881881

@@ -887,7 +887,6 @@ struct Sm100FmhaGenMainloopWarpspecialized {
887887
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
888888
// TODO: load all values
889889

890-
891890
// Choose TMEM OP based on
892891
// - TileM shape
893892
// - kCorrectionTileSize
@@ -933,6 +932,31 @@ struct Sm100FmhaGenMainloopWarpspecialized {
933932
float scale0 = scale_out * adj0 / row_sum;
934933
float scale1 = scale_out * adj1 / row_sum;
935934

935+
// Compute and store LSE if requested
936+
if (epilogue.params.ptr_LSE != nullptr) {
937+
// LSE = log(row_sum) + scale_softmax * row_max
938+
// scale_softmax_log2 is already in log2 space, convert to natural log
939+
float lse = cutlass::fast_log(row_sum) + (scale_softmax_log2 / std::log2(std::exp(1.0f))) * row_max;
940+
int h_r = row_idx;
941+
int h_k = get<2, 0>(blk_coord);
942+
int b = get<2, 1>(blk_coord);
943+
944+
// After problem_shape transformation in kernel:
945+
// problem_shape = (H_R, Sk, D, ((1, H_K), B))
946+
// So: get<0> = H_R, get<3,0,1> = H_K
947+
int H_R = get<0>(problem_shape);
948+
949+
// Check bounds
950+
if (thread_idx < H_R) {
951+
// LSE tensor shape: [B, H_K, H_R]
952+
// Use stride from epilogue.params.dLSE instead of hardcoding
953+
int linear_idx = b * get<0>(epilogue.params.dLSE) +
954+
h_k * get<1>(epilogue.params.dLSE) +
955+
h_r * get<2>(epilogue.params.dLSE);
956+
epilogue.params.ptr_LSE[linear_idx] = lse;
957+
}
958+
}
959+
936960
float2 scale0_f32x2 = make_float2(scale0, scale0);
937961
float2 scale1_f32x2 = make_float2(scale1, scale1);
938962

@@ -1223,8 +1247,9 @@ struct Sm100FmhaGenMainloopWarpspecialized {
12231247
auto g_shape = select<0,2>(problem_shape);
12241248
auto mO = make_tensor(make_gmem_ptr(epilogue.params.ptr_o), append<3>(select<0,1>(TileShapePV{}), get<3>(problem_shape)), epilogue.params.dO);
12251249
auto gO = mO(_, _, get<2>(blk_coord));
1226-
1227-
correction_epilogue(params.scale_softmax_log2, params.scale_output, tTMEM_LOADVrS0, tTMEM_LOADVrS1, gO, cO, g_shape, epilogue);
1250+
int row_idx = get<0>(tTMEM_LOADVcS(_0{}));
1251+
correction_epilogue(params.scale_softmax_log2, params.scale_output, tTMEM_LOADVrS0, tTMEM_LOADVrS1,
1252+
gO, cO, g_shape, epilogue, blk_coord, problem_shape, row_idx);
12281253

12291254
cutlass::arch::fence_view_async_tmem_load();
12301255

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ struct Sm100FmhaGenKernelWarpspecialized {
168168
ElementOut* ptr_o; // 1 x D x (H x B)
169169
StrideOOrig dO;
170170

171+
ElementAcc* ptr_LSE; // (B, H_K, H_R)
172+
cute::Stride<int, int, cute::_1> dLSE; // stride: (H_K*H_R, H_R, 1)
173+
171174
cutlass::KernelHardwareInfo hw_info;
172175

173176
ElementAcc scale_softmax = 0.0f;
@@ -227,6 +230,8 @@ struct Sm100FmhaGenKernelWarpspecialized {
227230

228231
typename CollectiveEpilogue::Arguments epilogue_args {
229232
args.ptr_o, dO,
233+
args.ptr_LSE,
234+
args.dLSE,
230235
};
231236

232237
return Params{

0 commit comments

Comments
 (0)