Skip to content

Commit 2c01cac

Browse files
committed
more guards, addressing #15 for head size <= 64
1 parent 76d5507 commit 2c01cac

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -219,20 +219,28 @@ def forward_kernel_causal_and_sparse(
219219

220220
if EVEN_N & EVEN_M:
221221
if EVEN_HEADDIM:
222-
k = tl.load(k_ptrs)
222+
k = tl.load(
223+
k_ptrs,
224+
mask = (offs_n[:, None] >= 0),
225+
other = 0.
226+
)
223227
else:
224-
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
228+
k = tl.load(
229+
k_ptrs,
230+
mask = (offs_n[:, None] >= 0) & (offs_d[None, :] < headdim),
231+
other = 0.0
232+
)
225233
else:
226234
if EVEN_HEADDIM:
227235
k = tl.load(
228236
k_ptrs,
229-
mask = offs_n[:, None] < seqlen_k,
237+
mask = (offs_n[:, None] >= 0) & (offs_n[:, None] < seqlen_k),
230238
other = 0.0,
231239
)
232240
else:
233241
k = tl.load(
234242
k_ptrs,
235-
mask = (offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
243+
mask = (offs_n[:, None] >= 0) & (offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
236244
other = 0.0,
237245
)
238246

@@ -1229,19 +1237,36 @@ def backward_kernel_one_col_block_causal(
12291237
# if we just call tl.load(k_ptrs), we get the wrong output!
12301238
if EVEN_N & EVEN_M:
12311239
if EVEN_HEADDIM:
1232-
k = tl.load(k_ptrs)
1240+
k = tl.load(
1241+
k_ptrs,
1242+
mask = (offs_n[:, None] >= 0),
1243+
other = 0.
1244+
)
12331245
v = tl.load(v_ptrs)
12341246
else:
1235-
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
1247+
k = tl.load(
1248+
k_ptrs,
1249+
mask = (offs_n[:, None] >= 0) & (offs_d[None, :] < headdim),
1250+
other = 0.0
1251+
)
1252+
12361253
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
12371254
else:
12381255
if EVEN_HEADDIM:
1239-
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
1256+
k = tl.load(
1257+
k_ptrs,
1258+
mask = (offs_n[:, None] >= 0) & (offs_n[:, None] < seqlen_k),
1259+
other = 0.0
1260+
)
1261+
12401262
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
12411263
else:
12421264
k = tl.load(
1243-
k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
1265+
k_ptrs,
1266+
mask= (offs_n[:, None] >= 0) & (offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
1267+
other = 0.0
12441268
)
1269+
12451270
v = tl.load(
12461271
v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
12471272
)
@@ -1273,7 +1298,7 @@ def backward_kernel_one_col_block_causal(
12731298

12741299
if BLOCK != SEL_BLOCK:
12751300
block_diagonal_mask = (
1276-
(offs_n[None, :] >= 0.) &
1301+
(offs_n[None, :] >= 0) &
12771302
((offs_n[None, :] // SEL_BLOCK) == (offs_m[:, None] // SEL_BLOCK))
12781303
)
12791304

test_triton_nsa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ def regular_attend(
136136

137137
batch = 4
138138
seq_len = 507
139-
q_heads = 4
139+
q_heads = 8
140140
kv_heads = 2
141141
fine_block_size = 32
142-
num_sel = 6
142+
num_sel = 2
143143
dim_head = 64
144144
fused_sliding_window = False
145145
block_dk_dv_use_dot = False # need sufficient shared memory, A100 works

0 commit comments

Comments
 (0)