Skip to content

Commit db38d04

Browse files
committed
move grouped query head dimension to the right by one and cleanup backwards for sparse kv blocks
1 parent 599f312 commit db38d04

File tree

3 files changed

+60
-114
lines changed

3 files changed

+60
-114
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 57 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def backward_kernel_one_col_block_sparse(
587587
QUERY_EXPAND_DIM: tl.constexpr,
588588
OFF_SEL_KV_BLOCKS: tl.constexpr
589589
):
590-
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
590+
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
591591

592592
begin_m = ((start_n * BLOCK) // BLOCK) * BLOCK
593593

@@ -609,73 +609,32 @@ def backward_kernel_one_col_block_sparse(
609609

610610
q_ptrs = (
611611
Q +
612-
offs_g[:, None, None] * stride_qh +
613-
offs_qm[None, :, None] * stride_qm +
612+
offs_g[None, :, None] * stride_qh +
613+
offs_qm[:, None, None] * stride_qm +
614614
offs_d[None, None, :]
615615
)
616616

617617
do_ptrs = (
618618
DO +
619-
offs_g[:, None, None] * stride_doh +
620-
offs_qm[None, :, None] * stride_dom +
619+
offs_g[None, :, None] * stride_doh +
620+
offs_qm[:, None, None] * stride_dom +
621621
offs_d[None, None, :]
622622
)
623623

624624
dq_ptrs = (
625625
DQ +
626-
offs_g[:, None, None] * stride_dqh +
627-
offs_qm[None, :, None] * stride_dqm +
626+
offs_g[None, :, None] * stride_dqh +
627+
offs_qm[:, None, None] * stride_dqm +
628628
offs_d[None, None, :]
629629
)
630630

631-
# initialize dv and dk
632-
633-
dv = tl.zeros([BLOCK, BLOCK_HEADDIM], dtype=tl.float32)
634-
dk = tl.zeros([BLOCK, BLOCK_HEADDIM], dtype=tl.float32)
635-
636631
# There seems to be some problem with Triton pipelining that makes results wrong for
637632
# headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
638633
# may have zero step, and pipelining with the bias matrix could screw it up.
639634
# So we just exit early.
640635

641636
if begin_m >= seqlen_q:
642-
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
643-
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
644-
backward_store_dk_dv(
645-
dk_ptrs,
646-
dv_ptrs,
647-
dk,
648-
dv,
649-
offs_n,
650-
offs_d,
651-
seqlen_k,
652-
headdim,
653-
EVEN_M=EVEN_M,
654-
EVEN_N=EVEN_N,
655-
EVEN_HEADDIM=EVEN_HEADDIM,
656-
)
657637
return
658-
# k and v stay in SRAM throughout
659-
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
660-
# if we just call tl.load(k_ptrs), we get the wrong output!
661-
if EVEN_N & EVEN_M:
662-
if EVEN_HEADDIM:
663-
k = tl.load(k_ptrs)
664-
v = tl.load(v_ptrs)
665-
else:
666-
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
667-
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
668-
else:
669-
if EVEN_HEADDIM:
670-
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
671-
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
672-
else:
673-
k = tl.load(
674-
k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
675-
)
676-
v = tl.load(
677-
v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
678-
)
679638

680639
# same block for block causal diagonal
681640

@@ -685,28 +644,17 @@ def backward_kernel_one_col_block_sparse(
685644
q = tl.load(q_ptrs)
686645
else:
687646
if EVEN_HEADDIM:
688-
q = tl.load(q_ptrs, mask=offs_m[None, :, None] < seqlen_q, other=0.0)
647+
q = tl.load(
648+
q_ptrs,
649+
mask = offs_m[:, None, None] < seqlen_q,
650+
other = 0.0
651+
)
689652
else:
690653
q = tl.load(
691654
q_ptrs,
692-
mask=(offs_m[None, :, None] < seqlen_q) & (offs_d[None, None, :] < headdim),
693-
other=0.0,
655+
mask = (offs_m[:, None, None] < seqlen_q) & (offs_d[None, None, :] < headdim),
656+
other = 0.0,
694657
)
695-
# recompute p = softmax(qk, dim=-1).T
696-
697-
q = q.reshape([QUERY_HEAD_GROUPS * BLOCK, BLOCK_HEADDIM])
698-
699-
qk = tl.dot(q, tl.trans(k))
700-
701-
qk = qk.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK)
702-
703-
# Trying to combine the two masks seem to make the result wrong
704-
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
705-
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
706-
707-
qk = tl.where(offs_m[:, None] >= (offs_n[None, :]), qk, float("-inf"))
708-
709-
qk = qk.reshape(QUERY_HEAD_GROUPS * BLOCK, BLOCK)
710658

711659
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
712660
# Also wrong for headdim=64.
@@ -715,9 +663,7 @@ def backward_kernel_one_col_block_sparse(
715663
tl.debug_barrier()
716664

717665
lse_i = tl.load(LSE + offs_d_or_lse)
718-
lse_i = lse_i.reshape(QUERY_HEAD_GROUPS * BLOCK)
719-
720-
p = tl.exp(qk * softmax_scale - lse_i[:, None])
666+
lse_i = tl.trans(lse_i) # (m, h)
721667

722668
# compute dv
723669
# [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
@@ -730,12 +676,10 @@ def backward_kernel_one_col_block_sparse(
730676
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
731677
do = tl.load(
732678
do_ptrs,
733-
mask=(offs_m[None, :, None] < seqlen_q) & (offs_d[None, None, :] < headdim),
734-
other=0.0,
679+
mask = (offs_m[:, None, None] < seqlen_q) & (offs_d[None, None, :] < headdim),
680+
other = 0.0,
735681
)
736682

737-
do = do.reshape(QUERY_HEAD_GROUPS * BLOCK, BLOCK_HEADDIM)
738-
739683
# compute dp = dot(v, do)
740684
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
741685
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
@@ -752,12 +696,12 @@ def backward_kernel_one_col_block_sparse(
752696
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
753697

754698
Di = tl.load(D + offs_d_or_lse)
755-
Di = Di.reshape(QUERY_HEAD_GROUPS * BLOCK)
699+
Di = tl.trans(Di) # (m, h)
756700

757701
# Converting ds to q.dtype here reduces register pressure and makes it much faster
758702
# for BLOCK_HEADDIM=128
759703

760-
dq = tl.zeros([QUERY_HEAD_GROUPS * BLOCK, BLOCK_HEADDIM], dtype = tl.float32)
704+
dq = tl.zeros([BLOCK, QUERY_HEAD_GROUPS, BLOCK_HEADDIM], dtype = tl.float32)
761705

762706
# handle kv block indices using atomic adds for starters, todo: swap dq and dk/dv loops at some point, semi big refactor
763707

@@ -774,30 +718,39 @@ def backward_kernel_one_col_block_sparse(
774718
block_indices = tl.load(kv_block_indices_ptrs + OFF_SEL_KV_BLOCKS)
775719
block_masks = tl.load(kv_block_mask_ptrs + OFF_SEL_KV_BLOCKS)
776720

777-
blocks_offs_n = block_indices[:, None] * BLOCK + tl.arange(0, BLOCK)[None, :]
721+
blocks_offs_n = (
722+
block_indices[:, None] * BLOCK +
723+
tl.arange(0, BLOCK)[None, :]
724+
)
778725

779726
block_k_ptrs = (
780-
K + blocks_offs_n[:, :, None] * stride_kn + offs_d[None, None, :]
727+
K +
728+
blocks_offs_n[:, :, None] * stride_kn +
729+
offs_d[None, None, :]
781730
)
782731

783732
block_v_ptrs = (
784-
V + blocks_offs_n[:, :, None] * stride_vn + offs_d[None, None, :]
733+
V +
734+
blocks_offs_n[:, :, None] * stride_vn +
735+
offs_d[None, None, :]
785736
)
786737

787738
block_dv_ptrs = (
788-
DV + blocks_offs_n[:, :, None] * stride_dvn + offs_d[None, None, :]
739+
DV +
740+
blocks_offs_n[:, :, None] * stride_dvn +
741+
offs_d[None, None, :]
789742
)
790743

791744
block_dk_ptrs = (
792-
DK + blocks_offs_n[:, :, None] * stride_dkn + offs_d[None, None, :]
745+
DK +
746+
blocks_offs_n[:, :, None] * stride_dkn +
747+
offs_d[None, None, :]
793748
)
794749

795750
block_k = tl.load(block_k_ptrs)
796751
block_v = tl.load(block_v_ptrs)
797752

798-
q_expanded = q.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK_HEADDIM)
799-
q_expanded = q_expanded.permute(1, 0, 2)
800-
q_expanded = tl.expand_dims(q_expanded, 2)
753+
q_expanded = tl.expand_dims(q, 2)
801754
q_expanded = tl.broadcast_to(q_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM))
802755
q_expanded = q_expanded.reshape(BLOCK, 16, BLOCK_HEADDIM)
803756

@@ -806,84 +759,77 @@ def backward_kernel_one_col_block_sparse(
806759

807760
block_qk = block_qk.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK)
808761
qk = tl.sum(block_qk, 2) / QUERY_EXPAND_DIM
809-
qk = qk.permute(1, 0, 2)
810-
811-
qk += tl.where(block_masks[None, :, None], 0, float("-inf"))
812762

813-
qk = qk.reshape(QUERY_HEAD_GROUPS * BLOCK, BLOCK)
763+
qk += tl.where(block_masks[:, None, None], 0, float("-inf"))
814764

815-
p = tl.exp(qk * softmax_scale - lse_i[:, None])
765+
p = tl.exp(qk * softmax_scale - lse_i[:, :, None])
816766

817767
# take care of block dv
818768

819-
block_dv = p.to(do.dtype)[:, :, None] * do[:, None, :]
769+
block_dv = p.to(do.dtype)[:, :, :, None] * do[:, :, None, :]
820770

821-
block_dv = block_dv.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK, BLOCK_HEADDIM)
822-
block_dv = tl.sum(block_dv, 0)
771+
# block_dv = block_dv.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK, BLOCK_HEADDIM)
772+
block_dv = tl.sum(block_dv, 1)
823773

824774
tl.atomic_add(block_dv_ptrs, block_dv, sem = 'relaxed')
825775

826776
# get dp
827777

828-
do_expanded = do.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK_HEADDIM)
829-
do_expanded = do_expanded.permute(1, 0, 2)
830-
do_expanded = tl.expand_dims(do_expanded, 2)
831-
do_expanded = tl.broadcast_to(do_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM))
832-
do_expanded = do_expanded.reshape(BLOCK, 16, BLOCK_HEADDIM)
778+
do = tl.expand_dims(do, 2)
779+
do = tl.broadcast_to(do, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM))
780+
do = do.reshape(BLOCK, 16, BLOCK_HEADDIM)
833781

834782
block_v = tl.permute(block_v, (0, 2, 1))
835783

836-
dp = tl.dot(do_expanded, block_v)
784+
dp = tl.dot(do, block_v)
837785

838786
dp = dp.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK)
839787
dp = tl.sum(dp, 2) / QUERY_EXPAND_DIM
840-
dp = dp.permute(1, 0, 2)
841-
dp = dp.reshape(QUERY_HEAD_GROUPS * BLOCK, BLOCK)
842788

843789
# ds
844790

845-
ds = (p * (dp - Di[:, None]) * softmax_scale)
791+
ds = (p * (dp - Di[:, :, None]) * softmax_scale)
846792
ds = ds.to(q.dtype)
847793

848794
# block dk
849795

850-
block_dk = ds[:, :, None] * q[:, None, :]
851-
block_dk = block_dk.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK, BLOCK_HEADDIM)
852-
block_dk = tl.sum(block_dk, 0)
796+
block_dk = ds[:, :, :, None] * q[:, :, None, :]
797+
block_dk = tl.sum(block_dk, 1)
853798

854799
tl.atomic_add(block_dk_ptrs, block_dk, sem = 'relaxed')
855800

856801
# block dq
857802

858-
ds_expanded = ds.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK)
859-
ds_expanded = ds_expanded.permute(1, 0, 2)
860-
ds_expanded = tl.expand_dims(ds_expanded, 2)
803+
ds_expanded = tl.expand_dims(ds, 2)
861804
ds_expanded = tl.broadcast_to(ds_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK))
862805
ds_expanded = ds_expanded.reshape(BLOCK, 16, BLOCK)
863806

864807
block_dq = tl.dot(ds_expanded, block_k)
865808

866809
block_dq = block_dq.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM)
867810
block_dq = tl.sum(block_dq, 2) / QUERY_EXPAND_DIM
868-
block_dq = block_dq.permute(1, 0, 2)
869-
block_dq = block_dq.reshape(QUERY_HEAD_GROUPS * BLOCK, BLOCK_HEADDIM)
870811

871812
dq += block_dq
872813

873814
# update dq
874815

875-
dq = dq.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK_HEADDIM)
816+
dq = dq.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK_HEADDIM)
876817

877818
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
878819
tl.atomic_add(dq_ptrs, dq, sem = 'relaxed')
879820
else:
880821
if EVEN_HEADDIM:
881-
tl.atomic_add(dq_ptrs, dq, mask=offs_m[None, :, None] < seqlen_q, sem = 'relaxed')
822+
tl.atomic_add(
823+
dq_ptrs,
824+
dq,
825+
mask = offs_m[:, None, None] < seqlen_q,
826+
sem = 'relaxed'
827+
)
882828
else:
883829
tl.atomic_add(
884830
dq_ptrs,
885831
dq,
886-
mask = (offs_m[None, :, None] < seqlen_q) & (offs_d[None, None, :] < headdim),
832+
mask = (offs_m[:, None, None] < seqlen_q) & (offs_d[None, None, :] < headdim),
887833
sem = 'relaxed',
888834
)
889835

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.56"
3+
version = "0.0.57"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

test_triton_nsa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@ def regular_attend(
103103
batch = 2
104104
seq_len = 511
105105
q_heads = 4
106-
kv_heads = 4
106+
kv_heads = 2
107107
fine_block_size = 16
108-
num_sel = 4
108+
num_sel = 6
109109

110110
q = torch.randn(batch, q_heads, seq_len, 64).cuda()
111111
k = torch.randn(batch, kv_heads, seq_len, 64).cuda()

0 commit comments

Comments
 (0)