@@ -450,6 +450,7 @@ def forward_kernel(
450
450
kv_block_indices ,
451
451
kv_block_mask ,
452
452
Out ,
453
+ SlidingOut ,
453
454
Lse ,
454
455
softmax_scale ,
455
456
stride_qb ,
@@ -484,15 +485,22 @@ def forward_kernel(
484
485
QUERY_EXPAND_DIM : tl .constexpr ,
485
486
NUM_SEL_KV_BLOCKS : tl .constexpr ,
486
487
INCLUDE_BLOCK_CAUSAL : tl .constexpr ,
487
- SLIDING : tl .constexpr
488
+ RETURN_SLIDING_OUT : tl .constexpr
488
489
):
490
+ if RETURN_SLIDING_OUT :
491
+ sliding = tl .program_id (2 ) == 0
492
+ out_ptr = SlidingOut if sliding else Out
493
+ else :
494
+ sliding = False
495
+ out_ptr = Out
496
+
489
497
forward_kernel_causal_and_sparse (
490
498
Q ,
491
499
K ,
492
500
V ,
493
501
kv_block_indices ,
494
502
kv_block_mask ,
495
- Out ,
503
+ out_ptr ,
496
504
Lse ,
497
505
softmax_scale ,
498
506
stride_qb ,
@@ -527,7 +535,7 @@ def forward_kernel(
527
535
QUERY_EXPAND_DIM ,
528
536
NUM_SEL_KV_BLOCKS ,
529
537
INCLUDE_BLOCK_CAUSAL ,
530
- SLIDING
538
+ sliding
531
539
)
532
540
533
541
def native_sparse_attn_forward (
@@ -537,7 +545,8 @@ def native_sparse_attn_forward(
537
545
kv_block_indices ,
538
546
kv_block_mask ,
539
547
block_size = 128 ,
540
- include_block_causal = True
548
+ include_block_causal = True ,
549
+ return_sliding_window_out = False
541
550
):
542
551
q , k , v , kv_block_indices = [x if is_contiguous (x ) else x .contiguous () for x in (q , k , v , kv_block_indices )]
543
552
@@ -563,11 +572,16 @@ def native_sparse_attn_forward(
563
572
lse = torch .empty ((batch , nheads , seqlen_q_rounded ), device = device , dtype = torch .float32 )
564
573
565
574
o = torch .empty_like (q )
575
+ slide_o = torch .empty_like (q )
566
576
567
577
BLOCK_HEADDIM = max (triton .next_power_of_2 (dim ), 16 )
568
578
num_warps = 4 if dim <= 64 else 8
569
579
570
- grid = lambda META : (triton .cdiv (seqlen_q , META ["BLOCK" ]), batch * kv_heads ) # kv heads here, as grouped query heads all loaded, following the paper
580
+ grid = lambda META : (
581
+ triton .cdiv (seqlen_q , META ["BLOCK" ]),
582
+ batch * kv_heads ,
583
+ (2 if return_sliding_window_out else 1 )
584
+ ) # kv heads here, as grouped query heads all loaded, following the paper
571
585
572
586
forward_kernel [grid ](
573
587
q ,
@@ -576,6 +590,7 @@ def native_sparse_attn_forward(
576
590
kv_block_indices ,
577
591
kv_block_mask ,
578
592
o ,
593
+ slide_o ,
579
594
lse ,
580
595
softmax_scale ,
581
596
q .stride (0 ),
@@ -606,12 +621,12 @@ def native_sparse_attn_forward(
606
621
QUERY_HEAD_GROUPS = head_groups ,
607
622
NUM_SEL_KV_BLOCKS = num_selected_fine_blocks ,
608
623
INCLUDE_BLOCK_CAUSAL = include_block_causal ,
609
- SLIDING = False ,
624
+ RETURN_SLIDING_OUT = False ,
610
625
num_warps = num_warps ,
611
626
num_stages = 1 ,
612
627
)
613
628
614
- return o , lse
629
+ return o , slide_o , lse
615
630
616
631
@triton .jit
617
632
def backward_preprocess_do_o_dot (
@@ -1648,7 +1663,8 @@ def forward(
1648
1663
fmask ,
1649
1664
sel_scale ,
1650
1665
include_block_causal ,
1651
- block_dk_dv_use_dot
1666
+ block_dk_dv_use_dot ,
1667
+ return_sliding_window_out
1652
1668
):
1653
1669
dtype = fq .dtype
1654
1670
@@ -1658,15 +1674,16 @@ def forward(
1658
1674
1659
1675
fq , fk , fv = tuple (t .half () for t in (fq , fk , fv ))
1660
1676
1661
- out , lse = native_sparse_attn_forward (
1677
+ out , slide_out , lse = native_sparse_attn_forward (
1662
1678
fq , fk , fv ,
1663
1679
selected_block_indices ,
1664
1680
fmask ,
1665
1681
block_size = block_size ,
1666
1682
include_block_causal = include_block_causal ,
1683
+ return_sliding_window_out = return_sliding_window_out
1667
1684
)
1668
1685
1669
- ctx .save_for_backward (fq , fk , fv , selected_block_indices , fmask , out , lse )
1686
+ ctx .save_for_backward (fq , fk , fv , selected_block_indices , fmask , out , slide_out , lse )
1670
1687
1671
1688
return_sel_grads = exists (sel_scale )
1672
1689
@@ -1678,23 +1695,32 @@ def forward(
1678
1695
head_groups ,
1679
1696
return_sel_grads ,
1680
1697
include_block_causal ,
1681
- block_dk_dv_use_dot
1698
+ block_dk_dv_use_dot ,
1699
+ return_sliding_window_out
1682
1700
)
1683
1701
1684
- return out .type (dtype ), lse
1702
+ return out .type (dtype ), slide_out . type ( dtype ), lse
1685
1703
1686
1704
@classmethod
1687
- def backward (self , ctx , do , _ ):
1705
+ def backward (self , ctx , do , do_sliding , _ ):
1688
1706
device = do .device
1689
1707
1690
- q , k , v , sel_block_indices , mask , out , lse = ctx .saved_tensors
1708
+ (
1709
+ q , k , v ,
1710
+ sel_block_indices ,
1711
+ mask ,
1712
+ out ,
1713
+ slide_out ,
1714
+ lse
1715
+ ) = ctx .saved_tensors
1691
1716
1692
1717
(
1693
1718
block_size ,
1694
1719
head_groups ,
1695
1720
return_sel_grads ,
1696
1721
include_block_causal ,
1697
- block_dk_dv_use_dot
1722
+ block_dk_dv_use_dot ,
1723
+ return_sliding_window_out
1698
1724
) = ctx ._saved_variables
1699
1725
1700
1726
do = do .half ()
@@ -1719,7 +1745,7 @@ def backward(self, ctx, do, _):
1719
1745
if return_sel_grads :
1720
1746
ret_sel_grads = sel_grads
1721
1747
1722
- return dq , dk , dv , None , None , None , ret_sel_grads , None , None
1748
+ return dq , dk , dv , None , None , None , ret_sel_grads , None , None , None
1723
1749
1724
1750
_native_sparse_attend = NSA .apply
1725
1751
@@ -1742,7 +1768,8 @@ def native_sparse_attend(
1742
1768
sel_scale : Float ['b kh n sel' ] | Float ['b qh n sel' ] | None = None ,
1743
1769
include_block_causal = True ,
1744
1770
return_lse = False ,
1745
- block_dk_dv_use_dot = False
1771
+ block_dk_dv_use_dot = False ,
1772
+ return_sliding_window_out = False
1746
1773
):
1747
1774
seq_len = fq .shape [- 2 ]
1748
1775
q_heads , kv_heads , sel_heads = fq .shape [1 ], fk .shape [1 ], selected_block_indices .shape [1 ]
@@ -1757,16 +1784,20 @@ def native_sparse_attend(
1757
1784
if kv_heads != sel_heads :
1758
1785
fk , fv = tuple (repeat (t , 'b h ... -> b (h gh) ...' , gh = q_heads // kv_heads ) for t in (fk , fv ))
1759
1786
1760
- out , lse = _native_sparse_attend (
1787
+ out , sliding_out , lse = _native_sparse_attend (
1761
1788
fq , fk , fv ,
1762
1789
block_size ,
1763
1790
selected_block_indices ,
1764
1791
fmask ,
1765
1792
sel_scale ,
1766
1793
include_block_causal ,
1767
- block_dk_dv_use_dot
1794
+ block_dk_dv_use_dot ,
1795
+ return_sliding_window_out
1768
1796
)
1769
1797
1798
+ if return_sliding_window_out :
1799
+ out = (out , out )
1800
+
1770
1801
if not return_lse :
1771
1802
return out
1772
1803
0 commit comments