@@ -112,6 +112,7 @@ def forward_kernel_causal_and_sparse(
112
112
QUERY_HEAD_GROUPS : tl .constexpr ,
113
113
QUERY_EXPAND_DIM : tl .constexpr ,
114
114
NUM_SEL_KV_BLOCKS : tl .constexpr ,
115
+ NUM_BLOCKS_PER_SEL : tl .constexpr ,
115
116
INCLUDE_BLOCK_CAUSAL : tl .constexpr ,
116
117
SLIDING : tl .constexpr
117
118
):
@@ -346,6 +347,7 @@ def forward_kernel_causal_and_sparse(
346
347
q = q .reshape (BLOCK , 16 , BLOCK_HEADDIM )
347
348
348
349
for off_sel_kv_block in range (NUM_SEL_KV_BLOCKS ):
350
+
349
351
block_indices = tl .load (
350
352
kv_block_indices_ptrs + off_sel_kv_block ,
351
353
mask = offs_m < seqlen_q ,
@@ -358,86 +360,91 @@ def forward_kernel_causal_and_sparse(
358
360
other = False
359
361
)
360
362
361
- blocks_offs_n = block_indices [:, None ] * BLOCK + tl . arange ( 0 , BLOCK )[ None , :]
363
+ for off_blocks_per_sel in range ( NUM_BLOCKS_PER_SEL ):
362
364
363
- block_k_ptrs = (
364
- K +
365
- off_b * stride_kb +
366
- off_h * stride_kh +
367
- blocks_offs_n [:, :, None ] * stride_kn +
368
- offs_d [None , None , :]
369
- )
365
+ blocks_offs_n = (
366
+ block_indices [:, None ] * (BLOCK * NUM_BLOCKS_PER_SEL ) +
367
+ tl .arange (0 , BLOCK )[None , :] + (off_blocks_per_sel * BLOCK )
368
+ )
370
369
371
- block_v_ptrs = (
372
- V +
373
- off_b * stride_vb +
374
- off_h * stride_vh +
375
- blocks_offs_n [:, :, None ] * stride_vn +
376
- offs_d [None , None , :]
377
- )
370
+ block_k_ptrs = (
371
+ K +
372
+ off_b * stride_kb +
373
+ off_h * stride_kh +
374
+ blocks_offs_n [:, :, None ] * stride_kn +
375
+ offs_d [None , None , :]
376
+ )
378
377
379
- # load k of shape (m, n, d), sparsely selected by each query
378
+ block_v_ptrs = (
379
+ V +
380
+ off_b * stride_vb +
381
+ off_h * stride_vh +
382
+ blocks_offs_n [:, :, None ] * stride_vn +
383
+ offs_d [None , None , :]
384
+ )
380
385
381
- k_block = tl .load (
382
- block_k_ptrs ,
383
- mask = blocks_offs_n [:, :, None ] < seqlen_k ,
384
- other = 0.
385
- )
386
+ # load k of shape (m, n, d), sparsely selected by each query
386
387
387
- # similarities
388
+ k_block = tl .load (
389
+ block_k_ptrs ,
390
+ mask = blocks_offs_n [:, :, None ] < seqlen_k ,
391
+ other = 0.
392
+ )
388
393
389
- block_qk = tl .zeros ([BLOCK , 16 , BLOCK ], dtype = tl .float32 )
390
- sel_qk = tl .zeros ([BLOCK , QUERY_HEAD_GROUPS , BLOCK ], dtype = tl .float32 )
394
+ # similarities
391
395
392
- k_block = k_block . reshape ( BLOCK , BLOCK , BLOCK_HEADDIM )
393
- k_block = k_block . permute ( 0 , 2 , 1 )
396
+ block_qk = tl . zeros ([ BLOCK , 16 , BLOCK ], dtype = tl . float32 )
397
+ sel_qk = tl . zeros ([ BLOCK , QUERY_HEAD_GROUPS , BLOCK ], dtype = tl . float32 )
394
398
395
- block_qk += tl .dot (q , k_block )
396
- block_qk = block_qk .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK )
397
- block_qk = tl .reduce (block_qk , 2 , reduce_avg )
399
+ k_block = k_block .reshape (BLOCK , BLOCK , BLOCK_HEADDIM )
400
+ k_block = k_block .permute (0 , 2 , 1 )
398
401
399
- sel_qk += block_qk
400
- sel_qk += tl .where (block_masks [:, None , None ], 0 , float ("-inf" ))
402
+ block_qk += tl .dot (q , k_block )
403
+ block_qk = block_qk .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK )
404
+ block_qk = tl .reduce (block_qk , 2 , reduce_avg )
401
405
402
- # attention
406
+ sel_qk += block_qk
407
+ sel_qk += tl .where (block_masks [:, None , None ], 0 , float ("-inf" ))
403
408
404
- m_ij = tl .maximum (tl .max (sel_qk , 2 ) * softmax_scale , lse_i )
405
- block_p = tl .exp (sel_qk * softmax_scale - m_ij [:, :, None ])
409
+ # attention
406
410
407
- l_ij = tl .sum (block_p , 2 )
411
+ m_ij = tl .maximum (tl .max (sel_qk , 2 ) * softmax_scale , lse_i )
412
+ block_p = tl .exp (sel_qk * softmax_scale - m_ij [:, :, None ])
408
413
409
- # renormalize the running output
414
+ l_ij = tl . sum ( block_p , 2 )
410
415
411
- acc_o_scale = tl .exp (m_i - m_ij )
412
- acc_o = acc_o * acc_o_scale [:, :, None ]
416
+ # renormalize the running output
413
417
414
- # aggregate values
418
+ acc_o_scale = tl .exp (m_i - m_ij )
419
+ acc_o = acc_o * acc_o_scale [:, :, None ]
415
420
416
- v_block = tl .load (
417
- block_v_ptrs ,
418
- mask = blocks_offs_n [:, :, None ] < seqlen_k ,
419
- other = 0.
420
- )
421
+ # aggregate values
421
422
422
- v_block = tl .reshape (v_block , (BLOCK , BLOCK , BLOCK_HEADDIM ))
423
+ v_block = tl .load (
424
+ block_v_ptrs ,
425
+ mask = blocks_offs_n [:, :, None ] < seqlen_k ,
426
+ other = 0.
427
+ )
423
428
424
- block_p = block_p .to (v_block .dtype )
425
- p_expanded = block_p .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
426
- p_expanded = tl .expand_dims (p_expanded , 2 )
427
- p_expanded = tl .broadcast_to (p_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK ))
428
- p_expanded = p_expanded .reshape (BLOCK , 16 , BLOCK )
429
+ v_block = tl .reshape (v_block , (BLOCK , BLOCK , BLOCK_HEADDIM ))
429
430
430
- block_acc_o = tl .dot (p_expanded , v_block )
431
- block_acc_o = block_acc_o .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM )
432
- block_acc_o = tl .reduce (block_acc_o , 2 , reduce_avg )
431
+ block_p = block_p .to (v_block .dtype )
432
+ p_expanded = block_p .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
433
+ p_expanded = tl .expand_dims (p_expanded , 2 )
434
+ p_expanded = tl .broadcast_to (p_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK ))
435
+ p_expanded = p_expanded .reshape (BLOCK , 16 , BLOCK )
433
436
434
- acc_o += block_acc_o
437
+ block_acc_o = tl .dot (p_expanded , v_block )
438
+ block_acc_o = block_acc_o .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM )
439
+ block_acc_o = tl .reduce (block_acc_o , 2 , reduce_avg )
435
440
436
- # -- update statistics
441
+ acc_o += block_acc_o
437
442
438
- m_i = m_ij
439
- l_i_new = tl .exp (lse_i - m_ij ) + l_ij
440
- lse_i = m_ij + tl .log (l_i_new )
443
+ # -- update statistics
444
+
445
+ m_i = m_ij
446
+ l_i_new = tl .exp (lse_i - m_ij ) + l_ij
447
+ lse_i = m_ij + tl .log (l_i_new )
441
448
442
449
# normalize accumulated out
443
450
@@ -528,6 +535,7 @@ def forward_kernel(
528
535
QUERY_HEAD_GROUPS : tl .constexpr ,
529
536
QUERY_EXPAND_DIM : tl .constexpr ,
530
537
NUM_SEL_KV_BLOCKS : tl .constexpr ,
538
+ NUM_BLOCKS_PER_SEL : tl .constexpr ,
531
539
INCLUDE_BLOCK_CAUSAL : tl .constexpr ,
532
540
RETURN_SLIDING_OUT : tl .constexpr
533
541
):
@@ -583,6 +591,7 @@ def forward_kernel(
583
591
QUERY_HEAD_GROUPS ,
584
592
QUERY_EXPAND_DIM ,
585
593
num_sel_kv_blocks ,
594
+ NUM_BLOCKS_PER_SEL ,
586
595
INCLUDE_BLOCK_CAUSAL ,
587
596
sliding
588
597
)
@@ -607,10 +616,6 @@ def native_sparse_attn_forward(
607
616
assert divisible_by (block_size , 16 )
608
617
609
618
num_blocks_per_sel = block_size // 16
610
- if num_blocks_per_sel > 1 :
611
- kv_block_indices = einx .add ('... sel, r -> ... (sel r)' , kv_block_indices * num_blocks_per_sel , arange (num_blocks_per_sel , device = device ))
612
- kv_block_mask = repeat (kv_block_mask , '... sel -> ... (sel r)' , r = num_blocks_per_sel )
613
-
614
619
num_selected_fine_blocks = kv_block_indices .shape [- 1 ]
615
620
assert kv_block_indices .shape == kv_block_mask .shape
616
621
@@ -679,6 +684,7 @@ def native_sparse_attn_forward(
679
684
SEL_BLOCK = block_size ,
680
685
QUERY_HEAD_GROUPS = head_groups ,
681
686
NUM_SEL_KV_BLOCKS = num_selected_fine_blocks ,
687
+ NUM_BLOCKS_PER_SEL = num_blocks_per_sel ,
682
688
INCLUDE_BLOCK_CAUSAL = include_block_causal ,
683
689
RETURN_SLIDING_OUT = return_sliding_window_out ,
684
690
num_warps = num_warps ,
@@ -821,6 +827,8 @@ def backward_kernel_one_col_block_sparse(
821
827
QUERY_EXPAND_DIM : tl .constexpr ,
822
828
RETURN_SEL_GRADS : tl .constexpr ,
823
829
OFF_SEL_KV_BLOCKS : tl .constexpr ,
830
+ NUM_BLOCKS_PER_SEL : tl .constexpr ,
831
+ OFF_BLOCK_PER_SEL : tl .constexpr ,
824
832
BLOCK_DV_USE_DOT : tl .constexpr ,
825
833
BLOCK_DK_USE_DOT : tl .constexpr ,
826
834
):
@@ -962,8 +970,8 @@ def backward_kernel_one_col_block_sparse(
962
970
)
963
971
964
972
blocks_offs_n = (
965
- block_indices [:, None ] * BLOCK +
966
- tl .arange (0 , BLOCK )[None , :]
973
+ block_indices [:, None ] * ( BLOCK * NUM_BLOCKS_PER_SEL ) +
974
+ tl .arange (0 , BLOCK )[None , :] + ( OFF_BLOCK_PER_SEL * BLOCK )
967
975
)
968
976
969
977
block_k_ptrs = (
@@ -1135,8 +1143,6 @@ def backward_kernel_one_col_block_causal(
1135
1143
Q ,
1136
1144
K ,
1137
1145
V ,
1138
- kv_block_indices ,
1139
- kv_block_mask ,
1140
1146
DO ,
1141
1147
DQ ,
1142
1148
DK ,
@@ -1513,6 +1519,7 @@ def backward_kernel(
1513
1519
RETURN_SEL_GRADS : tl .constexpr ,
1514
1520
INCLUDE_BLOCK_CAUSAL : tl .constexpr ,
1515
1521
SLIDING : tl .constexpr ,
1522
+ NUM_BLOCKS_PER_SEL : tl .constexpr ,
1516
1523
BLOCK_DV_USE_DOT : tl .constexpr ,
1517
1524
BLOCK_DK_USE_DOT : tl .constexpr ,
1518
1525
):
@@ -1545,7 +1552,8 @@ def backward_kernel(
1545
1552
lse = SLIDE_LSE
1546
1553
delta = SLIDE_D
1547
1554
1548
- OFF_SEL_KV_BLOCKS = block_id
1555
+ OFF_SEL_KV_BLOCKS = block_id // NUM_BLOCKS_PER_SEL
1556
+ OFF_BLOCK_PER_SEL = block_id % NUM_BLOCKS_PER_SEL
1549
1557
1550
1558
# offset pointers for batch/head
1551
1559
@@ -1585,8 +1593,6 @@ def backward_kernel(
1585
1593
Q ,
1586
1594
K ,
1587
1595
V ,
1588
- kv_block_indices ,
1589
- kv_block_mask ,
1590
1596
do ,
1591
1597
DQ ,
1592
1598
DK ,
@@ -1659,6 +1665,8 @@ def backward_kernel(
1659
1665
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1660
1666
RETURN_SEL_GRADS = RETURN_SEL_GRADS ,
1661
1667
OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS ,
1668
+ NUM_BLOCKS_PER_SEL = NUM_BLOCKS_PER_SEL ,
1669
+ OFF_BLOCK_PER_SEL = OFF_BLOCK_PER_SEL ,
1662
1670
BLOCK_DV_USE_DOT = BLOCK_DV_USE_DOT ,
1663
1671
BLOCK_DK_USE_DOT = BLOCK_DK_USE_DOT ,
1664
1672
)
@@ -1703,11 +1711,6 @@ def native_sparse_attn_backward(
1703
1711
1704
1712
orig_kv_block_grads = kv_block_grads
1705
1713
1706
- if num_blocks_per_sel > 1 :
1707
- kv_block_indices = einx .add ('... sel, r -> ... (sel r)' , kv_block_indices * num_blocks_per_sel , arange (num_blocks_per_sel , device = device ))
1708
- kv_block_mask = repeat (kv_block_mask , '... sel -> ... (sel r)' , r = num_blocks_per_sel )
1709
- kv_block_grads = repeat (kv_block_grads , '... sel -> ... (sel r)' , r = num_blocks_per_sel )
1710
-
1711
1714
num_sel_fine_blocks = kv_block_indices .shape [- 1 ]
1712
1715
assert kv_block_indices .shape == kv_block_mask .shape
1713
1716
@@ -1767,7 +1770,7 @@ def native_sparse_attn_backward(
1767
1770
)
1768
1771
1769
1772
grid = lambda META : (
1770
- int (include_block_causal ) + int (sliding ) + num_sel_fine_blocks ,
1773
+ int (include_block_causal ) + int (sliding ) + ( num_sel_fine_blocks * num_blocks_per_sel ) ,
1771
1774
batch * kv_heads ,
1772
1775
triton .cdiv (seqlen_k , META ['BLOCK' ])
1773
1776
)
@@ -1834,17 +1837,15 @@ def native_sparse_attn_backward(
1834
1837
RETURN_SEL_GRADS = return_sel_grads ,
1835
1838
INCLUDE_BLOCK_CAUSAL = include_block_causal ,
1836
1839
SLIDING = sliding ,
1840
+ NUM_BLOCKS_PER_SEL = num_blocks_per_sel ,
1837
1841
BLOCK_DV_USE_DOT = default (block_dk_dv_use_dot , head_groups > 1 ),
1838
1842
BLOCK_DK_USE_DOT = default (block_dk_dv_use_dot , head_groups > 1 )
1839
1843
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
1840
1844
# num_warps=num_warps,
1841
1845
# num_stages=1,
1842
1846
)
1843
1847
1844
- if num_blocks_per_sel > 1 :
1845
- orig_kv_block_grads .copy_ (reduce (kv_block_grads , '... (sel r) -> ... sel' , 'sum' , r = num_blocks_per_sel ))
1846
-
1847
- return delta
1848
+ return delta , slide_delta
1848
1849
1849
1850
# native sparse attention function
1850
1851
0 commit comments