Skip to content

Commit 8fe8317

Browse files
committed
address #31 q heads to kv head ratio is high (gqa)
1 parent 7ebebca commit 8fe8317

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,10 +1064,12 @@ def backward_kernel_one_col_block_sparse(
10641064
offs_m * stride_kvbl_m
10651065
)
10661066

1067-
sel_grads = ds * qk
1067+
sel_grads = ds * qk # (q block, q head group, k block)
1068+
10681069
sel_grads = tl.where(block_masks[:, None, None], sel_grads, 0.)
1069-
sel_grads = sel_grads.reshape(BLOCK, QUERY_HEAD_GROUPS * BLOCK)
1070-
sel_grads = tl.sum(sel_grads, 1)
1070+
1071+
sel_grads = tl.sum(sel_grads, 2) # for k block
1072+
sel_grads = tl.sum(sel_grads, 1) # for q head groups
10711073

10721074
tl.atomic_add(
10731075
kv_block_grads_ptrs + OFF_SEL_KV_BLOCKS,
@@ -1678,8 +1680,6 @@ def native_sparse_attn_backward(
16781680

16791681
num_blocks_per_sel = block_size // 16
16801682

1681-
orig_kv_block_grads = kv_block_grads
1682-
16831683
num_sel_fine_blocks = kv_block_indices.shape[-1]
16841684
assert kv_block_indices.shape == kv_block_mask.shape
16851685

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.2.1"
3+
version = "0.2.2"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)