@@ -587,7 +587,7 @@ def backward_kernel_one_col_block_sparse(
587
587
QUERY_EXPAND_DIM : tl .constexpr ,
588
588
OFF_SEL_KV_BLOCKS : tl .constexpr
589
589
):
590
- # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
590
+ # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
591
591
592
592
begin_m = ((start_n * BLOCK ) // BLOCK ) * BLOCK
593
593
@@ -609,73 +609,32 @@ def backward_kernel_one_col_block_sparse(
609
609
610
610
q_ptrs = (
611
611
Q +
612
- offs_g [:, None , None ] * stride_qh +
613
- offs_qm [None , : , None ] * stride_qm +
612
+ offs_g [None , : , None ] * stride_qh +
613
+ offs_qm [:, None , None ] * stride_qm +
614
614
offs_d [None , None , :]
615
615
)
616
616
617
617
do_ptrs = (
618
618
DO +
619
- offs_g [:, None , None ] * stride_doh +
620
- offs_qm [None , : , None ] * stride_dom +
619
+ offs_g [None , : , None ] * stride_doh +
620
+ offs_qm [:, None , None ] * stride_dom +
621
621
offs_d [None , None , :]
622
622
)
623
623
624
624
dq_ptrs = (
625
625
DQ +
626
- offs_g [:, None , None ] * stride_dqh +
627
- offs_qm [None , : , None ] * stride_dqm +
626
+ offs_g [None , : , None ] * stride_dqh +
627
+ offs_qm [:, None , None ] * stride_dqm +
628
628
offs_d [None , None , :]
629
629
)
630
630
631
- # initialize dv and dk
632
-
633
- dv = tl .zeros ([BLOCK , BLOCK_HEADDIM ], dtype = tl .float32 )
634
- dk = tl .zeros ([BLOCK , BLOCK_HEADDIM ], dtype = tl .float32 )
635
-
636
631
# There seems to be some problem with Triton pipelining that makes results wrong for
637
632
# headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
638
633
# may have zero step, and pipelining with the bias matrix could screw it up.
639
634
# So we just exit early.
640
635
641
636
if begin_m >= seqlen_q :
642
- dv_ptrs = DV + (offs_n [:, None ] * stride_dvn + offs_d [None , :])
643
- dk_ptrs = DK + (offs_n [:, None ] * stride_dkn + offs_d [None , :])
644
- backward_store_dk_dv (
645
- dk_ptrs ,
646
- dv_ptrs ,
647
- dk ,
648
- dv ,
649
- offs_n ,
650
- offs_d ,
651
- seqlen_k ,
652
- headdim ,
653
- EVEN_M = EVEN_M ,
654
- EVEN_N = EVEN_N ,
655
- EVEN_HEADDIM = EVEN_HEADDIM ,
656
- )
657
637
return
658
- # k and v stay in SRAM throughout
659
- # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
660
- # if we just call tl.load(k_ptrs), we get the wrong output!
661
- if EVEN_N & EVEN_M :
662
- if EVEN_HEADDIM :
663
- k = tl .load (k_ptrs )
664
- v = tl .load (v_ptrs )
665
- else :
666
- k = tl .load (k_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
667
- v = tl .load (v_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
668
- else :
669
- if EVEN_HEADDIM :
670
- k = tl .load (k_ptrs , mask = offs_n [:, None ] < seqlen_k , other = 0.0 )
671
- v = tl .load (v_ptrs , mask = offs_n [:, None ] < seqlen_k , other = 0.0 )
672
- else :
673
- k = tl .load (
674
- k_ptrs , mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ), other = 0.0
675
- )
676
- v = tl .load (
677
- v_ptrs , mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ), other = 0.0
678
- )
679
638
680
639
# same block for block causal diagonal
681
640
@@ -685,28 +644,17 @@ def backward_kernel_one_col_block_sparse(
685
644
q = tl .load (q_ptrs )
686
645
else :
687
646
if EVEN_HEADDIM :
688
- q = tl .load (q_ptrs , mask = offs_m [None , :, None ] < seqlen_q , other = 0.0 )
647
+ q = tl .load (
648
+ q_ptrs ,
649
+ mask = offs_m [:, None , None ] < seqlen_q ,
650
+ other = 0.0
651
+ )
689
652
else :
690
653
q = tl .load (
691
654
q_ptrs ,
692
- mask = (offs_m [None , : , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
693
- other = 0.0 ,
655
+ mask = (offs_m [:, None , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
656
+ other = 0.0 ,
694
657
)
695
- # recompute p = softmax(qk, dim=-1).T
696
-
697
- q = q .reshape ([QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM ])
698
-
699
- qk = tl .dot (q , tl .trans (k ))
700
-
701
- qk = qk .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK )
702
-
703
- # Trying to combine the two masks seem to make the result wrong
704
- if not EVEN_N : # Need to mask out otherwise the softmax is wrong
705
- qk = tl .where (offs_n [None , :] < seqlen_k , qk , float ("-inf" ))
706
-
707
- qk = tl .where (offs_m [:, None ] >= (offs_n [None , :]), qk , float ("-inf" ))
708
-
709
- qk = qk .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK )
710
658
711
659
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
712
660
# Also wrong for headdim=64.
@@ -715,9 +663,7 @@ def backward_kernel_one_col_block_sparse(
715
663
tl .debug_barrier ()
716
664
717
665
lse_i = tl .load (LSE + offs_d_or_lse )
718
- lse_i = lse_i .reshape (QUERY_HEAD_GROUPS * BLOCK )
719
-
720
- p = tl .exp (qk * softmax_scale - lse_i [:, None ])
666
+ lse_i = tl .trans (lse_i ) # (m, h)
721
667
722
668
# compute dv
723
669
# [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
@@ -730,12 +676,10 @@ def backward_kernel_one_col_block_sparse(
730
676
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
731
677
do = tl .load (
732
678
do_ptrs ,
733
- mask = (offs_m [None , : , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
734
- other = 0.0 ,
679
+ mask = (offs_m [:, None , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
680
+ other = 0.0 ,
735
681
)
736
682
737
- do = do .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM )
738
-
739
683
# compute dp = dot(v, do)
740
684
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
741
685
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
@@ -752,12 +696,12 @@ def backward_kernel_one_col_block_sparse(
752
696
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
753
697
754
698
Di = tl .load (D + offs_d_or_lse )
755
- Di = Di . reshape ( QUERY_HEAD_GROUPS * BLOCK )
699
+ Di = tl . trans ( Di ) # (m, h )
756
700
757
701
# Converting ds to q.dtype here reduces register pressure and makes it much faster
758
702
# for BLOCK_HEADDIM=128
759
703
760
- dq = tl .zeros ([QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM ], dtype = tl .float32 )
704
+ dq = tl .zeros ([BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM ], dtype = tl .float32 )
761
705
762
706
# handle kv block indices using atomic adds for starters, todo: swap dq and dk/dv loops at some point, semi big refactor
763
707
@@ -774,30 +718,39 @@ def backward_kernel_one_col_block_sparse(
774
718
block_indices = tl .load (kv_block_indices_ptrs + OFF_SEL_KV_BLOCKS )
775
719
block_masks = tl .load (kv_block_mask_ptrs + OFF_SEL_KV_BLOCKS )
776
720
777
- blocks_offs_n = block_indices [:, None ] * BLOCK + tl .arange (0 , BLOCK )[None , :]
721
+ blocks_offs_n = (
722
+ block_indices [:, None ] * BLOCK +
723
+ tl .arange (0 , BLOCK )[None , :]
724
+ )
778
725
779
726
block_k_ptrs = (
780
- K + blocks_offs_n [:, :, None ] * stride_kn + offs_d [None , None , :]
727
+ K +
728
+ blocks_offs_n [:, :, None ] * stride_kn +
729
+ offs_d [None , None , :]
781
730
)
782
731
783
732
block_v_ptrs = (
784
- V + blocks_offs_n [:, :, None ] * stride_vn + offs_d [None , None , :]
733
+ V +
734
+ blocks_offs_n [:, :, None ] * stride_vn +
735
+ offs_d [None , None , :]
785
736
)
786
737
787
738
block_dv_ptrs = (
788
- DV + blocks_offs_n [:, :, None ] * stride_dvn + offs_d [None , None , :]
739
+ DV +
740
+ blocks_offs_n [:, :, None ] * stride_dvn +
741
+ offs_d [None , None , :]
789
742
)
790
743
791
744
block_dk_ptrs = (
792
- DK + blocks_offs_n [:, :, None ] * stride_dkn + offs_d [None , None , :]
745
+ DK +
746
+ blocks_offs_n [:, :, None ] * stride_dkn +
747
+ offs_d [None , None , :]
793
748
)
794
749
795
750
block_k = tl .load (block_k_ptrs )
796
751
block_v = tl .load (block_v_ptrs )
797
752
798
- q_expanded = q .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK_HEADDIM )
799
- q_expanded = q_expanded .permute (1 , 0 , 2 )
800
- q_expanded = tl .expand_dims (q_expanded , 2 )
753
+ q_expanded = tl .expand_dims (q , 2 )
801
754
q_expanded = tl .broadcast_to (q_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
802
755
q_expanded = q_expanded .reshape (BLOCK , 16 , BLOCK_HEADDIM )
803
756
@@ -806,84 +759,77 @@ def backward_kernel_one_col_block_sparse(
806
759
807
760
block_qk = block_qk .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK )
808
761
qk = tl .sum (block_qk , 2 ) / QUERY_EXPAND_DIM
809
- qk = qk .permute (1 , 0 , 2 )
810
-
811
- qk += tl .where (block_masks [None , :, None ], 0 , float ("-inf" ))
812
762
813
- qk = qk . reshape ( QUERY_HEAD_GROUPS * BLOCK , BLOCK )
763
+ qk += tl . where ( block_masks [:, None , None ], 0 , float ( "-inf" ) )
814
764
815
- p = tl .exp (qk * softmax_scale - lse_i [:, None ])
765
+ p = tl .exp (qk * softmax_scale - lse_i [:, :, None ])
816
766
817
767
# take care of block dv
818
768
819
- block_dv = p .to (do .dtype )[:, :, None ] * do [:, None , :]
769
+ block_dv = p .to (do .dtype )[:, :, :, None ] * do [:, :, None , :]
820
770
821
- block_dv = block_dv .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK , BLOCK_HEADDIM )
822
- block_dv = tl .sum (block_dv , 0 )
771
+ # block_dv = block_dv.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK, BLOCK_HEADDIM)
772
+ block_dv = tl .sum (block_dv , 1 )
823
773
824
774
tl .atomic_add (block_dv_ptrs , block_dv , sem = 'relaxed' )
825
775
826
776
# get dp
827
777
828
- do_expanded = do .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK_HEADDIM )
829
- do_expanded = do_expanded .permute (1 , 0 , 2 )
830
- do_expanded = tl .expand_dims (do_expanded , 2 )
831
- do_expanded = tl .broadcast_to (do_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
832
- do_expanded = do_expanded .reshape (BLOCK , 16 , BLOCK_HEADDIM )
778
+ do = tl .expand_dims (do , 2 )
779
+ do = tl .broadcast_to (do , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
780
+ do = do .reshape (BLOCK , 16 , BLOCK_HEADDIM )
833
781
834
782
block_v = tl .permute (block_v , (0 , 2 , 1 ))
835
783
836
- dp = tl .dot (do_expanded , block_v )
784
+ dp = tl .dot (do , block_v )
837
785
838
786
dp = dp .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK )
839
787
dp = tl .sum (dp , 2 ) / QUERY_EXPAND_DIM
840
- dp = dp .permute (1 , 0 , 2 )
841
- dp = dp .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK )
842
788
843
789
# ds
844
790
845
- ds = (p * (dp - Di [:, None ]) * softmax_scale )
791
+ ds = (p * (dp - Di [:, :, None ]) * softmax_scale )
846
792
ds = ds .to (q .dtype )
847
793
848
794
# block dk
849
795
850
- block_dk = ds [:, :, None ] * q [:, None , :]
851
- block_dk = block_dk .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK , BLOCK_HEADDIM )
852
- block_dk = tl .sum (block_dk , 0 )
796
+ block_dk = ds [:, :, :, None ] * q [:, :, None , :]
797
+ block_dk = tl .sum (block_dk , 1 )
853
798
854
799
tl .atomic_add (block_dk_ptrs , block_dk , sem = 'relaxed' )
855
800
856
801
# block dq
857
802
858
- ds_expanded = ds .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK )
859
- ds_expanded = ds_expanded .permute (1 , 0 , 2 )
860
- ds_expanded = tl .expand_dims (ds_expanded , 2 )
803
+ ds_expanded = tl .expand_dims (ds , 2 )
861
804
ds_expanded = tl .broadcast_to (ds_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK ))
862
805
ds_expanded = ds_expanded .reshape (BLOCK , 16 , BLOCK )
863
806
864
807
block_dq = tl .dot (ds_expanded , block_k )
865
808
866
809
block_dq = block_dq .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM )
867
810
block_dq = tl .sum (block_dq , 2 ) / QUERY_EXPAND_DIM
868
- block_dq = block_dq .permute (1 , 0 , 2 )
869
- block_dq = block_dq .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM )
870
811
871
812
dq += block_dq
872
813
873
814
# update dq
874
815
875
- dq = dq .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK_HEADDIM )
816
+ dq = dq .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
876
817
877
818
if EVEN_M & EVEN_HEADDIM : # Race condition if we just do EVEN_M
878
819
tl .atomic_add (dq_ptrs , dq , sem = 'relaxed' )
879
820
else :
880
821
if EVEN_HEADDIM :
881
- tl .atomic_add (dq_ptrs , dq , mask = offs_m [None , :, None ] < seqlen_q , sem = 'relaxed' )
822
+ tl .atomic_add (
823
+ dq_ptrs ,
824
+ dq ,
825
+ mask = offs_m [:, None , None ] < seqlen_q ,
826
+ sem = 'relaxed'
827
+ )
882
828
else :
883
829
tl .atomic_add (
884
830
dq_ptrs ,
885
831
dq ,
886
- mask = (offs_m [None , : , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
832
+ mask = (offs_m [:, None , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
887
833
sem = 'relaxed' ,
888
834
)
889
835
0 commit comments