diff --git a/benchmarks/run.py b/benchmarks/run.py index 5ccb5cdd4..d3c63e3d5 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -335,6 +335,11 @@ class RunResult: "examples.mamba2_chunk_state", "helion_mamba2_chunk_state_kernel", ), + "gdn_fwd_h": ( + "tritonbench.operators.gdn_fwd_h.operator", + "examples.gdn_fwd_h", + "helion_gdn_fwd_h_tb", + ), } @@ -651,6 +656,13 @@ class RunResult: "helion_mamba2_chunk_state_kernel_speedup": "helion_speedup", "helion_mamba2_chunk_state_kernel_accuracy": "helion_accuracy", }, + "gdn_fwd_h": { + "eager": "baseline", + "compile_speedup": "torch_compile_speedup", + "compile_accuracy": "torch_compile_accuracy", + "helion_gdn_fwd_h_speedup": "helion_speedup", + "helion_gdn_fwd_h_accuracy": "helion_accuracy", + }, } diff --git a/examples/gdn_fwd_h.py b/examples/gdn_fwd_h.py new file mode 100644 index 000000000..45e11f0b2 --- /dev/null +++ b/examples/gdn_fwd_h.py @@ -0,0 +1,210 @@ +""" +Gated Delta Net Fwd H Kernel +============================ + +This code implements a fwd_h kernel as used in gated delta net +""" + +# %% +# Imports +# ------- +from __future__ import annotations + +import math +from typing import Callable + +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import run_example +import helion.language as hl + + +# %% +# Helion Kernel Implementation +# ---------------------------- +@helion.kernel() +def helion_gdn_fwd_h( + k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int +) -> torch.Tensor: + """ + Argument: + k: (batch, seqlen, nheads, dhead) + w: (batch, seqlen, nheads, dhead) + u: (batch, seqlen, nheads, expand_v*dhead) + g: (batch, seqlen, nheads) + chunk_size: int + Return: + h: (batch, nchunks, nheads, dhead, expand_v*dhead) + """ + + batch, seqlen, nheads, dhead = k.shape + dhead = hl.specialize(dhead) + chunk_size = hl.specialize(chunk_size) + dstate = u.shape[-1] + + acc_dtype = torch.float32 + dtype = k.dtype + + nchunks = (seqlen + chunk_size - 1) // chunk_size + h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k.device) + block_v = hl.register_block_size(dstate) + + for tile_b, tile_h, tile_v in hl.tile( + [batch, nheads, dstate], block_size=[1, 1, block_v] + ): + b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype) + for t_i in hl.tile(seqlen, block_size=chunk_size): + h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype) + b_w = w[tile_b.begin, t_i, tile_h.begin, :] + c_h = b_h.to(dtype) + b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype) + p_v = u[tile_b.begin, t_i, tile_h.begin, tile_v].to(acc_dtype) + b_v = p_v - b_v + m_t = t_i.index < seqlen + t_i_last = min(t_i.begin + chunk_size, seqlen) - 1 + b_g_last = g[tile_b.begin, t_i_last, tile_h.begin].to(acc_dtype) + b_g = g[tile_b.begin, t_i, tile_h.begin].to(acc_dtype) + b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None] + b_g_last = torch.exp(b_g_last) + b_h *= b_g_last + b_v = b_v.to(dtype) + p_k = k[tile_b.begin, t_i, tile_h.begin, :] + b_h = hl.dot(p_k.T, b_v, acc=b_h) + return h + + +def helion_gdn_fwd_h_tb( + tb_obj: object, + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor, + chunk_size: int, +) -> Callable[[], torch.Tensor]: + """ + Argument: + k: (batch, seqlen, nheads, dhead) + w: (batch, seqlen, nheads, dhead) + u: (batch, seqlen, nheads, expand_v*dhead) + g: (batch, seqlen, nheads) + chunk_size: int + Return: + h: (batch, nchunks, nheads, dhead, expand_v*dhead) + """ + return lambda: helion_gdn_fwd_h(k, w, u, g, chunk_size) + + +# %% +# Reference Function +# ------------- +def ref_gdn_fwd_h( + k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int +) -> torch.Tensor: + """ + Argument: + k: (batch, seqlen, nheads, dhead) + w: (batch, seqlen, nheads, dhead) + u: (batch, seqlen, nheads, expand_v*dhead) + g: (batch, seqlen, nheads) + chunk_size: int + Return: + h: (batch, nchunks, nheads, dhead, expand_v*dhead) + """ + + batch, seqlen, nheads, dhead = k.shape + expand_v = u.shape[-1] // dhead + nchunks = (seqlen + chunk_size - 1) // chunk_size + + acc_dtype = torch.float32 + dtype = k.dtype + + h = torch.empty( + batch, nchunks, nheads, dhead, expand_v * dhead, dtype=k.dtype, device=k.device + ) + b_h = torch.zeros( + batch, nheads, dhead, expand_v * dhead, dtype=acc_dtype, device=k.device + ) + + k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead) + w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead) + u_c = u.reshape(batch, nchunks, chunk_size, nheads, expand_v * dhead) + g_c = g.reshape(batch, nchunks, chunk_size, nheads) + for i_t in range(nchunks): + h[:, i_t, :, :, :] = b_h.to(dtype) + b_w = w_c[:, i_t, :, :, :].to(acc_dtype) + c_h = b_h.to(dtype).to(acc_dtype) + b_v = torch.einsum("bchk,bhkv->bchv", b_w, c_h) + p_v = u_c[:, i_t, :, :, :].to(acc_dtype) + b_v = p_v - b_v + last_idx = min((i_t + 1) * chunk_size, seqlen) - 1 + m_t = (i_t * chunk_size + torch.arange(0, chunk_size, device=k.device)) < seqlen + b_g_last = g[:, last_idx, :].to(acc_dtype) + b_g = g_c[:, i_t, :, :].to(acc_dtype) # batch, chunk, nheads + b_v *= torch.where( + m_t.unsqueeze(0).unsqueeze(-1), torch.exp(b_g_last.unsqueeze(1) - b_g), 0 + ).unsqueeze(-1) + b_g_last = torch.exp(b_g_last) + b_h *= b_g_last.unsqueeze(-1).unsqueeze(-1) + b_v = b_v.to(dtype).to(acc_dtype) + p_k = k_c[:, i_t, :, :, :].to(acc_dtype) + b_h += torch.einsum("bchk,bchv->bhkv", p_k, b_v) + return h + + +# %% +# Testing Function +# ------------- +def test( + batch: int, + nheads: int, + seqlen: int, + chunk_size: int, + dhead: int, + dstate: int, + dtype: torch.dtype = torch.float16, +) -> None: + k = torch.randn(batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device=DEVICE) + k = torch.nn.functional.rms_norm(k, (dhead,)) + w = torch.randn( + batch, + seqlen // chunk_size, + chunk_size, + nheads, + dhead, + dtype=torch.float32, + device=DEVICE, + ) + # w = torch.nn.functional.rms_norm(w.to(torch.bfloat16), (dhead,)) + wu, ws, wv = torch.linalg.svd(w.permute(0, 1, 3, 2, 4), full_matrices=False) + w = torch.einsum("bnhik,bnhkj->bnhij", wu, wv) + w = ( + w.permute(0, 1, 3, 2, 4) + .reshape(batch, seqlen, nheads, dhead) + .to(torch.bfloat16) + ) + u = torch.randn(batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device=DEVICE) + u = torch.nn.functional.rms_norm(u, (dstate,)) + g = torch.cumsum( + 0.5 + * math.log(1 / dhead) + * torch.rand(batch, seqlen, nheads, dtype=torch.float32, device=DEVICE), + dim=1, + ) + args = (k, w, u, g, chunk_size) + run_example(helion_gdn_fwd_h, ref_gdn_fwd_h, args) + + +# %% +# Main Function +# ----------- +def main() -> None: + """ + Main entry point that runs the attention kernel test with specific parameters. + """ + test(8, 80, 4096, 256, 64, 128) + + +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index fc2b53f4f..415f1e313 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -1708,6 +1708,130 @@ def fused_linear_jsd_kernel(beta: float, ignore_index: int, temperature: float, # src[fused_linear_jsd.py:N]: return (loss / student_logits.shape[0]).sum() return (loss / student_logits.shape[0]).sum() +--- assertExpectedJournal(TestExamples.test_gdn_fwd_h) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_helion_gdn_fwd_h(h, w, u, g, k, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_4: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): + # src[gdn_fwd_h.py:N]: for tile_b, tile_h, tile_v in hl.tile( + # src[gdn_fwd_h.py:N]: [batch, nheads, dstate], block_size=[1, 1, block_v] + # src[gdn_fwd_h.py:N]: ): + num_blocks_0 = 8 + num_blocks_1 = 80 + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1 + pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) + offset_1 = pid_0 + offset_2 = pid_1 + offset_0 = pid_2 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + indices_5 = tl.arange(0, _RDIM_SIZE_4).to(tl.int32) + # src[gdn_fwd_h.py:N]: b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype) + b_h = tl.full([64, _BLOCK_SIZE_0], 0.0, tl.float32) + # src[gdn_fwd_h.py:N]: for t_i in hl.tile(seqlen, block_size=chunk_size): + # src[gdn_fwd_h.py:N]: h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype) + # src[gdn_fwd_h.py:N]: b_w = w[tile_b.begin, t_i, tile_h.begin, :] + # src[gdn_fwd_h.py:N-N]: ... + for offset_4 in tl.range(0, 4096, _BLOCK_SIZE_3): + indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) + b_h_copy = b_h + b_h_copy_0 = b_h_copy + # src[gdn_fwd_h.py:N]: h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype) + v_0 = tl.cast(b_h_copy_0, tl.bfloat16) + tile_id = offset_4 // _BLOCK_SIZE_3 + tl.store(h + (offset_1 * 10485760 + tile_id * 655360 + offset_2 * 8192 + indices_5[:, None] * 128 + indices_0[None, :] * 1), v_0, None) + # src[gdn_fwd_h.py:N]: b_w = w[tile_b.begin, t_i, tile_h.begin, :] + b_w = tl.load(w + (offset_1 * 20971520 + indices_4[:, None] * 5120 + offset_2 * 64 + indices_5[None, :] * 1), None) + # src[gdn_fwd_h.py:N]: c_h = b_h.to(dtype) + v_1 = tl.cast(b_h_copy_0, tl.bfloat16) + # src[gdn_fwd_h.py:N]: b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype) + b_v = tl.dot(tl.cast(b_w, tl.bfloat16), tl.cast(v_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32) + # src[gdn_fwd_h.py:N]: p_v = u[tile_b.begin, t_i, tile_h.begin, tile_v].to(acc_dtype) + load_1 = tl.load(u + (offset_1 * 41943040 + indices_4[:, None] * 10240 + offset_2 * 128 + indices_0[None, :] * 1), None) + v_2 = tl.cast(load_1, tl.float32) + # src[gdn_fwd_h.py:N]: b_v = p_v - b_v + v_3 = v_2 - b_v + # src[gdn_fwd_h.py:N]: m_t = t_i.index < seqlen + v_4 = tl.full([], 4096, tl.int32) + v_5 = indices_4 < v_4 + # src[gdn_fwd_h.py:N]: t_i_last = min(t_i.begin + chunk_size, seqlen) - 1 + sub_1 = -1 + (4096 * (4096 <= 256 + offset_4) + (256 + offset_4) * (256 + offset_4 < 4096)) + # src[gdn_fwd_h.py:N]: b_g_last = g[tile_b.begin, t_i_last, tile_h.begin].to(acc_dtype) + b_g_last = tl.load(g + (offset_1 * 327680 + sub_1 * 80 + offset_2 * 1), None) + # src[gdn_fwd_h.py:N]: b_g = g[tile_b.begin, t_i, tile_h.begin].to(acc_dtype) + b_g = tl.load(g + (offset_1 * 327680 + indices_4 * 80 + offset_2 * 1), None) + # src[gdn_fwd_h.py:N]: b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None] + v_6 = b_g_last[None] + v_7 = v_6 - b_g + v_8 = libdevice.exp(v_7) + v_9 = 0.0 + v_10 = v_9[None] + v_11 = tl.where(v_5, v_8, v_10) + subscript = v_11[:, None] + v_12 = v_3 * subscript + # src[gdn_fwd_h.py:N]: b_g_last = torch.exp(b_g_last) + v_13 = libdevice.exp(b_g_last) + # src[gdn_fwd_h.py:N]: b_h *= b_g_last + v_14 = v_13[None, None] + v_15 = b_h_copy_0 * v_14 + # src[gdn_fwd_h.py:N]: b_v = b_v.to(dtype) + v_16 = tl.cast(v_12, tl.bfloat16) + # src[gdn_fwd_h.py:N]: p_k = k[tile_b.begin, t_i, tile_h.begin, :] + p_k = tl.load(k + (offset_1 * 20971520 + indices_4[:, None] * 5120 + offset_2 * 64 + indices_5[None, :] * 1), None) + # src[gdn_fwd_h.py:N]: b_h = hl.dot(p_k.T, b_v, acc=b_h) + permute = tl.permute(p_k, [1, 0]) + b_h = tl.dot(tl.cast(permute, tl.bfloat16), tl.cast(v_16, tl.bfloat16), acc=v_15, input_precision='tf32', out_dtype=tl.float32) + +def helion_gdn_fwd_h(k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int, *, _launcher=_default_launcher): + """ + Argument: + k: (batch, seqlen, nheads, dhead) + w: (batch, seqlen, nheads, dhead) + u: (batch, seqlen, nheads, expand_v*dhead) + g: (batch, seqlen, nheads) + chunk_size: int + Return: + h: (batch, nchunks, nheads, dhead, expand_v*dhead) + """ + # src[gdn_fwd_h.py:N]: batch, seqlen, nheads, dhead = k.shape + batch, seqlen, nheads, dhead = k.shape + # src[gdn_fwd_h.py:N]: dhead = hl.specialize(dhead) + dhead = 64 + # src[gdn_fwd_h.py:N]: chunk_size = hl.specialize(chunk_size) + chunk_size = 256 + # src[gdn_fwd_h.py:N]: dstate = u.shape[-1] + dstate = u.shape[-1] + # src[gdn_fwd_h.py:N]: acc_dtype = torch.float32 + acc_dtype = torch.float32 + # src[gdn_fwd_h.py:N]: dtype = k.dtype + dtype = k.dtype + # src[gdn_fwd_h.py:N]: nchunks = (seqlen + chunk_size - 1) // chunk_size + nchunks = (seqlen + chunk_size - 1) // chunk_size + # src[gdn_fwd_h.py:N]: h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k.device) + h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k.device) + # src[gdn_fwd_h.py:N]: for tile_b, tile_h, tile_v in hl.tile( + # src[gdn_fwd_h.py:N]: [batch, nheads, dstate], block_size=[1, 1, block_v] + # src[gdn_fwd_h.py:N]: ): + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_4 = 64 + # src[gdn_fwd_h.py:N]: for t_i in hl.tile(seqlen, block_size=chunk_size): + # src[gdn_fwd_h.py:N]: h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype) + # src[gdn_fwd_h.py:N]: b_w = w[tile_b.begin, t_i, tile_h.begin, :] + # src[gdn_fwd_h.py:N-N]: ... + _BLOCK_SIZE_3 = 256 + # src[gdn_fwd_h.py:N]: for tile_b, tile_h, tile_v in hl.tile( + # src[gdn_fwd_h.py:N]: [batch, nheads, dstate], block_size=[1, 1, block_v] + # src[gdn_fwd_h.py:N]: ): + # src[gdn_fwd_h.py:N-N]: ... + _launcher(_helion_helion_gdn_fwd_h, (8 * 80 * triton.cdiv(128, _BLOCK_SIZE_0),), h, w, u, g, k, _BLOCK_SIZE_0, _RDIM_SIZE_4, _BLOCK_SIZE_3, num_warps=4, num_stages=1) + # src[gdn_fwd_h.py:N]: return h + return h + --- assertExpectedJournal(TestExamples.test_geglu) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index 31a796a5f..f2cb4f830 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1822,6 +1822,63 @@ def test_grpo_loss_bwd(self): ) ) + def test_gdn_fwd_h(self): + """Test gated delta net forward h kernel.""" + import math + + batch = 8 + nheads = 80 + seqlen = 4096 + chunk_size = 256 + dhead = 64 + dstate = 128 + + k = torch.randn( + batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device=DEVICE + ) + k = torch.nn.functional.rms_norm(k, (dhead,)) + w = torch.randn( + batch, + seqlen // chunk_size, + chunk_size, + nheads, + dhead, + dtype=torch.float32, + device=DEVICE, + ) + wu, ws, wv = torch.linalg.svd(w.permute(0, 1, 3, 2, 4), full_matrices=False) + w = torch.einsum("bnhik,bnhkj->bnhij", wu, wv) + w = ( + w.permute(0, 1, 3, 2, 4) + .reshape(batch, seqlen, nheads, dhead) + .to(torch.bfloat16) + ) + u = torch.randn( + batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device=DEVICE + ) + u = torch.nn.functional.rms_norm(u, (dstate,)) + g = torch.cumsum( + 0.5 + * math.log(1 / dhead) + * torch.rand(batch, seqlen, nheads, dtype=torch.float32, device=DEVICE), + dim=1, + ) + + args = (k, w, u, g, chunk_size) + + # Import and use the reference implementation + mod = import_path(EXAMPLES_DIR / "gdn_fwd_h.py") + expected = mod.ref_gdn_fwd_h(*args) + + self.assertExpectedJournal( + check_example( + "gdn_fwd_h", + args, + expected, + fn_name="helion_gdn_fwd_h", + ) + ) + if __name__ == "__main__": unittest.main()