Skip to content

Commit afc0e85

Browse files
committed
keep chipping away at it
1 parent 2c3ad83 commit afc0e85

File tree

3 files changed

+66
-30
lines changed

3 files changed

+66
-30
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,9 @@ def backward_kernel_one_col_block_sparse(
742742
QUERY_HEAD_GROUPS: tl.constexpr,
743743
QUERY_EXPAND_DIM: tl.constexpr,
744744
RETURN_SEL_GRADS: tl.constexpr,
745-
OFF_SEL_KV_BLOCKS: tl.constexpr
745+
OFF_SEL_KV_BLOCKS: tl.constexpr,
746+
BLOCK_DV_USE_DOT: tl.constexpr,
747+
BLOCK_DK_USE_DOT: tl.constexpr,
746748
):
747749
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
748750

@@ -918,23 +920,33 @@ def backward_kernel_one_col_block_sparse(
918920

919921
p = tl.exp(masked_qk * softmax_scale - lse_i[:, :, None])
920922

923+
# prepare do
924+
925+
do_expanded = tl.expand_dims(do, 2)
926+
do_expanded = tl.broadcast_to(do_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM))
927+
do_expanded = do_expanded.reshape(BLOCK, 16, BLOCK_HEADDIM)
928+
921929
# take care of block dv
922930

923-
block_dv = p.to(do.dtype)[:, :, :, None] * do[:, :, None, :]
931+
if not BLOCK_DV_USE_DOT:
932+
p_expanded = p.to(do.dtype)
933+
p_expanded = tl.expand_dims(p_expanded, 2)
934+
p_expanded = tl.broadcast_to(p_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK))
935+
p_expanded = p_expanded.reshape(BLOCK, QUERY_HEAD_GROUPS * QUERY_EXPAND_DIM, BLOCK)
936+
p_expanded = tl.permute(p_expanded, (0, 2, 1))
924937

925-
block_dv = tl.sum(block_dv, 1)
938+
block_dv = tl.dot(p_expanded, do_expanded) / QUERY_EXPAND_DIM
939+
else:
940+
block_dv = p.to(do.dtype)[:, :, :, None] * do[:, :, None, :]
941+
block_dv = tl.sum(block_dv, 1)
926942

927943
tl.atomic_add(block_dv_ptrs, block_dv, mask = block_masks[:, None, None], sem = 'relaxed')
928944

929945
# get dp
930946

931-
do = tl.expand_dims(do, 2)
932-
do = tl.broadcast_to(do, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM))
933-
do = do.reshape(BLOCK, 16, BLOCK_HEADDIM)
934-
935947
block_v = tl.permute(block_v, (0, 2, 1))
936948

937-
dp = tl.dot(do, block_v)
949+
dp = tl.dot(do_expanded, block_v)
938950

939951
dp = dp.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK)
940952
dp = tl.sum(dp, 2) / QUERY_EXPAND_DIM
@@ -963,10 +975,23 @@ def backward_kernel_one_col_block_sparse(
963975
sem = 'relaxed'
964976
)
965977

978+
# ds
979+
980+
ds_expanded = tl.expand_dims(ds, 2)
981+
ds_expanded = tl.broadcast_to(ds_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK))
982+
ds_expanded = ds_expanded.reshape(BLOCK, 16, BLOCK)
983+
966984
# block dk
967985

968-
block_dk = ds[:, :, :, None] * q[:, :, None, :].to(ds.dtype)
969-
block_dk = tl.sum(block_dk, 1)
986+
if BLOCK_DK_USE_DOT:
987+
q_expanded = tl.expand_dims(q, 2)
988+
q_expanded = tl.broadcast_to(q_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM))
989+
q_expanded = q_expanded.reshape(BLOCK, 16, BLOCK_HEADDIM)
990+
991+
block_dk = tl.dot(tl.permute(ds_expanded, (0, 2, 1)), q_expanded.to(ds.dtype)) / QUERY_EXPAND_DIM
992+
else:
993+
block_dk = ds[:, :, :, None] * q[:, :, None, :].to(ds.dtype)
994+
block_dk = tl.sum(block_dk, 1)
970995

971996
tl.atomic_add(
972997
block_dk_ptrs,
@@ -977,11 +1002,7 @@ def backward_kernel_one_col_block_sparse(
9771002

9781003
# block dq
9791004

980-
ds_expanded = tl.expand_dims(ds, 2)
981-
ds_expanded = tl.broadcast_to(ds_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK))
982-
ds_expanded = ds_expanded.reshape(BLOCK, 16, BLOCK)
9831005
ds_expanded = ds_expanded.to(block_k.dtype)
984-
9851006
block_dq = tl.dot(ds_expanded, block_k)
9861007

9871008
block_dq = block_dq.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM)
@@ -1348,6 +1369,8 @@ def backward_kernel(
13481369
RETURN_SEL_GRADS: tl.constexpr,
13491370
INCLUDE_BLOCK_CAUSAL: tl.constexpr,
13501371
SLIDING: tl.constexpr,
1372+
BLOCK_DV_USE_DOT: tl.constexpr,
1373+
BLOCK_DK_USE_DOT: tl.constexpr,
13511374
):
13521375
off_hb = tl.program_id(1)
13531376
off_b = off_hb // kv_heads
@@ -1467,7 +1490,9 @@ def backward_kernel(
14671490
QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS,
14681491
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM,
14691492
RETURN_SEL_GRADS = RETURN_SEL_GRADS,
1470-
OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS
1493+
OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS,
1494+
BLOCK_DV_USE_DOT = BLOCK_DV_USE_DOT,
1495+
BLOCK_DK_USE_DOT = BLOCK_DK_USE_DOT,
14711496
)
14721497

14731498
def native_sparse_attn_backward(
@@ -1482,7 +1507,8 @@ def native_sparse_attn_backward(
14821507
block_size = 128,
14831508
include_block_causal = True,
14841509
return_sel_grads = False,
1485-
sliding = False
1510+
sliding = False,
1511+
block_dk_dv_use_dot = None
14861512
):
14871513
device = do.device
14881514

@@ -1596,7 +1622,9 @@ def native_sparse_attn_backward(
15961622
EVEN_HEADDIM = BLOCK_HEADDIM == dim,
15971623
RETURN_SEL_GRADS = return_sel_grads,
15981624
INCLUDE_BLOCK_CAUSAL = include_block_causal,
1599-
SLIDING = sliding
1625+
SLIDING = sliding,
1626+
BLOCK_DV_USE_DOT = default(block_dk_dv_use_dot, head_groups > 1),
1627+
BLOCK_DK_USE_DOT = default(block_dk_dv_use_dot, head_groups > 1)
16001628
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
16011629
# num_warps=num_warps,
16021630
# num_stages=1,
@@ -1619,7 +1647,8 @@ def forward(
16191647
selected_block_indices,
16201648
fmask,
16211649
sel_scale,
1622-
include_block_causal
1650+
include_block_causal,
1651+
block_dk_dv_use_dot
16231652
):
16241653
dtype = fq.dtype
16251654

@@ -1649,6 +1678,7 @@ def forward(
16491678
head_groups,
16501679
return_sel_grads,
16511680
include_block_causal,
1681+
block_dk_dv_use_dot
16521682
)
16531683

16541684
return out.type(dtype), lse
@@ -1663,7 +1693,8 @@ def backward(self, ctx, do, _):
16631693
block_size,
16641694
head_groups,
16651695
return_sel_grads,
1666-
include_block_causal
1696+
include_block_causal,
1697+
block_dk_dv_use_dot
16671698
) = ctx._saved_variables
16681699

16691700
do = do.half()
@@ -1679,15 +1710,16 @@ def backward(self, ctx, do, _):
16791710
out, lse, dq, dk, dv,
16801711
block_size = block_size,
16811712
include_block_causal = include_block_causal,
1682-
return_sel_grads = return_sel_grads
1713+
return_sel_grads = return_sel_grads,
1714+
block_dk_dv_use_dot = block_dk_dv_use_dot
16831715
)
16841716

16851717
ret_sel_grads = None
16861718

16871719
if return_sel_grads:
16881720
ret_sel_grads = sel_grads
16891721

1690-
return dq, dk, dv, None, None, None, ret_sel_grads, None
1722+
return dq, dk, dv, None, None, None, ret_sel_grads, None, None
16911723

16921724
_native_sparse_attend = NSA.apply
16931725

@@ -1709,7 +1741,8 @@ def native_sparse_attend(
17091741
fmask: Bool['b qh n sel'] | Bool['b kh n sel'],
17101742
sel_scale: Float['b kh n sel'] | Float['b qh n sel'] | None = None,
17111743
include_block_causal = True,
1712-
return_lse = False
1744+
return_lse = False,
1745+
block_dk_dv_use_dot = False
17131746
):
17141747
seq_len = fq.shape[-2]
17151748
q_heads, kv_heads, sel_heads = fq.shape[1], fk.shape[1], selected_block_indices.shape[1]
@@ -1730,7 +1763,8 @@ def native_sparse_attend(
17301763
selected_block_indices,
17311764
fmask,
17321765
sel_scale,
1733-
include_block_causal
1766+
include_block_causal,
1767+
block_dk_dv_use_dot
17341768
)
17351769

17361770
if not return_lse:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.1.2"
3+
version = "0.1.4"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

test_triton_nsa.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,13 @@ def regular_attend(
140140
kv_heads = 2
141141
fine_block_size = 16
142142
num_sel = 6
143+
dim_head = 64
143144
fused_sliding_window = False
145+
block_dk_dv_use_dot = False
144146

145-
q = torch.randn(batch, q_heads, seq_len, 64).cuda()
146-
k = torch.randn(batch, kv_heads, seq_len, 64).cuda()
147-
v = torch.randn(batch, kv_heads, seq_len, 64).cuda()
147+
q = torch.randn(batch, q_heads, seq_len, dim_head).cuda()
148+
k = torch.randn(batch, kv_heads, seq_len, dim_head).cuda()
149+
v = torch.randn(batch, kv_heads, seq_len, dim_head).cuda()
148150

149151
indices = torch.randint(0, 2, (batch, kv_heads, seq_len, num_sel)).cuda()
150152
mask = torch.randint(0, 2, (batch, kv_heads, seq_len, num_sel)).bool().cuda()
@@ -166,17 +168,17 @@ def regular_attend(
166168

167169
# triton nsa forwards and backwards
168170

169-
nsa_out, nlse = native_sparse_attend(nq, nk, nv, fine_block_size, indices, mask, sel_scale = nsel_scale, return_lse = True)
171+
nsa_out, nlse = native_sparse_attend(nq, nk, nv, fine_block_size, indices, mask, sel_scale = nsel_scale, return_lse = True, block_dk_dv_use_dot = block_dk_dv_use_dot)
170172
nsa_out.sum().backward()
171173

172174
# asserts
173175

174176
assert torch.allclose(out, nsa_out, atol = 1e-2)
175-
assert torch.allclose(rlse, nlse, atol = 1e-2)
177+
assert torch.allclose(rlse, nlse, atol = 2e-2)
176178

177179
assert torch.allclose(rsel_scale.grad, nsel_scale.grad, atol = 1e-2)
178180
assert torch.allclose(nv.grad, rv.grad, atol = 1e-2)
179181
assert torch.allclose(nq.grad, rq.grad, atol = 1e-2)
180-
assert torch.allclose(nk.grad, rk.grad, atol = 1e-2)
182+
assert torch.allclose(nk.grad, rk.grad, atol = 2e-2)
181183

182184
print('✅ outputs and gradients are same between pytorch native sparse attn and triton native sparse attn')

0 commit comments

Comments
 (0)