Skip to content

Commit 06f2eae

Browse files
committed
allow for block causal to be turned off for the triton kernel, prepping for the encoder variant
1 parent da547e7 commit 06f2eae

File tree

2 files changed

+108
-87
lines changed

2 files changed

+108
-87
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 107 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def forward_kernel(
112112
BLOCK: tl.constexpr,
113113
QUERY_HEAD_GROUPS: tl.constexpr,
114114
QUERY_EXPAND_DIM: tl.constexpr,
115-
NUM_SEL_KV_BLOCKS: tl.constexpr
115+
NUM_SEL_KV_BLOCKS: tl.constexpr,
116+
INCLUDE_BLOCK_CAUSAL: tl.constexpr
116117
):
117118
start_m = tl.program_id(0)
118119
off_hb = tl.program_id(1)
@@ -134,22 +135,6 @@ def forward_kernel(
134135
offs_d[None, None, :]
135136
)
136137

137-
k_ptrs = (
138-
K +
139-
off_b * stride_kb +
140-
off_h * stride_kh +
141-
offs_n[:, None] * stride_kn +
142-
offs_d[None, :]
143-
)
144-
145-
v_ptrs = (
146-
V +
147-
off_b * stride_vb +
148-
off_h * stride_vh +
149-
offs_n[:, None] * stride_vn +
150-
offs_d[None, :]
151-
)
152-
153138
# maximum
154139

155140
m_i = tl.zeros([BLOCK, QUERY_HEAD_GROUPS], dtype = tl.float32) - float("inf")
@@ -202,82 +187,99 @@ def forward_kernel(
202187
other = 0.0
203188
)
204189

205-
if EVEN_N & EVEN_M:
206-
if EVEN_HEADDIM:
207-
k = tl.load(k_ptrs)
208-
else:
209-
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
210-
else:
211-
if EVEN_HEADDIM:
212-
k = tl.load(
213-
k_ptrs,
214-
mask = offs_n[:, None] < seqlen_k,
215-
other = 0.0,
216-
)
217-
else:
218-
k = tl.load(
219-
k_ptrs,
220-
mask = (offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
221-
other = 0.0,
222-
)
190+
if INCLUDE_BLOCK_CAUSAL:
191+
k_ptrs = (
192+
K +
193+
off_b * stride_kb +
194+
off_h * stride_kh +
195+
offs_n[:, None] * stride_kn +
196+
offs_d[None, :]
197+
)
223198

224-
qk = tl.zeros([BLOCK * QUERY_HEAD_GROUPS, BLOCK], dtype=tl.float32)
199+
v_ptrs = (
200+
V +
201+
off_b * stride_vb +
202+
off_h * stride_vh +
203+
offs_n[:, None] * stride_vn +
204+
offs_d[None, :]
205+
)
225206

226-
q = q.reshape(BLOCK * QUERY_HEAD_GROUPS, BLOCK_HEADDIM)
207+
if EVEN_N & EVEN_M:
208+
if EVEN_HEADDIM:
209+
k = tl.load(k_ptrs)
210+
else:
211+
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
212+
else:
213+
if EVEN_HEADDIM:
214+
k = tl.load(
215+
k_ptrs,
216+
mask = offs_n[:, None] < seqlen_k,
217+
other = 0.0,
218+
)
219+
else:
220+
k = tl.load(
221+
k_ptrs,
222+
mask = (offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
223+
other = 0.0,
224+
)
227225

228-
qk += tl.dot(q, tl.trans(k))
226+
qk = tl.zeros([BLOCK * QUERY_HEAD_GROUPS, BLOCK], dtype=tl.float32)
229227

230-
qk = qk.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
228+
q = q.reshape(BLOCK * QUERY_HEAD_GROUPS, BLOCK_HEADDIM)
231229

232-
if not EVEN_N:
233-
qk += tl.where(offs_n[None, :] < seqlen_k, 0, float("-inf"))
230+
qk += tl.dot(q, tl.trans(k))
234231

235-
qk = qk.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
232+
qk = qk.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
236233

237-
qk += tl.where(offs_m[:, None, None] >= offs_n[None, None, :], 0, float("-inf"))
234+
if not EVEN_N:
235+
qk += tl.where(offs_n[None, :] < seqlen_k, 0, float("-inf"))
238236

239-
m_ij = tl.maximum(tl.max(qk, 2) * softmax_scale, lse_i)
240-
p = tl.exp(qk * softmax_scale - m_ij[:, :, None])
237+
qk = qk.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
241238

242-
l_ij = tl.sum(p, 2)
239+
qk += tl.where(offs_m[:, None, None] >= offs_n[None, None, :], 0, float("-inf"))
243240

244-
acc_o_scale = tl.exp(m_i - m_ij)
245-
acc_o *= acc_o_scale[:, :, None]
241+
m_ij = tl.maximum(tl.max(qk, 2) * softmax_scale, lse_i)
242+
p = tl.exp(qk * softmax_scale - m_ij[:, :, None])
246243

247-
if EVEN_N & EVEN_M:
248-
if EVEN_HEADDIM:
249-
v = tl.load(v_ptrs)
250-
else:
251-
v = tl.load(
252-
v_ptrs,
253-
mask = offs_d[None, :] < headdim,
254-
other = 0.0
255-
)
256-
else:
257-
if EVEN_HEADDIM:
258-
v = tl.load(
259-
v_ptrs,
260-
mask = offs_n[:, None] < seqlen_k,
261-
other = 0.0,
262-
)
244+
l_ij = tl.sum(p, 2)
245+
246+
acc_o_scale = tl.exp(m_i - m_ij)
247+
acc_o *= acc_o_scale[:, :, None]
248+
249+
if EVEN_N & EVEN_M:
250+
if EVEN_HEADDIM:
251+
v = tl.load(v_ptrs)
252+
else:
253+
v = tl.load(
254+
v_ptrs,
255+
mask = offs_d[None, :] < headdim,
256+
other = 0.0
257+
)
263258
else:
264-
v = tl.load(
265-
v_ptrs,
266-
mask = (offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
267-
other = 0.0,
268-
)
259+
if EVEN_HEADDIM:
260+
v = tl.load(
261+
v_ptrs,
262+
mask = offs_n[:, None] < seqlen_k,
263+
other = 0.0,
264+
)
265+
else:
266+
v = tl.load(
267+
v_ptrs,
268+
mask = (offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
269+
other = 0.0,
270+
)
269271

270-
p = p.reshape(BLOCK * QUERY_HEAD_GROUPS, BLOCK).to(v.dtype)
272+
p = p.reshape(BLOCK * QUERY_HEAD_GROUPS, BLOCK).to(v.dtype)
271273

272-
causal_o = tl.dot(p, v)
274+
causal_o = tl.dot(p, v)
273275

274-
acc_o += causal_o.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK_HEADDIM)
276+
acc_o += causal_o.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK_HEADDIM)
275277

276-
# -- update statistics
278+
# -- update statistics
277279

278-
m_i = m_ij
279-
l_i_new = tl.exp(lse_i - m_ij) + l_ij
280-
lse_i = m_ij + tl.log(l_i_new)
280+
m_i = m_ij
281+
l_i_new = tl.exp(lse_i - m_ij) + l_ij
282+
lse_i = m_ij + tl.log(l_i_new)
281283

282284
# # take care of the selected kv blocks
283285

@@ -419,7 +421,8 @@ def native_sparse_attn_forward(
419421
v,
420422
kv_block_indices,
421423
kv_block_mask,
422-
block_size = 128
424+
block_size = 128,
425+
include_block_causal = True
423426
):
424427
q, k, v, kv_block_indices = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v, kv_block_indices)]
425428

@@ -488,6 +491,7 @@ def native_sparse_attn_forward(
488491
QUERY_HEAD_GROUPS = head_groups,
489492
QUERY_EXPAND_DIM = 16 // head_groups,
490493
NUM_SEL_KV_BLOCKS = num_selected_fine_blocks,
494+
INCLUDE_BLOCK_CAUSAL = include_block_causal,
491495
num_warps = num_warps,
492496
num_stages = 1,
493497
)
@@ -1184,14 +1188,19 @@ def backward_kernel(
11841188
BLOCK: tl.constexpr,
11851189
QUERY_HEAD_GROUPS: tl.constexpr,
11861190
QUERY_EXPAND_DIM: tl.constexpr,
1191+
INCLUDE_BLOCK_CAUSAL: tl.constexpr
11871192
):
11881193
off_hb = tl.program_id(1)
11891194
off_b = off_hb // kv_heads
11901195
off_h = off_hb % kv_heads
11911196
off_qh = off_h * QUERY_HEAD_GROUPS
11921197

1193-
IS_CAUSAL = tl.program_id(0) == 0
1194-
OFF_SEL_KV_BLOCKS = tl.program_id(0) - 1
1198+
if INCLUDE_BLOCK_CAUSAL:
1199+
IS_CAUSAL = tl.program_id(0) == 0
1200+
OFF_SEL_KV_BLOCKS = tl.program_id(0) - 1
1201+
else:
1202+
IS_CAUSAL = False
1203+
OFF_SEL_KV_BLOCKS = tl.program_id(0)
11951204

11961205
# offset pointers for batch/head
11971206

@@ -1310,7 +1319,8 @@ def native_sparse_attn_backward(
13101319
o,
13111320
lse,
13121321
dq, dk, dv,
1313-
block_size = 128
1322+
block_size = 128,
1323+
include_block_causal = True
13141324
):
13151325
device = do.device
13161326

@@ -1362,7 +1372,10 @@ def native_sparse_attn_backward(
13621372
BLOCK_HEADDIM = BLOCK_HEADDIM,
13631373
)
13641374

1365-
grid = lambda META: (num_sel_fine_blocks + 1, batch * kv_heads)
1375+
grid = lambda META: (
1376+
num_sel_fine_blocks + int(include_block_causal),
1377+
batch * kv_heads
1378+
)
13661379

13671380
backward_kernel[grid](
13681381
q,
@@ -1418,7 +1431,8 @@ def native_sparse_attn_backward(
14181431
QUERY_EXPAND_DIM = 16 // head_groups,
14191432
EVEN_M = divisible_by(seqlen_q, block_size),
14201433
EVEN_N = divisible_by(seqlen_k, block_size),
1421-
EVEN_HEADDIM = BLOCK_HEADDIM == dim
1434+
EVEN_HEADDIM = BLOCK_HEADDIM == dim,
1435+
INCLUDE_BLOCK_CAUSAL = include_block_causal,
14221436
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
14231437
# num_warps=num_warps,
14241438
# num_stages=1,
@@ -1440,6 +1454,7 @@ def forward(
14401454
block_size,
14411455
selected_block_indices,
14421456
fmask,
1457+
include_block_causal
14431458
):
14441459
dtype = fq.dtype
14451460

@@ -1453,14 +1468,16 @@ def forward(
14531468
fq, fk, fv,
14541469
selected_block_indices,
14551470
fmask,
1456-
block_size = block_size
1471+
block_size = block_size,
1472+
include_block_causal = include_block_causal
14571473
)
14581474

14591475
ctx.save_for_backward(fq, fk, fv, selected_block_indices, fmask, out, lse)
14601476

14611477
ctx._saved_variables = (
14621478
block_size,
1463-
head_groups
1479+
head_groups,
1480+
include_block_causal
14641481
)
14651482

14661483
return out.type(dtype), lse
@@ -1473,7 +1490,8 @@ def backward(self, ctx, do, _):
14731490

14741491
(
14751492
block_size,
1476-
head_groups
1493+
head_groups,
1494+
include_block_causal
14771495
) = ctx._saved_variables
14781496

14791497
do = do.half()
@@ -1485,7 +1503,8 @@ def backward(self, ctx, do, _):
14851503
do, q, k, v,
14861504
sel_block_indices, mask,
14871505
out, lse, dq, dk, dv,
1488-
block_size = block_size
1506+
block_size = block_size,
1507+
include_block_causal = include_block_causal
14891508
)
14901509

14911510
return dq, dk, dv, None, None, None, None
@@ -1508,6 +1527,7 @@ def native_sparse_attend(
15081527
block_size: int,
15091528
selected_block_indices: Int['b qh n sel'] | Int['b kh n sel'],
15101529
fmask: Bool['b qh n sel'] | Bool['b kh n sel'],
1530+
include_block_causal = True,
15111531
return_lse = False
15121532
):
15131533
seq_len = fq.shape[-2]
@@ -1526,6 +1546,7 @@ def native_sparse_attend(
15261546
block_size,
15271547
selected_block_indices,
15281548
fmask,
1549+
include_block_causal
15291550
)
15301551

15311552
if not return_lse:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.58"
3+
version = "0.0.59"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)