@@ -431,11 +431,12 @@ def forward_kernel_causal_and_sparse(
431
431
)
432
432
433
433
@triton .heuristics (
434
- {
435
- "EVEN_M" : lambda args : divisible_by (args ["seqlen_q" ], args ["BLOCK" ]),
436
- "EVEN_N" : lambda args : divisible_by (args ["seqlen_k" ], args ["BLOCK" ]),
437
- "EVEN_HEADDIM" : lambda args : args ["headdim" ] == args ["BLOCK_HEADDIM" ],
438
- }
434
+ dict (
435
+ EVEN_M = lambda args : divisible_by (args ["seqlen_q" ], args ["BLOCK" ]),
436
+ EVEN_N = lambda args : divisible_by (args ["seqlen_k" ], args ["BLOCK" ]),
437
+ EVEN_HEADDIM = lambda args : args ["headdim" ] == args ["BLOCK_HEADDIM" ],
438
+ QUERY_EXPAND_DIM = lambda args : 16 // args ['QUERY_HEAD_GROUPS' ]
439
+ )
439
440
)
440
441
@triton .jit
441
442
def forward_kernel (
@@ -599,7 +600,6 @@ def native_sparse_attn_forward(
599
600
BLOCK_HEADDIM ,
600
601
BLOCK = block_size ,
601
602
QUERY_HEAD_GROUPS = head_groups ,
602
- QUERY_EXPAND_DIM = 16 // head_groups ,
603
603
NUM_SEL_KV_BLOCKS = num_selected_fine_blocks ,
604
604
INCLUDE_BLOCK_CAUSAL = include_block_causal ,
605
605
SLIDING = False ,
@@ -1275,6 +1275,11 @@ def backward_kernel_one_col_block_causal(
1275
1275
EVEN_HEADDIM = EVEN_HEADDIM ,
1276
1276
)
1277
1277
1278
+ @triton .heuristics (
1279
+ dict (
1280
+ QUERY_EXPAND_DIM = lambda args : 16 // args ['QUERY_HEAD_GROUPS' ]
1281
+ )
1282
+ )
1278
1283
@triton .jit
1279
1284
def backward_kernel (
1280
1285
Q ,
@@ -1576,7 +1581,6 @@ def native_sparse_attn_backward(
1576
1581
BLOCK_HEADDIM ,
1577
1582
BLOCK = block_size ,
1578
1583
QUERY_HEAD_GROUPS = head_groups ,
1579
- QUERY_EXPAND_DIM = 16 // head_groups ,
1580
1584
EVEN_M = divisible_by (seqlen_q , block_size ),
1581
1585
EVEN_N = divisible_by (seqlen_k , block_size ),
1582
1586
EVEN_HEADDIM = BLOCK_HEADDIM == dim ,
0 commit comments