@@ -219,20 +219,28 @@ def forward_kernel_causal_and_sparse(
219
219
220
220
if EVEN_N & EVEN_M :
221
221
if EVEN_HEADDIM :
222
- k = tl .load (k_ptrs )
222
+ k = tl .load (
223
+ k_ptrs ,
224
+ mask = (offs_n [:, None ] >= 0 ),
225
+ other = 0.
226
+ )
223
227
else :
224
- k = tl .load (k_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
228
+ k = tl .load (
229
+ k_ptrs ,
230
+ mask = (offs_n [:, None ] >= 0 ) & (offs_d [None , :] < headdim ),
231
+ other = 0.0
232
+ )
225
233
else :
226
234
if EVEN_HEADDIM :
227
235
k = tl .load (
228
236
k_ptrs ,
229
- mask = offs_n [:, None ] < seqlen_k ,
237
+ mask = ( offs_n [:, None ] >= 0 ) & ( offs_n [:, None ] < seqlen_k ) ,
230
238
other = 0.0 ,
231
239
)
232
240
else :
233
241
k = tl .load (
234
242
k_ptrs ,
235
- mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
243
+ mask = (offs_n [:, None ] >= 0 ) & ( offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
236
244
other = 0.0 ,
237
245
)
238
246
@@ -1229,19 +1237,36 @@ def backward_kernel_one_col_block_causal(
1229
1237
# if we just call tl.load(k_ptrs), we get the wrong output!
1230
1238
if EVEN_N & EVEN_M :
1231
1239
if EVEN_HEADDIM :
1232
- k = tl .load (k_ptrs )
1240
+ k = tl .load (
1241
+ k_ptrs ,
1242
+ mask = (offs_n [:, None ] >= 0 ),
1243
+ other = 0.
1244
+ )
1233
1245
v = tl .load (v_ptrs )
1234
1246
else :
1235
- k = tl .load (k_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
1247
+ k = tl .load (
1248
+ k_ptrs ,
1249
+ mask = (offs_n [:, None ] >= 0 ) & (offs_d [None , :] < headdim ),
1250
+ other = 0.0
1251
+ )
1252
+
1236
1253
v = tl .load (v_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
1237
1254
else :
1238
1255
if EVEN_HEADDIM :
1239
- k = tl .load (k_ptrs , mask = offs_n [:, None ] < seqlen_k , other = 0.0 )
1256
+ k = tl .load (
1257
+ k_ptrs ,
1258
+ mask = (offs_n [:, None ] >= 0 ) & (offs_n [:, None ] < seqlen_k ),
1259
+ other = 0.0
1260
+ )
1261
+
1240
1262
v = tl .load (v_ptrs , mask = offs_n [:, None ] < seqlen_k , other = 0.0 )
1241
1263
else :
1242
1264
k = tl .load (
1243
- k_ptrs , mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ), other = 0.0
1265
+ k_ptrs ,
1266
+ mask = (offs_n [:, None ] >= 0 ) & (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
1267
+ other = 0.0
1244
1268
)
1269
+
1245
1270
v = tl .load (
1246
1271
v_ptrs , mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ), other = 0.0
1247
1272
)
@@ -1273,7 +1298,7 @@ def backward_kernel_one_col_block_causal(
1273
1298
1274
1299
if BLOCK != SEL_BLOCK :
1275
1300
block_diagonal_mask = (
1276
- (offs_n [None , :] >= 0. ) &
1301
+ (offs_n [None , :] >= 0 ) &
1277
1302
((offs_n [None , :] // SEL_BLOCK ) == (offs_m [:, None ] // SEL_BLOCK ))
1278
1303
)
1279
1304
0 commit comments