Skip to content

Commit 8a75a40

Browse files
committed
seq parallel for backwards nsa
1 parent 922a633 commit 8a75a40

File tree

2 files changed

+85
-86
lines changed

2 files changed

+85
-86
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 84 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,93 +1526,91 @@ def backward_kernel(
15261526
off_qh * seqlen_q_rounded
15271527
)
15281528

1529-
num_block_n = tl.cdiv(seqlen_k, BLOCK)
1529+
start_n = tl.program_id(2)
15301530

15311531
if IS_CAUSAL:
1532-
for start_n in range(0, num_block_n):
1533-
backward_kernel_one_col_block_causal(
1534-
start_n,
1535-
Q,
1536-
K,
1537-
V,
1538-
kv_block_indices,
1539-
kv_block_mask,
1540-
DO,
1541-
DQ,
1542-
DK,
1543-
DV,
1544-
LSE,
1545-
D,
1546-
softmax_scale,
1547-
stride_qm,
1548-
stride_kn,
1549-
stride_vn,
1550-
stride_dom,
1551-
stride_dqm,
1552-
stride_dkn,
1553-
stride_dvn,
1554-
stride_kvbl_m,
1555-
stride_qh,
1556-
stride_doh,
1557-
stride_dqh,
1558-
seqlen_q,
1559-
seqlen_k,
1560-
seqlen_q_rounded,
1561-
headdim,
1562-
BLOCK_HEADDIM = BLOCK_HEADDIM,
1563-
EVEN_M = EVEN_M,
1564-
EVEN_N = EVEN_N,
1565-
EVEN_HEADDIM = EVEN_HEADDIM,
1566-
BLOCK = BLOCK,
1567-
SEL_BLOCK = SEL_BLOCK,
1568-
QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS,
1569-
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM,
1570-
SLIDING = SLIDING
1571-
)
1532+
backward_kernel_one_col_block_causal(
1533+
start_n,
1534+
Q,
1535+
K,
1536+
V,
1537+
kv_block_indices,
1538+
kv_block_mask,
1539+
DO,
1540+
DQ,
1541+
DK,
1542+
DV,
1543+
LSE,
1544+
D,
1545+
softmax_scale,
1546+
stride_qm,
1547+
stride_kn,
1548+
stride_vn,
1549+
stride_dom,
1550+
stride_dqm,
1551+
stride_dkn,
1552+
stride_dvn,
1553+
stride_kvbl_m,
1554+
stride_qh,
1555+
stride_doh,
1556+
stride_dqh,
1557+
seqlen_q,
1558+
seqlen_k,
1559+
seqlen_q_rounded,
1560+
headdim,
1561+
BLOCK_HEADDIM = BLOCK_HEADDIM,
1562+
EVEN_M = EVEN_M,
1563+
EVEN_N = EVEN_N,
1564+
EVEN_HEADDIM = EVEN_HEADDIM,
1565+
BLOCK = BLOCK,
1566+
SEL_BLOCK = SEL_BLOCK,
1567+
QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS,
1568+
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM,
1569+
SLIDING = SLIDING
1570+
)
15721571
else:
1573-
for start_n in range(0, num_block_n):
1574-
backward_kernel_one_col_block_sparse(
1575-
start_n,
1576-
Q,
1577-
K,
1578-
V,
1579-
kv_block_indices,
1580-
kv_block_mask,
1581-
kv_block_grads,
1582-
DO,
1583-
DQ,
1584-
DK,
1585-
DV,
1586-
LSE,
1587-
D,
1588-
softmax_scale,
1589-
stride_qm,
1590-
stride_kn,
1591-
stride_vn,
1592-
stride_dom,
1593-
stride_dqm,
1594-
stride_dkn,
1595-
stride_dvn,
1596-
stride_kvbl_m,
1597-
stride_qh,
1598-
stride_doh,
1599-
stride_dqh,
1600-
seqlen_q,
1601-
seqlen_k,
1602-
seqlen_q_rounded,
1603-
headdim,
1604-
BLOCK_HEADDIM = BLOCK_HEADDIM,
1605-
EVEN_M = EVEN_M,
1606-
EVEN_N = EVEN_N,
1607-
EVEN_HEADDIM = EVEN_HEADDIM,
1608-
BLOCK = BLOCK,
1609-
QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS,
1610-
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM,
1611-
RETURN_SEL_GRADS = RETURN_SEL_GRADS,
1612-
OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS,
1613-
BLOCK_DV_USE_DOT = BLOCK_DV_USE_DOT,
1614-
BLOCK_DK_USE_DOT = BLOCK_DK_USE_DOT,
1615-
)
1572+
backward_kernel_one_col_block_sparse(
1573+
start_n,
1574+
Q,
1575+
K,
1576+
V,
1577+
kv_block_indices,
1578+
kv_block_mask,
1579+
kv_block_grads,
1580+
DO,
1581+
DQ,
1582+
DK,
1583+
DV,
1584+
LSE,
1585+
D,
1586+
softmax_scale,
1587+
stride_qm,
1588+
stride_kn,
1589+
stride_vn,
1590+
stride_dom,
1591+
stride_dqm,
1592+
stride_dkn,
1593+
stride_dvn,
1594+
stride_kvbl_m,
1595+
stride_qh,
1596+
stride_doh,
1597+
stride_dqh,
1598+
seqlen_q,
1599+
seqlen_k,
1600+
seqlen_q_rounded,
1601+
headdim,
1602+
BLOCK_HEADDIM = BLOCK_HEADDIM,
1603+
EVEN_M = EVEN_M,
1604+
EVEN_N = EVEN_N,
1605+
EVEN_HEADDIM = EVEN_HEADDIM,
1606+
BLOCK = BLOCK,
1607+
QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS,
1608+
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM,
1609+
RETURN_SEL_GRADS = RETURN_SEL_GRADS,
1610+
OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS,
1611+
BLOCK_DV_USE_DOT = BLOCK_DV_USE_DOT,
1612+
BLOCK_DK_USE_DOT = BLOCK_DK_USE_DOT,
1613+
)
16161614

16171615
def native_sparse_attn_backward(
16181616
do,
@@ -1692,7 +1690,8 @@ def native_sparse_attn_backward(
16921690

16931691
grid = lambda META: (
16941692
num_sel_fine_blocks + int(include_block_causal),
1695-
batch * kv_heads
1693+
batch * kv_heads,
1694+
triton.cdiv(seqlen_k, META['BLOCK'])
16961695
)
16971696

16981697
backward_kernel[grid](

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.1.20"
3+
version = "0.1.21"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)