Skip to content

Commit 2089049

Browse files
committed
make it right for selection block sizes > 16
1 parent d03b9a3 commit 2089049

File tree

2 files changed

+81
-80
lines changed

2 files changed

+81
-80
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 80 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def forward_kernel_causal_and_sparse(
112112
QUERY_HEAD_GROUPS: tl.constexpr,
113113
QUERY_EXPAND_DIM: tl.constexpr,
114114
NUM_SEL_KV_BLOCKS: tl.constexpr,
115+
NUM_BLOCKS_PER_SEL: tl.constexpr,
115116
INCLUDE_BLOCK_CAUSAL: tl.constexpr,
116117
SLIDING: tl.constexpr
117118
):
@@ -346,6 +347,7 @@ def forward_kernel_causal_and_sparse(
346347
q = q.reshape(BLOCK, 16, BLOCK_HEADDIM)
347348

348349
for off_sel_kv_block in range(NUM_SEL_KV_BLOCKS):
350+
349351
block_indices = tl.load(
350352
kv_block_indices_ptrs + off_sel_kv_block,
351353
mask = offs_m < seqlen_q,
@@ -358,86 +360,91 @@ def forward_kernel_causal_and_sparse(
358360
other = False
359361
)
360362

361-
blocks_offs_n = block_indices[:, None] * BLOCK + tl.arange(0, BLOCK)[None, :]
363+
for off_blocks_per_sel in range(NUM_BLOCKS_PER_SEL):
362364

363-
block_k_ptrs = (
364-
K +
365-
off_b * stride_kb +
366-
off_h * stride_kh +
367-
blocks_offs_n[:, :, None] * stride_kn +
368-
offs_d[None, None, :]
369-
)
365+
blocks_offs_n = (
366+
block_indices[:, None] * (BLOCK * NUM_BLOCKS_PER_SEL) +
367+
tl.arange(0, BLOCK)[None, :] + (off_blocks_per_sel * BLOCK)
368+
)
370369

371-
block_v_ptrs = (
372-
V +
373-
off_b * stride_vb +
374-
off_h * stride_vh +
375-
blocks_offs_n[:, :, None] * stride_vn +
376-
offs_d[None, None, :]
377-
)
370+
block_k_ptrs = (
371+
K +
372+
off_b * stride_kb +
373+
off_h * stride_kh +
374+
blocks_offs_n[:, :, None] * stride_kn +
375+
offs_d[None, None, :]
376+
)
378377

379-
# load k of shape (m, n, d), sparsely selected by each query
378+
block_v_ptrs = (
379+
V +
380+
off_b * stride_vb +
381+
off_h * stride_vh +
382+
blocks_offs_n[:, :, None] * stride_vn +
383+
offs_d[None, None, :]
384+
)
380385

381-
k_block = tl.load(
382-
block_k_ptrs,
383-
mask = blocks_offs_n[:, :, None] < seqlen_k,
384-
other = 0.
385-
)
386+
# load k of shape (m, n, d), sparsely selected by each query
386387

387-
# similarities
388+
k_block = tl.load(
389+
block_k_ptrs,
390+
mask = blocks_offs_n[:, :, None] < seqlen_k,
391+
other = 0.
392+
)
388393

389-
block_qk = tl.zeros([BLOCK, 16, BLOCK], dtype = tl.float32)
390-
sel_qk = tl.zeros([BLOCK, QUERY_HEAD_GROUPS, BLOCK], dtype = tl.float32)
394+
# similarities
391395

392-
k_block = k_block.reshape(BLOCK, BLOCK, BLOCK_HEADDIM)
393-
k_block = k_block.permute(0, 2, 1)
396+
block_qk = tl.zeros([BLOCK, 16, BLOCK], dtype = tl.float32)
397+
sel_qk = tl.zeros([BLOCK, QUERY_HEAD_GROUPS, BLOCK], dtype = tl.float32)
394398

395-
block_qk += tl.dot(q, k_block)
396-
block_qk = block_qk.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK)
397-
block_qk = tl.reduce(block_qk, 2, reduce_avg)
399+
k_block = k_block.reshape(BLOCK, BLOCK, BLOCK_HEADDIM)
400+
k_block = k_block.permute(0, 2, 1)
398401

399-
sel_qk += block_qk
400-
sel_qk += tl.where(block_masks[:, None, None], 0, float("-inf"))
402+
block_qk += tl.dot(q, k_block)
403+
block_qk = block_qk.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK)
404+
block_qk = tl.reduce(block_qk, 2, reduce_avg)
401405

402-
# attention
406+
sel_qk += block_qk
407+
sel_qk += tl.where(block_masks[:, None, None], 0, float("-inf"))
403408

404-
m_ij = tl.maximum(tl.max(sel_qk, 2) * softmax_scale, lse_i)
405-
block_p = tl.exp(sel_qk * softmax_scale - m_ij[:, :, None])
409+
# attention
406410

407-
l_ij = tl.sum(block_p, 2)
411+
m_ij = tl.maximum(tl.max(sel_qk, 2) * softmax_scale, lse_i)
412+
block_p = tl.exp(sel_qk * softmax_scale - m_ij[:, :, None])
408413

409-
# renormalize the running output
414+
l_ij = tl.sum(block_p, 2)
410415

411-
acc_o_scale = tl.exp(m_i - m_ij)
412-
acc_o = acc_o * acc_o_scale[:, :, None]
416+
# renormalize the running output
413417

414-
# aggregate values
418+
acc_o_scale = tl.exp(m_i - m_ij)
419+
acc_o = acc_o * acc_o_scale[:, :, None]
415420

416-
v_block = tl.load(
417-
block_v_ptrs,
418-
mask = blocks_offs_n[:, :, None] < seqlen_k,
419-
other = 0.
420-
)
421+
# aggregate values
421422

422-
v_block = tl.reshape(v_block, (BLOCK, BLOCK, BLOCK_HEADDIM))
423+
v_block = tl.load(
424+
block_v_ptrs,
425+
mask = blocks_offs_n[:, :, None] < seqlen_k,
426+
other = 0.
427+
)
423428

424-
block_p = block_p.to(v_block.dtype)
425-
p_expanded = block_p.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
426-
p_expanded = tl.expand_dims(p_expanded, 2)
427-
p_expanded = tl.broadcast_to(p_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK))
428-
p_expanded = p_expanded.reshape(BLOCK, 16, BLOCK)
429+
v_block = tl.reshape(v_block, (BLOCK, BLOCK, BLOCK_HEADDIM))
429430

430-
block_acc_o = tl.dot(p_expanded, v_block)
431-
block_acc_o = block_acc_o.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM)
432-
block_acc_o = tl.reduce(block_acc_o, 2, reduce_avg)
431+
block_p = block_p.to(v_block.dtype)
432+
p_expanded = block_p.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
433+
p_expanded = tl.expand_dims(p_expanded, 2)
434+
p_expanded = tl.broadcast_to(p_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK))
435+
p_expanded = p_expanded.reshape(BLOCK, 16, BLOCK)
433436

434-
acc_o += block_acc_o
437+
block_acc_o = tl.dot(p_expanded, v_block)
438+
block_acc_o = block_acc_o.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM)
439+
block_acc_o = tl.reduce(block_acc_o, 2, reduce_avg)
435440

436-
# -- update statistics
441+
acc_o += block_acc_o
437442

438-
m_i = m_ij
439-
l_i_new = tl.exp(lse_i - m_ij) + l_ij
440-
lse_i = m_ij + tl.log(l_i_new)
443+
# -- update statistics
444+
445+
m_i = m_ij
446+
l_i_new = tl.exp(lse_i - m_ij) + l_ij
447+
lse_i = m_ij + tl.log(l_i_new)
441448

442449
# normalize accumulated out
443450

@@ -528,6 +535,7 @@ def forward_kernel(
528535
QUERY_HEAD_GROUPS: tl.constexpr,
529536
QUERY_EXPAND_DIM: tl.constexpr,
530537
NUM_SEL_KV_BLOCKS: tl.constexpr,
538+
NUM_BLOCKS_PER_SEL: tl.constexpr,
531539
INCLUDE_BLOCK_CAUSAL: tl.constexpr,
532540
RETURN_SLIDING_OUT: tl.constexpr
533541
):
@@ -583,6 +591,7 @@ def forward_kernel(
583591
QUERY_HEAD_GROUPS,
584592
QUERY_EXPAND_DIM,
585593
num_sel_kv_blocks,
594+
NUM_BLOCKS_PER_SEL,
586595
INCLUDE_BLOCK_CAUSAL,
587596
sliding
588597
)
@@ -607,10 +616,6 @@ def native_sparse_attn_forward(
607616
assert divisible_by(block_size, 16)
608617

609618
num_blocks_per_sel = block_size // 16
610-
if num_blocks_per_sel > 1:
611-
kv_block_indices = einx.add('... sel, r -> ... (sel r)', kv_block_indices * num_blocks_per_sel, arange(num_blocks_per_sel, device = device))
612-
kv_block_mask = repeat(kv_block_mask, '... sel -> ... (sel r)', r = num_blocks_per_sel)
613-
614619
num_selected_fine_blocks = kv_block_indices.shape[-1]
615620
assert kv_block_indices.shape == kv_block_mask.shape
616621

@@ -679,6 +684,7 @@ def native_sparse_attn_forward(
679684
SEL_BLOCK = block_size,
680685
QUERY_HEAD_GROUPS = head_groups,
681686
NUM_SEL_KV_BLOCKS = num_selected_fine_blocks,
687+
NUM_BLOCKS_PER_SEL = num_blocks_per_sel,
682688
INCLUDE_BLOCK_CAUSAL = include_block_causal,
683689
RETURN_SLIDING_OUT = return_sliding_window_out,
684690
num_warps = num_warps,
@@ -821,6 +827,8 @@ def backward_kernel_one_col_block_sparse(
821827
QUERY_EXPAND_DIM: tl.constexpr,
822828
RETURN_SEL_GRADS: tl.constexpr,
823829
OFF_SEL_KV_BLOCKS: tl.constexpr,
830+
NUM_BLOCKS_PER_SEL: tl.constexpr,
831+
OFF_BLOCK_PER_SEL: tl.constexpr,
824832
BLOCK_DV_USE_DOT: tl.constexpr,
825833
BLOCK_DK_USE_DOT: tl.constexpr,
826834
):
@@ -962,8 +970,8 @@ def backward_kernel_one_col_block_sparse(
962970
)
963971

964972
blocks_offs_n = (
965-
block_indices[:, None] * BLOCK +
966-
tl.arange(0, BLOCK)[None, :]
973+
block_indices[:, None] * (BLOCK * NUM_BLOCKS_PER_SEL) +
974+
tl.arange(0, BLOCK)[None, :] + (OFF_BLOCK_PER_SEL * BLOCK)
967975
)
968976

969977
block_k_ptrs = (
@@ -1135,8 +1143,6 @@ def backward_kernel_one_col_block_causal(
11351143
Q,
11361144
K,
11371145
V,
1138-
kv_block_indices,
1139-
kv_block_mask,
11401146
DO,
11411147
DQ,
11421148
DK,
@@ -1513,6 +1519,7 @@ def backward_kernel(
15131519
RETURN_SEL_GRADS: tl.constexpr,
15141520
INCLUDE_BLOCK_CAUSAL: tl.constexpr,
15151521
SLIDING: tl.constexpr,
1522+
NUM_BLOCKS_PER_SEL: tl.constexpr,
15161523
BLOCK_DV_USE_DOT: tl.constexpr,
15171524
BLOCK_DK_USE_DOT: tl.constexpr,
15181525
):
@@ -1545,7 +1552,8 @@ def backward_kernel(
15451552
lse = SLIDE_LSE
15461553
delta = SLIDE_D
15471554

1548-
OFF_SEL_KV_BLOCKS = block_id
1555+
OFF_SEL_KV_BLOCKS = block_id // NUM_BLOCKS_PER_SEL
1556+
OFF_BLOCK_PER_SEL = block_id % NUM_BLOCKS_PER_SEL
15491557

15501558
# offset pointers for batch/head
15511559

@@ -1585,8 +1593,6 @@ def backward_kernel(
15851593
Q,
15861594
K,
15871595
V,
1588-
kv_block_indices,
1589-
kv_block_mask,
15901596
do,
15911597
DQ,
15921598
DK,
@@ -1659,6 +1665,8 @@ def backward_kernel(
16591665
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM,
16601666
RETURN_SEL_GRADS = RETURN_SEL_GRADS,
16611667
OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS,
1668+
NUM_BLOCKS_PER_SEL = NUM_BLOCKS_PER_SEL,
1669+
OFF_BLOCK_PER_SEL = OFF_BLOCK_PER_SEL,
16621670
BLOCK_DV_USE_DOT = BLOCK_DV_USE_DOT,
16631671
BLOCK_DK_USE_DOT = BLOCK_DK_USE_DOT,
16641672
)
@@ -1703,11 +1711,6 @@ def native_sparse_attn_backward(
17031711

17041712
orig_kv_block_grads = kv_block_grads
17051713

1706-
if num_blocks_per_sel > 1:
1707-
kv_block_indices = einx.add('... sel, r -> ... (sel r)', kv_block_indices * num_blocks_per_sel, arange(num_blocks_per_sel, device = device))
1708-
kv_block_mask = repeat(kv_block_mask, '... sel -> ... (sel r)', r = num_blocks_per_sel)
1709-
kv_block_grads = repeat(kv_block_grads, '... sel -> ... (sel r)', r = num_blocks_per_sel)
1710-
17111714
num_sel_fine_blocks = kv_block_indices.shape[-1]
17121715
assert kv_block_indices.shape == kv_block_mask.shape
17131716

@@ -1767,7 +1770,7 @@ def native_sparse_attn_backward(
17671770
)
17681771

17691772
grid = lambda META: (
1770-
int(include_block_causal) + int(sliding) + num_sel_fine_blocks,
1773+
int(include_block_causal) + int(sliding) + (num_sel_fine_blocks * num_blocks_per_sel),
17711774
batch * kv_heads,
17721775
triton.cdiv(seqlen_k, META['BLOCK'])
17731776
)
@@ -1834,17 +1837,15 @@ def native_sparse_attn_backward(
18341837
RETURN_SEL_GRADS = return_sel_grads,
18351838
INCLUDE_BLOCK_CAUSAL = include_block_causal,
18361839
SLIDING = sliding,
1840+
NUM_BLOCKS_PER_SEL = num_blocks_per_sel,
18371841
BLOCK_DV_USE_DOT = default(block_dk_dv_use_dot, head_groups > 1),
18381842
BLOCK_DK_USE_DOT = default(block_dk_dv_use_dot, head_groups > 1)
18391843
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
18401844
# num_warps=num_warps,
18411845
# num_stages=1,
18421846
)
18431847

1844-
if num_blocks_per_sel > 1:
1845-
orig_kv_block_grads.copy_(reduce(kv_block_grads, '... (sel r) -> ... sel', 'sum', r = num_blocks_per_sel))
1846-
1847-
return delta
1848+
return delta, slide_delta
18481849

18491850
# native sparse attention function
18501851

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.1.24"
3+
version = "0.1.25"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)