Skip to content

Commit da547e7

Browse files
committed
move query head group dimension to the right by one in forward triton kernel
1 parent 911baa5 commit da547e7

File tree

3 files changed

+93
-62
lines changed

3 files changed

+93
-62
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 90 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def forward_kernel(
129129
q_ptrs = (
130130
Q +
131131
off_b * stride_qb +
132-
offs_qh[:, None, None] * stride_qh +
133-
offs_m[None, :, None] * stride_qm +
132+
offs_qh[None, :, None] * stride_qh +
133+
offs_m[:, None, None] * stride_qm +
134134
offs_d[None, None, :]
135135
)
136136

@@ -152,48 +152,56 @@ def forward_kernel(
152152

153153
# maximum
154154

155-
m_i = tl.zeros([BLOCK * QUERY_HEAD_GROUPS], dtype = tl.float32) - float("inf")
155+
m_i = tl.zeros([BLOCK, QUERY_HEAD_GROUPS], dtype = tl.float32) - float("inf")
156156

157157
# lse
158158

159159
lse_ptrs = (
160160
Lse +
161161
off_b * stride_lse_b +
162-
offs_qh[:, None] * seqlen_q_rounded +
163-
offs_m[None, :]
162+
offs_qh[None, :] * seqlen_q_rounded +
163+
offs_m[:, None]
164164
)
165165

166-
lse_i = tl.zeros([BLOCK * QUERY_HEAD_GROUPS], dtype = tl.float32) - float("inf")
166+
lse_i = tl.zeros([BLOCK, QUERY_HEAD_GROUPS], dtype = tl.float32) - float("inf")
167167

168168
# output
169169

170170
out_ptrs = (
171171
Out +
172172
off_b * stride_ob +
173-
offs_qh[:, None, None] * stride_oh +
174-
offs_m[None, :, None] * stride_om +
173+
offs_qh[None, :, None] * stride_oh +
174+
offs_m[:, None, None] * stride_om +
175175
offs_d[None, None, :]
176176
)
177177

178-
acc_o = tl.zeros([QUERY_HEAD_GROUPS * BLOCK, BLOCK_HEADDIM], dtype = tl.float32)
178+
acc_o = tl.zeros([BLOCK, QUERY_HEAD_GROUPS, BLOCK_HEADDIM], dtype = tl.float32)
179179

180180
# load queries, keys, values
181181

182182
if EVEN_M & EVEN_N:
183183
if EVEN_HEADDIM:
184184
q = tl.load(q_ptrs)
185185
else:
186-
q = tl.load(q_ptrs, mask=offs_d[None, None, :] < headdim, other=0.0)
186+
q = tl.load(
187+
q_ptrs,
188+
mask = offs_d[None, None, :] < headdim,
189+
other = 0.0
190+
)
187191
else:
188192
if EVEN_HEADDIM:
189-
q = tl.load(q_ptrs, mask=offs_m[None, :, None] < seqlen_q, other=0.0)
193+
q = tl.load(
194+
q_ptrs,
195+
mask = offs_m[:, None, None] < seqlen_q,
196+
other = 0.0
197+
)
190198
else:
191199
q = tl.load(
192-
q_ptrs, mask=(offs_m[None, :, None] < seqlen_q) & (offs_d[None, None, :] < headdim), other=0.0
200+
q_ptrs,
201+
mask = (offs_m[:, None, None] < seqlen_q) & (offs_d[None, None, :] < headdim),
202+
other = 0.0
193203
)
194204

195-
q = q.reshape([QUERY_HEAD_GROUPS * BLOCK, BLOCK_HEADDIM])
196-
197205
if EVEN_N & EVEN_M:
198206
if EVEN_HEADDIM:
199207
k = tl.load(k_ptrs)
@@ -203,65 +211,75 @@ def forward_kernel(
203211
if EVEN_HEADDIM:
204212
k = tl.load(
205213
k_ptrs,
206-
mask=offs_n[:, None] < seqlen_k,
207-
other=0.0,
214+
mask = offs_n[:, None] < seqlen_k,
215+
other = 0.0,
208216
)
209217
else:
210218
k = tl.load(
211219
k_ptrs,
212-
mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
213-
other=0.0,
220+
mask = (offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
221+
other = 0.0,
214222
)
215223

216-
qk = tl.zeros([QUERY_HEAD_GROUPS * BLOCK, BLOCK], dtype=tl.float32)
224+
qk = tl.zeros([BLOCK * QUERY_HEAD_GROUPS, BLOCK], dtype=tl.float32)
225+
226+
q = q.reshape(BLOCK * QUERY_HEAD_GROUPS, BLOCK_HEADDIM)
227+
217228
qk += tl.dot(q, tl.trans(k))
218229

230+
qk = qk.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
231+
219232
if not EVEN_N:
220233
qk += tl.where(offs_n[None, :] < seqlen_k, 0, float("-inf"))
221234

222-
qk = qk.reshape([QUERY_HEAD_GROUPS, BLOCK, BLOCK])
223-
224-
qk += tl.where(offs_m[:, None] >= offs_n[None, :], 0, float("-inf"))
235+
qk = qk.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
225236

226-
qk = qk.reshape([QUERY_HEAD_GROUPS * BLOCK, BLOCK])
237+
qk += tl.where(offs_m[:, None, None] >= offs_n[None, None, :], 0, float("-inf"))
227238

228-
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
229-
p = tl.exp(qk * softmax_scale - m_ij[:, None])
239+
m_ij = tl.maximum(tl.max(qk, 2) * softmax_scale, lse_i)
240+
p = tl.exp(qk * softmax_scale - m_ij[:, :, None])
230241

231-
l_ij = tl.sum(p, 1)
242+
l_ij = tl.sum(p, 2)
232243

233244
acc_o_scale = tl.exp(m_i - m_ij)
234-
acc_o *= acc_o_scale[:, None]
245+
acc_o *= acc_o_scale[:, :, None]
235246

236247
if EVEN_N & EVEN_M:
237248
if EVEN_HEADDIM:
238249
v = tl.load(v_ptrs)
239250
else:
240-
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
251+
v = tl.load(
252+
v_ptrs,
253+
mask = offs_d[None, :] < headdim,
254+
other = 0.0
255+
)
241256
else:
242257
if EVEN_HEADDIM:
243258
v = tl.load(
244259
v_ptrs,
245-
mask=offs_n[:, None] < seqlen_k,
246-
other=0.0,
260+
mask = offs_n[:, None] < seqlen_k,
261+
other = 0.0,
247262
)
248263
else:
249264
v = tl.load(
250265
v_ptrs,
251-
mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
252-
other=0.0,
266+
mask = (offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
267+
other = 0.0,
253268
)
254269

255-
p = p.to(v.dtype)
256-
acc_o += tl.dot(p, v)
270+
p = p.reshape(BLOCK * QUERY_HEAD_GROUPS, BLOCK).to(v.dtype)
271+
272+
causal_o = tl.dot(p, v)
273+
274+
acc_o += causal_o.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK_HEADDIM)
257275

258276
# -- update statistics
259277

260278
m_i = m_ij
261279
l_i_new = tl.exp(lse_i - m_ij) + l_ij
262280
lse_i = m_ij + tl.log(l_i_new)
263281

264-
# take care of the selected kv blocks
282+
# # take care of the selected kv blocks
265283

266284
kv_block_indices_ptrs = (
267285
kv_block_indices +
@@ -277,8 +295,7 @@ def forward_kernel(
277295
offs_m * stride_kvbl_m
278296
)
279297

280-
q = q.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK_HEADDIM)
281-
q = q.permute((1, 0, 2))
298+
q = q.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK_HEADDIM)
282299
q = tl.expand_dims(q, 2)
283300
q = tl.broadcast_to(q, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM))
284301
q = q.reshape(BLOCK, 16, BLOCK_HEADDIM)
@@ -290,11 +307,19 @@ def forward_kernel(
290307
blocks_offs_n = block_indices[:, None] * BLOCK + tl.arange(0, BLOCK)[None, :]
291308

292309
block_k_ptrs = (
293-
K + off_b * stride_kb + off_h * stride_kh + (blocks_offs_n[:, :, None] * stride_kn + offs_d[None, None, :])
310+
K +
311+
off_b * stride_kb +
312+
off_h * stride_kh +
313+
blocks_offs_n[:, :, None] * stride_kn +
314+
offs_d[None, None, :]
294315
)
295316

296317
block_v_ptrs = (
297-
V + off_b * stride_vb + off_h * stride_vh + (blocks_offs_n[:, :, None] * stride_vn + offs_d[None, None, :])
318+
V +
319+
off_b * stride_vb +
320+
off_h * stride_vh +
321+
blocks_offs_n[:, :, None] * stride_vn +
322+
offs_d[None, None, :]
298323
)
299324

300325
# load k of shape (m, n, d), sparsely selected by each query
@@ -304,50 +329,44 @@ def forward_kernel(
304329
# similarities
305330

306331
block_qk = tl.zeros([BLOCK, 16, BLOCK], dtype = tl.float32)
307-
qk = tl.zeros([QUERY_HEAD_GROUPS, BLOCK, BLOCK], dtype = tl.float32)
332+
qk = tl.zeros([BLOCK, QUERY_HEAD_GROUPS, BLOCK], dtype = tl.float32)
308333

309334
k_block = k_block.reshape(BLOCK, BLOCK, BLOCK_HEADDIM)
310335
k_block = k_block.permute(0, 2, 1)
311336

312-
block_qk = tl.dot(q, k_block)
337+
block_qk += tl.dot(q, k_block)
313338
block_qk = block_qk.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK)
314339
block_qk = tl.sum(block_qk, 2) / QUERY_EXPAND_DIM
315-
block_qk = block_qk.permute(1, 0, 2)
316340

317341
qk += block_qk
318-
qk += tl.where(block_masks[:, None], 0, float("-inf"))
319-
320-
qk = qk.reshape(QUERY_HEAD_GROUPS * BLOCK, BLOCK)
342+
qk += tl.where(block_masks[:, None, None], 0, float("-inf"))
321343

322344
# attention
323345

324-
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
325-
p = tl.exp(qk * softmax_scale - m_ij[:, None])
346+
m_ij = tl.maximum(tl.max(qk, 2) * softmax_scale, lse_i)
347+
block_p = tl.exp(qk * softmax_scale - m_ij[:, :, None])
326348

327-
l_ij = tl.sum(p, 1)
349+
l_ij = tl.sum(block_p, 2)
328350

329351
# renormalize the running output
330352

331353
acc_o_scale = tl.exp(m_i - m_ij)
332-
acc_o = acc_o * acc_o_scale[:, None]
354+
acc_o = acc_o * acc_o_scale[:, :, None]
333355

334356
# aggregate values
335357

336358
v_block = tl.load(block_v_ptrs)
337359
v_block = tl.reshape(v_block, (BLOCK, BLOCK, BLOCK_HEADDIM))
338360

339-
p = p.to(v_block.dtype)
340-
p_expanded = p.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK)
341-
p_expanded = p_expanded.permute(1, 0, 2)
361+
block_p = block_p.to(v_block.dtype)
362+
p_expanded = block_p.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
342363
p_expanded = tl.expand_dims(p_expanded, 2)
343364
p_expanded = tl.broadcast_to(p_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK))
344365
p_expanded = p_expanded.reshape(BLOCK, 16, BLOCK)
345366

346367
block_acc_o = tl.dot(p_expanded, v_block)
347368
block_acc_o = block_acc_o.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM)
348369
block_acc_o = tl.sum(block_acc_o, 2) / QUERY_EXPAND_DIM
349-
block_acc_o = block_acc_o.permute(1, 0, 2)
350-
block_acc_o = block_acc_o.reshape(QUERY_HEAD_GROUPS * BLOCK, BLOCK_HEADDIM)
351370

352371
acc_o += block_acc_o
353372

@@ -360,28 +379,38 @@ def forward_kernel(
360379
# normalize accumulated out
361380

362381
acc_o_scale = tl.exp(m_i - lse_i)
363-
acc_o *= acc_o_scale[:, None]
382+
acc_o *= acc_o_scale[:, :, None]
364383

365384
# write back lse
366385

367-
lse_i = lse_i.reshape([QUERY_HEAD_GROUPS, BLOCK])
368-
tl.store(lse_ptrs, lse_i, mask = offs_m[None, :] < seqlen_q)
386+
lse_i = lse_i.reshape(BLOCK, QUERY_HEAD_GROUPS)
387+
tl.store(lse_ptrs, lse_i, mask = offs_m[:, None] < seqlen_q)
369388

370389
# write to output
371390

372-
acc_o = acc_o.reshape([QUERY_HEAD_GROUPS, BLOCK, BLOCK_HEADDIM])
391+
acc_o = acc_o.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK_HEADDIM)
373392

374393
if EVEN_M:
375394
if EVEN_HEADDIM:
376395
tl.store(out_ptrs, acc_o)
377396
else:
378-
tl.store(out_ptrs, acc_o, mask=offs_d[None, None, :] < headdim)
397+
tl.store(
398+
out_ptrs,
399+
acc_o,
400+
mask = offs_d[None, None, :] < headdim
401+
)
379402
else:
380403
if EVEN_HEADDIM:
381-
tl.store(out_ptrs, acc_o, mask=offs_m[None, :, None] < seqlen_q)
404+
tl.store(
405+
out_ptrs,
406+
acc_o,
407+
mask = offs_m[:, None, None] < seqlen_q
408+
)
382409
else:
383410
tl.store(
384-
out_ptrs, acc_o, mask=(offs_m[None, :, None] < seqlen_q) & (offs_d[None, None, :] < headdim)
411+
out_ptrs,
412+
acc_o,
413+
mask = (offs_m[:, None, None] < seqlen_q) & (offs_d[None, None, :] < headdim)
385414
)
386415

387416
def native_sparse_attn_forward(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.57"
3+
version = "0.0.58"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

test_triton_nsa.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,5 @@ def regular_attend(
137137
assert torch.allclose(nv.grad, rv.grad, atol = 1e-2)
138138
assert torch.allclose(nk.grad, rk.grad, atol = 1e-2)
139139
assert torch.allclose(nq.grad, rq.grad, atol = 1e-2)
140+
141+
print('✅ outputs and gradients are same between pytorch native sparse attn and triton native sparse attn')

0 commit comments

Comments
 (0)