@@ -1708,6 +1708,130 @@ def fused_linear_jsd_kernel(beta: float, ignore_index: int, temperature: float,
17081708 # src[fused_linear_jsd.py:N]: return (loss / student_logits.shape[0]).sum()
17091709 return (loss / student_logits.shape[0]).sum()
17101710
1711+ --- assertExpectedJournal(TestExamples.test_gdn_fwd_h)
1712+ from __future__ import annotations
1713+
1714+ import torch
1715+ import triton
1716+ import triton.language as tl
1717+ from helion.runtime import default_launcher as _default_launcher
1718+
1719+ @triton.jit
1720+ def _helion_helion_gdn_fwd_h(h, w, u, g, k, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_4: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
1721+ # src[gdn_fwd_h.py:N]: for tile_b, tile_h, tile_v in hl.tile(
1722+ # src[gdn_fwd_h.py:N]: [batch, nheads, dstate], block_size=[1, 1, block_v]
1723+ # src[gdn_fwd_h.py:N]: ):
1724+ num_blocks_0 = 8
1725+ num_blocks_1 = 80
1726+ pid_0 = tl.program_id(0) % num_blocks_0
1727+ pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
1728+ pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
1729+ offset_1 = pid_0
1730+ offset_2 = pid_1
1731+ offset_0 = pid_2 * _BLOCK_SIZE_0
1732+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1733+ indices_5 = tl.arange(0, _RDIM_SIZE_4).to(tl.int32)
1734+ # src[gdn_fwd_h.py:N]: b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
1735+ b_h = tl.full([64, _BLOCK_SIZE_0], 0.0, tl.float32)
1736+ # src[gdn_fwd_h.py:N]: for t_i in hl.tile(seqlen, block_size=chunk_size):
1737+ # src[gdn_fwd_h.py:N]: h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype)
1738+ # src[gdn_fwd_h.py:N]: b_w = w[tile_b.begin, t_i, tile_h.begin, :]
1739+ # src[gdn_fwd_h.py:N-N]: ...
1740+ for offset_4 in tl.range(0, 4096, _BLOCK_SIZE_3):
1741+ indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1742+ b_h_copy = b_h
1743+ b_h_copy_0 = b_h_copy
1744+ # src[gdn_fwd_h.py:N]: h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype)
1745+ v_0 = tl.cast(b_h_copy_0, tl.bfloat16)
1746+ tile_id = offset_4 // _BLOCK_SIZE_3
1747+ tl.store(h + (offset_1 * 10485760 + tile_id * 655360 + offset_2 * 8192 + indices_5[:, None] * 128 + indices_0[None, :] * 1), v_0, None)
1748+ # src[gdn_fwd_h.py:N]: b_w = w[tile_b.begin, t_i, tile_h.begin, :]
1749+ b_w = tl.load(w + (offset_1 * 20971520 + indices_4[:, None] * 5120 + offset_2 * 64 + indices_5[None, :] * 1), None)
1750+ # src[gdn_fwd_h.py:N]: c_h = b_h.to(dtype)
1751+ v_1 = tl.cast(b_h_copy_0, tl.bfloat16)
1752+ # src[gdn_fwd_h.py:N]: b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype)
1753+ b_v = tl.dot(tl.cast(b_w, tl.bfloat16), tl.cast(v_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
1754+ # src[gdn_fwd_h.py:N]: p_v = u[tile_b.begin, t_i, tile_h.begin, tile_v].to(acc_dtype)
1755+ load_1 = tl.load(u + (offset_1 * 41943040 + indices_4[:, None] * 10240 + offset_2 * 128 + indices_0[None, :] * 1), None)
1756+ v_2 = tl.cast(load_1, tl.float32)
1757+ # src[gdn_fwd_h.py:N]: b_v = p_v - b_v
1758+ v_3 = v_2 - b_v
1759+ # src[gdn_fwd_h.py:N]: m_t = t_i.index < seqlen
1760+ v_4 = tl.full([], 4096, tl.int32)
1761+ v_5 = indices_4 < v_4
1762+ # src[gdn_fwd_h.py:N]: t_i_last = min(t_i.begin + chunk_size, seqlen) - 1
1763+ sub_1 = -1 + (4096 * (4096 <= 256 + offset_4) + (256 + offset_4) * (256 + offset_4 < 4096))
1764+ # src[gdn_fwd_h.py:N]: b_g_last = g[tile_b.begin, t_i_last, tile_h.begin].to(acc_dtype)
1765+ b_g_last = tl.load(g + (offset_1 * 327680 + sub_1 * 80 + offset_2 * 1), None)
1766+ # src[gdn_fwd_h.py:N]: b_g = g[tile_b.begin, t_i, tile_h.begin].to(acc_dtype)
1767+ b_g = tl.load(g + (offset_1 * 327680 + indices_4 * 80 + offset_2 * 1), None)
1768+ # src[gdn_fwd_h.py:N]: b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
1769+ v_6 = b_g_last[None]
1770+ v_7 = v_6 - b_g
1771+ v_8 = libdevice.exp(v_7)
1772+ v_9 = 0.0
1773+ v_10 = v_9[None]
1774+ v_11 = tl.where(v_5, v_8, v_10)
1775+ subscript = v_11[:, None]
1776+ v_12 = v_3 * subscript
1777+ # src[gdn_fwd_h.py:N]: b_g_last = torch.exp(b_g_last)
1778+ v_13 = libdevice.exp(b_g_last)
1779+ # src[gdn_fwd_h.py:N]: b_h *= b_g_last
1780+ v_14 = v_13[None, None]
1781+ v_15 = b_h_copy_0 * v_14
1782+ # src[gdn_fwd_h.py:N]: b_v = b_v.to(dtype)
1783+ v_16 = tl.cast(v_12, tl.bfloat16)
1784+ # src[gdn_fwd_h.py:N]: p_k = k[tile_b.begin, t_i, tile_h.begin, :]
1785+ p_k = tl.load(k + (offset_1 * 20971520 + indices_4[:, None] * 5120 + offset_2 * 64 + indices_5[None, :] * 1), None)
1786+ # src[gdn_fwd_h.py:N]: b_h = hl.dot(p_k.T, b_v, acc=b_h)
1787+ permute = tl.permute(p_k, [1, 0])
1788+ b_h = tl.dot(tl.cast(permute, tl.bfloat16), tl.cast(v_16, tl.bfloat16), acc=v_15, input_precision='tf32', out_dtype=tl.float32)
1789+
1790+ def helion_gdn_fwd_h(k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int, *, _launcher=_default_launcher):
1791+ """
1792+ Argument:
1793+ k: (batch, seqlen, nheads, dhead)
1794+ w: (batch, seqlen, nheads, dhead)
1795+ u: (batch, seqlen, nheads, expand_v*dhead)
1796+ g: (batch, seqlen, nheads)
1797+ chunk_size: int
1798+ Return:
1799+ h: (batch, nchunks, nheads, dhead, expand_v*dhead)
1800+ """
1801+ # src[gdn_fwd_h.py:N]: batch, seqlen, nheads, dhead = k.shape
1802+ batch, seqlen, nheads, dhead = k.shape
1803+ # src[gdn_fwd_h.py:N]: dhead = hl.specialize(dhead)
1804+ dhead = 64
1805+ # src[gdn_fwd_h.py:N]: chunk_size = hl.specialize(chunk_size)
1806+ chunk_size = 256
1807+ # src[gdn_fwd_h.py:N]: dstate = u.shape[-1]
1808+ dstate = u.shape[-1]
1809+ # src[gdn_fwd_h.py:N]: acc_dtype = torch.float32
1810+ acc_dtype = torch.float32
1811+ # src[gdn_fwd_h.py:N]: dtype = k.dtype
1812+ dtype = k.dtype
1813+ # src[gdn_fwd_h.py:N]: nchunks = (seqlen + chunk_size - 1) // chunk_size
1814+ nchunks = (seqlen + chunk_size - 1) // chunk_size
1815+ # src[gdn_fwd_h.py:N]: h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k.device)
1816+ h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k.device)
1817+ # src[gdn_fwd_h.py:N]: for tile_b, tile_h, tile_v in hl.tile(
1818+ # src[gdn_fwd_h.py:N]: [batch, nheads, dstate], block_size=[1, 1, block_v]
1819+ # src[gdn_fwd_h.py:N]: ):
1820+ _BLOCK_SIZE_0 = 32
1821+ _RDIM_SIZE_4 = 64
1822+ # src[gdn_fwd_h.py:N]: for t_i in hl.tile(seqlen, block_size=chunk_size):
1823+ # src[gdn_fwd_h.py:N]: h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype)
1824+ # src[gdn_fwd_h.py:N]: b_w = w[tile_b.begin, t_i, tile_h.begin, :]
1825+ # src[gdn_fwd_h.py:N-N]: ...
1826+ _BLOCK_SIZE_3 = 256
1827+ # src[gdn_fwd_h.py:N]: for tile_b, tile_h, tile_v in hl.tile(
1828+ # src[gdn_fwd_h.py:N]: [batch, nheads, dstate], block_size=[1, 1, block_v]
1829+ # src[gdn_fwd_h.py:N]: ):
1830+ # src[gdn_fwd_h.py:N-N]: ...
1831+ _launcher(_helion_helion_gdn_fwd_h, (8 * 80 * triton.cdiv(128, _BLOCK_SIZE_0),), h, w, u, g, k, _BLOCK_SIZE_0, _RDIM_SIZE_4, _BLOCK_SIZE_3, num_warps=4, num_stages=1)
1832+ # src[gdn_fwd_h.py:N]: return h
1833+ return h
1834+
17111835--- assertExpectedJournal(TestExamples.test_geglu)
17121836from __future__ import annotations
17131837
0 commit comments