Skip to content

Commit c395e2a

Browse files
committed
parallelize gathering of gradients for causal and fine selected kv blocks
1 parent 6c24fa8 commit c395e2a

File tree

3 files changed

+48
-46
lines changed

3 files changed

+48
-46
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,13 +1208,15 @@ def backward_kernel(
12081208
BLOCK: tl.constexpr,
12091209
QUERY_HEAD_GROUPS: tl.constexpr,
12101210
QUERY_EXPAND_DIM: tl.constexpr,
1211-
NUM_SEL_KV_BLOCKS: tl.constexpr
12121211
):
12131212
off_hb = tl.program_id(1)
12141213
off_b = off_hb // kv_heads
12151214
off_h = off_hb % kv_heads
12161215
off_qh = off_h * QUERY_HEAD_GROUPS
12171216

1217+
IS_CAUSAL = tl.program_id(0) == 0
1218+
OFF_SEL_KV_BLOCKS = tl.program_id(0) - 1
1219+
12181220
# offset pointers for batch/head
12191221

12201222
Q += off_b * stride_qb + off_qh * stride_qh
@@ -1244,46 +1246,47 @@ def backward_kernel(
12441246

12451247
num_block_n = tl.cdiv(seqlen_k, BLOCK)
12461248

1247-
for start_n in range(0, num_block_n):
1248-
backward_kernel_one_col_block_causal(
1249-
start_n,
1250-
Q,
1251-
K,
1252-
V,
1253-
kv_block_indices,
1254-
kv_block_mask,
1255-
DO,
1256-
DQ,
1257-
DK,
1258-
DV,
1259-
LSE,
1260-
D,
1261-
softmax_scale,
1262-
stride_qm,
1263-
stride_kn,
1264-
stride_vn,
1265-
stride_dom,
1266-
stride_dqm,
1267-
stride_dkn,
1268-
stride_dvn,
1269-
stride_kvbl_m,
1270-
stride_qh,
1271-
stride_doh,
1272-
stride_dqh,
1273-
seqlen_q,
1274-
seqlen_k,
1275-
seqlen_q_rounded,
1276-
headdim,
1277-
BLOCK_HEADDIM = BLOCK_HEADDIM,
1278-
EVEN_M = EVEN_M,
1279-
EVEN_N = EVEN_N,
1280-
EVEN_HEADDIM = EVEN_HEADDIM,
1281-
BLOCK = BLOCK,
1282-
QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS,
1283-
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM,
1284-
)
1285-
1286-
for off_sel_kv_blocks in range(NUM_SEL_KV_BLOCKS):
1249+
if IS_CAUSAL:
1250+
for start_n in range(0, num_block_n):
1251+
backward_kernel_one_col_block_causal(
1252+
start_n,
1253+
Q,
1254+
K,
1255+
V,
1256+
kv_block_indices,
1257+
kv_block_mask,
1258+
DO,
1259+
DQ,
1260+
DK,
1261+
DV,
1262+
LSE,
1263+
D,
1264+
softmax_scale,
1265+
stride_qm,
1266+
stride_kn,
1267+
stride_vn,
1268+
stride_dom,
1269+
stride_dqm,
1270+
stride_dkn,
1271+
stride_dvn,
1272+
stride_kvbl_m,
1273+
stride_qh,
1274+
stride_doh,
1275+
stride_dqh,
1276+
seqlen_q,
1277+
seqlen_k,
1278+
seqlen_q_rounded,
1279+
headdim,
1280+
BLOCK_HEADDIM = BLOCK_HEADDIM,
1281+
EVEN_M = EVEN_M,
1282+
EVEN_N = EVEN_N,
1283+
EVEN_HEADDIM = EVEN_HEADDIM,
1284+
BLOCK = BLOCK,
1285+
QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS,
1286+
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM,
1287+
)
1288+
else:
1289+
for start_n in range(0, num_block_n):
12871290
backward_kernel_one_col_block_sparse(
12881291
start_n,
12891292
Q,
@@ -1320,7 +1323,7 @@ def backward_kernel(
13201323
BLOCK = BLOCK,
13211324
QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS,
13221325
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM,
1323-
OFF_SEL_KV_BLOCKS = off_sel_kv_blocks
1326+
OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS
13241327
)
13251328

13261329
def native_sparse_attn_backward(
@@ -1383,7 +1386,7 @@ def native_sparse_attn_backward(
13831386
BLOCK_HEADDIM = BLOCK_HEADDIM,
13841387
)
13851388

1386-
grid = lambda META: (1, batch * kv_heads)
1389+
grid = lambda META: (num_sel_fine_blocks + 1, batch * kv_heads)
13871390

13881391
backward_kernel[grid](
13891392
q,
@@ -1437,7 +1440,6 @@ def native_sparse_attn_backward(
14371440
BLOCK = block_size,
14381441
QUERY_HEAD_GROUPS = head_groups,
14391442
QUERY_EXPAND_DIM = 16 // head_groups,
1440-
NUM_SEL_KV_BLOCKS = num_sel_fine_blocks,
14411443
EVEN_M = divisible_by(seqlen_q, block_size),
14421444
EVEN_N = divisible_by(seqlen_k, block_size),
14431445
EVEN_HEADDIM = BLOCK_HEADDIM == dim

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.54"
3+
version = "0.0.55"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

test_triton_nsa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def regular_attend(
111111
k = torch.randn(batch, kv_heads, seq_len, 64).cuda()
112112
v = torch.randn(batch, kv_heads, seq_len, 64).cuda()
113113

114-
indices = torch.zeros(batch, kv_heads, seq_len, num_sel).long().cuda()
114+
indices = torch.randint(0, 2, (batch, kv_heads, seq_len, num_sel)).cuda()
115115
mask = torch.randint(0, 2, (batch, kv_heads, seq_len, num_sel)).bool().cuda()
116116

117117
# both regular and nsa pathways `r` and `n`

0 commit comments

Comments
 (0)