Commit 8b2c265
Write back LSE (#5209)
Summary:
Pull Request resolved: #5209
X-link: https://github.com/facebookresearch/FBGEMM/pull/2204
* **Python interface**: Modifies `fmha_gen_fwd` to return LSE tensor instead of creating a dummy one
* **CUDA implementation**: Adds LSE tensor allocation and computation logic
* **Epilogue**: Adds LSE computation and storage in the epilogue
* **Mainloop**: Updates `correction_epilogue` to compute and write LSE values
Reviewed By: jsisometa
Differential Revision: D86949420
fbshipit-source-id: bf6fd9fa616d91c3b758b8a47a933690a88a9b801 parent 2aa8cd0 commit 8b2c265
File tree
6 files changed
+82
-47
lines changed- fbgemm_gpu/experimental/gen_ai
- gen_ai/attention/cutlass_blackwell_fmha
- src/attention/cuda/cutlass_blackwell_fmha
- collective
- kernel
6 files changed
+82
-47
lines changedLines changed: 1 addition & 24 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
180 | 180 | | |
181 | 181 | | |
182 | 182 | | |
183 | | - | |
184 | | - | |
185 | | - | |
186 | | - | |
187 | | - | |
188 | | - | |
189 | | - | |
190 | | - | |
191 | | - | |
192 | | - | |
193 | | - | |
194 | | - | |
195 | | - | |
196 | | - | |
197 | | - | |
198 | | - | |
199 | | - | |
200 | | - | |
201 | | - | |
202 | | - | |
203 | 183 | | |
204 | 184 | | |
205 | 185 | | |
| |||
233 | 213 | | |
234 | 214 | | |
235 | 215 | | |
236 | | - | |
| 216 | + | |
237 | 217 | | |
238 | 218 | | |
239 | 219 | | |
| |||
248 | 228 | | |
249 | 229 | | |
250 | 230 | | |
251 | | - | |
252 | | - | |
253 | | - | |
254 | 231 | | |
255 | 232 | | |
256 | 233 | | |
| |||
Lines changed: 37 additions & 15 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
99 | 99 | | |
100 | 100 | | |
101 | 101 | | |
| 102 | + | |
102 | 103 | | |
103 | 104 | | |
104 | 105 | | |
| |||
117 | 118 | | |
118 | 119 | | |
119 | 120 | | |
120 | | - | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
121 | 124 | | |
122 | 125 | | |
123 | 126 | | |
| |||
138 | 141 | | |
139 | 142 | | |
140 | 143 | | |
| 144 | + | |
141 | 145 | | |
142 | 146 | | |
| 147 | + | |
143 | 148 | | |
144 | 149 | | |
145 | 150 | | |
146 | | - | |
| 151 | + | |
147 | 152 | | |
148 | 153 | | |
149 | 154 | | |
| |||
177 | 182 | | |
178 | 183 | | |
179 | 184 | | |
180 | | - | |
| 185 | + | |
181 | 186 | | |
182 | 187 | | |
183 | 188 | | |
| |||
210 | 215 | | |
211 | 216 | | |
212 | 217 | | |
| 218 | + | |
213 | 219 | | |
214 | 220 | | |
215 | 221 | | |
216 | 222 | | |
217 | 223 | | |
218 | 224 | | |
219 | 225 | | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
220 | 232 | | |
221 | 233 | | |
222 | 234 | | |
| |||
241 | 253 | | |
242 | 254 | | |
243 | 255 | | |
| 256 | + | |
| 257 | + | |
244 | 258 | | |
245 | 259 | | |
246 | 260 | | |
| |||
306 | 320 | | |
307 | 321 | | |
308 | 322 | | |
309 | | - | |
| 323 | + | |
310 | 324 | | |
311 | 325 | | |
312 | 326 | | |
| |||
321 | 335 | | |
322 | 336 | | |
323 | 337 | | |
324 | | - | |
| 338 | + | |
325 | 339 | | |
326 | 340 | | |
327 | 341 | | |
| |||
343 | 357 | | |
344 | 358 | | |
345 | 359 | | |
346 | | - | |
| 360 | + | |
347 | 361 | | |
348 | 362 | | |
349 | 363 | | |
350 | 364 | | |
351 | 365 | | |
352 | 366 | | |
353 | 367 | | |
354 | | - | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
355 | 378 | | |
356 | 379 | | |
357 | 380 | | |
358 | 381 | | |
359 | 382 | | |
360 | 383 | | |
361 | 384 | | |
362 | | - | |
363 | | - | |
364 | | - | |
365 | | - | |
366 | | - | |
367 | | - | |
368 | | - | |
369 | | - | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
370 | 392 | | |
371 | 393 | | |
372 | 394 | | |
| |||
Lines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
44 | | - | |
| 44 | + | |
45 | 45 | | |
46 | 46 | | |
47 | 47 | | |
| |||
Lines changed: 8 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
38 | 38 | | |
39 | 39 | | |
40 | 40 | | |
41 | | - | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
42 | 44 | | |
43 | 45 | | |
44 | 46 | | |
| |||
47 | 49 | | |
48 | 50 | | |
49 | 51 | | |
| 52 | + | |
50 | 53 | | |
51 | 54 | | |
52 | | - | |
| 55 | + | |
| 56 | + | |
53 | 57 | | |
54 | 58 | | |
55 | 59 | | |
| |||
60 | 64 | | |
61 | 65 | | |
62 | 66 | | |
| 67 | + | |
| 68 | + | |
63 | 69 | | |
64 | 70 | | |
65 | 71 | | |
| |||
Lines changed: 30 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
870 | 870 | | |
871 | 871 | | |
872 | 872 | | |
873 | | - | |
| 873 | + | |
874 | 874 | | |
875 | 875 | | |
876 | 876 | | |
877 | 877 | | |
878 | | - | |
| 878 | + | |
879 | 879 | | |
880 | 880 | | |
881 | 881 | | |
| |||
887 | 887 | | |
888 | 888 | | |
889 | 889 | | |
890 | | - | |
891 | 890 | | |
892 | 891 | | |
893 | 892 | | |
| |||
933 | 932 | | |
934 | 933 | | |
935 | 934 | | |
| 935 | + | |
| 936 | + | |
| 937 | + | |
| 938 | + | |
| 939 | + | |
| 940 | + | |
| 941 | + | |
| 942 | + | |
| 943 | + | |
| 944 | + | |
| 945 | + | |
| 946 | + | |
| 947 | + | |
| 948 | + | |
| 949 | + | |
| 950 | + | |
| 951 | + | |
| 952 | + | |
| 953 | + | |
| 954 | + | |
| 955 | + | |
| 956 | + | |
| 957 | + | |
| 958 | + | |
| 959 | + | |
936 | 960 | | |
937 | 961 | | |
938 | 962 | | |
| |||
1223 | 1247 | | |
1224 | 1248 | | |
1225 | 1249 | | |
1226 | | - | |
1227 | | - | |
| 1250 | + | |
| 1251 | + | |
| 1252 | + | |
1228 | 1253 | | |
1229 | 1254 | | |
1230 | 1255 | | |
| |||
Lines changed: 5 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
168 | 168 | | |
169 | 169 | | |
170 | 170 | | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
171 | 174 | | |
172 | 175 | | |
173 | 176 | | |
| |||
227 | 230 | | |
228 | 231 | | |
229 | 232 | | |
| 233 | + | |
| 234 | + | |
230 | 235 | | |
231 | 236 | | |
232 | 237 | | |
| |||
0 commit comments