@@ -1526,93 +1526,91 @@ def backward_kernel(
1526
1526
off_qh * seqlen_q_rounded
1527
1527
)
1528
1528
1529
- num_block_n = tl .cdiv ( seqlen_k , BLOCK )
1529
+ start_n = tl .program_id ( 2 )
1530
1530
1531
1531
if IS_CAUSAL :
1532
- for start_n in range (0 , num_block_n ):
1533
- backward_kernel_one_col_block_causal (
1534
- start_n ,
1535
- Q ,
1536
- K ,
1537
- V ,
1538
- kv_block_indices ,
1539
- kv_block_mask ,
1540
- DO ,
1541
- DQ ,
1542
- DK ,
1543
- DV ,
1544
- LSE ,
1545
- D ,
1546
- softmax_scale ,
1547
- stride_qm ,
1548
- stride_kn ,
1549
- stride_vn ,
1550
- stride_dom ,
1551
- stride_dqm ,
1552
- stride_dkn ,
1553
- stride_dvn ,
1554
- stride_kvbl_m ,
1555
- stride_qh ,
1556
- stride_doh ,
1557
- stride_dqh ,
1558
- seqlen_q ,
1559
- seqlen_k ,
1560
- seqlen_q_rounded ,
1561
- headdim ,
1562
- BLOCK_HEADDIM = BLOCK_HEADDIM ,
1563
- EVEN_M = EVEN_M ,
1564
- EVEN_N = EVEN_N ,
1565
- EVEN_HEADDIM = EVEN_HEADDIM ,
1566
- BLOCK = BLOCK ,
1567
- SEL_BLOCK = SEL_BLOCK ,
1568
- QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1569
- QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1570
- SLIDING = SLIDING
1571
- )
1532
+ backward_kernel_one_col_block_causal (
1533
+ start_n ,
1534
+ Q ,
1535
+ K ,
1536
+ V ,
1537
+ kv_block_indices ,
1538
+ kv_block_mask ,
1539
+ DO ,
1540
+ DQ ,
1541
+ DK ,
1542
+ DV ,
1543
+ LSE ,
1544
+ D ,
1545
+ softmax_scale ,
1546
+ stride_qm ,
1547
+ stride_kn ,
1548
+ stride_vn ,
1549
+ stride_dom ,
1550
+ stride_dqm ,
1551
+ stride_dkn ,
1552
+ stride_dvn ,
1553
+ stride_kvbl_m ,
1554
+ stride_qh ,
1555
+ stride_doh ,
1556
+ stride_dqh ,
1557
+ seqlen_q ,
1558
+ seqlen_k ,
1559
+ seqlen_q_rounded ,
1560
+ headdim ,
1561
+ BLOCK_HEADDIM = BLOCK_HEADDIM ,
1562
+ EVEN_M = EVEN_M ,
1563
+ EVEN_N = EVEN_N ,
1564
+ EVEN_HEADDIM = EVEN_HEADDIM ,
1565
+ BLOCK = BLOCK ,
1566
+ SEL_BLOCK = SEL_BLOCK ,
1567
+ QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1568
+ QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1569
+ SLIDING = SLIDING
1570
+ )
1572
1571
else :
1573
- for start_n in range (0 , num_block_n ):
1574
- backward_kernel_one_col_block_sparse (
1575
- start_n ,
1576
- Q ,
1577
- K ,
1578
- V ,
1579
- kv_block_indices ,
1580
- kv_block_mask ,
1581
- kv_block_grads ,
1582
- DO ,
1583
- DQ ,
1584
- DK ,
1585
- DV ,
1586
- LSE ,
1587
- D ,
1588
- softmax_scale ,
1589
- stride_qm ,
1590
- stride_kn ,
1591
- stride_vn ,
1592
- stride_dom ,
1593
- stride_dqm ,
1594
- stride_dkn ,
1595
- stride_dvn ,
1596
- stride_kvbl_m ,
1597
- stride_qh ,
1598
- stride_doh ,
1599
- stride_dqh ,
1600
- seqlen_q ,
1601
- seqlen_k ,
1602
- seqlen_q_rounded ,
1603
- headdim ,
1604
- BLOCK_HEADDIM = BLOCK_HEADDIM ,
1605
- EVEN_M = EVEN_M ,
1606
- EVEN_N = EVEN_N ,
1607
- EVEN_HEADDIM = EVEN_HEADDIM ,
1608
- BLOCK = BLOCK ,
1609
- QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1610
- QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1611
- RETURN_SEL_GRADS = RETURN_SEL_GRADS ,
1612
- OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS ,
1613
- BLOCK_DV_USE_DOT = BLOCK_DV_USE_DOT ,
1614
- BLOCK_DK_USE_DOT = BLOCK_DK_USE_DOT ,
1615
- )
1572
+ backward_kernel_one_col_block_sparse (
1573
+ start_n ,
1574
+ Q ,
1575
+ K ,
1576
+ V ,
1577
+ kv_block_indices ,
1578
+ kv_block_mask ,
1579
+ kv_block_grads ,
1580
+ DO ,
1581
+ DQ ,
1582
+ DK ,
1583
+ DV ,
1584
+ LSE ,
1585
+ D ,
1586
+ softmax_scale ,
1587
+ stride_qm ,
1588
+ stride_kn ,
1589
+ stride_vn ,
1590
+ stride_dom ,
1591
+ stride_dqm ,
1592
+ stride_dkn ,
1593
+ stride_dvn ,
1594
+ stride_kvbl_m ,
1595
+ stride_qh ,
1596
+ stride_doh ,
1597
+ stride_dqh ,
1598
+ seqlen_q ,
1599
+ seqlen_k ,
1600
+ seqlen_q_rounded ,
1601
+ headdim ,
1602
+ BLOCK_HEADDIM = BLOCK_HEADDIM ,
1603
+ EVEN_M = EVEN_M ,
1604
+ EVEN_N = EVEN_N ,
1605
+ EVEN_HEADDIM = EVEN_HEADDIM ,
1606
+ BLOCK = BLOCK ,
1607
+ QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1608
+ QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1609
+ RETURN_SEL_GRADS = RETURN_SEL_GRADS ,
1610
+ OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS ,
1611
+ BLOCK_DV_USE_DOT = BLOCK_DV_USE_DOT ,
1612
+ BLOCK_DK_USE_DOT = BLOCK_DK_USE_DOT ,
1613
+ )
1616
1614
1617
1615
def native_sparse_attn_backward (
1618
1616
do ,
@@ -1692,7 +1690,8 @@ def native_sparse_attn_backward(
1692
1690
1693
1691
grid = lambda META : (
1694
1692
num_sel_fine_blocks + int (include_block_causal ),
1695
- batch * kv_heads
1693
+ batch * kv_heads ,
1694
+ triton .cdiv (seqlen_k , META ['BLOCK' ])
1696
1695
)
1697
1696
1698
1697
backward_kernel [grid ](
0 commit comments