Skip to content

Commit 9a038d1

Browse files
committed
small cleanup
1 parent 97aa4ae commit 9a038d1

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -431,11 +431,12 @@ def forward_kernel_causal_and_sparse(
431431
)
432432

433433
@triton.heuristics(
434-
{
435-
"EVEN_M": lambda args: divisible_by(args["seqlen_q"], args["BLOCK"]),
436-
"EVEN_N": lambda args: divisible_by(args["seqlen_k"], args["BLOCK"]),
437-
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
438-
}
434+
dict(
435+
EVEN_M = lambda args: divisible_by(args["seqlen_q"], args["BLOCK"]),
436+
EVEN_N = lambda args: divisible_by(args["seqlen_k"], args["BLOCK"]),
437+
EVEN_HEADDIM = lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
438+
QUERY_EXPAND_DIM = lambda args: 16 // args['QUERY_HEAD_GROUPS']
439+
)
439440
)
440441
@triton.jit
441442
def forward_kernel(
@@ -599,7 +600,6 @@ def native_sparse_attn_forward(
599600
BLOCK_HEADDIM,
600601
BLOCK = block_size,
601602
QUERY_HEAD_GROUPS = head_groups,
602-
QUERY_EXPAND_DIM = 16 // head_groups,
603603
NUM_SEL_KV_BLOCKS = num_selected_fine_blocks,
604604
INCLUDE_BLOCK_CAUSAL = include_block_causal,
605605
SLIDING = False,
@@ -1275,6 +1275,11 @@ def backward_kernel_one_col_block_causal(
12751275
EVEN_HEADDIM = EVEN_HEADDIM,
12761276
)
12771277

1278+
@triton.heuristics(
1279+
dict(
1280+
QUERY_EXPAND_DIM = lambda args: 16 // args['QUERY_HEAD_GROUPS']
1281+
)
1282+
)
12781283
@triton.jit
12791284
def backward_kernel(
12801285
Q,
@@ -1576,7 +1581,6 @@ def native_sparse_attn_backward(
15761581
BLOCK_HEADDIM,
15771582
BLOCK = block_size,
15781583
QUERY_HEAD_GROUPS = head_groups,
1579-
QUERY_EXPAND_DIM = 16 // head_groups,
15801584
EVEN_M = divisible_by(seqlen_q, block_size),
15811585
EVEN_N = divisible_by(seqlen_k, block_size),
15821586
EVEN_HEADDIM = BLOCK_HEADDIM == dim,

0 commit comments

Comments
 (0)