Skip to content

Commit ae2b1b3

Browse files
committed
add test for gdn
1 parent d969da4 commit ae2b1b3

File tree

2 files changed

+181
-0
lines changed

2 files changed

+181
-0
lines changed

test/test_examples.expected

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1708,6 +1708,130 @@ def fused_linear_jsd_kernel(beta: float, ignore_index: int, temperature: float,
17081708
# src[fused_linear_jsd.py:N]: return (loss / student_logits.shape[0]).sum()
17091709
return (loss / student_logits.shape[0]).sum()
17101710

1711+
--- assertExpectedJournal(TestExamples.test_gdn_fwd_h)
1712+
from __future__ import annotations
1713+
1714+
import torch
1715+
import triton
1716+
import triton.language as tl
1717+
from helion.runtime import default_launcher as _default_launcher
1718+
1719+
@triton.jit
1720+
def _helion_helion_gdn_fwd_h(h, w, u, g, k, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_4: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
1721+
# src[gdn_fwd_h.py:N]: for tile_b, tile_h, tile_v in hl.tile(
1722+
# src[gdn_fwd_h.py:N]: [batch, nheads, dstate], block_size=[1, 1, block_v]
1723+
# src[gdn_fwd_h.py:N]: ):
1724+
num_blocks_0 = 8
1725+
num_blocks_1 = 80
1726+
pid_0 = tl.program_id(0) % num_blocks_0
1727+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
1728+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
1729+
offset_1 = pid_0
1730+
offset_2 = pid_1
1731+
offset_0 = pid_2 * _BLOCK_SIZE_0
1732+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1733+
indices_5 = tl.arange(0, _RDIM_SIZE_4).to(tl.int32)
1734+
# src[gdn_fwd_h.py:N]: b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
1735+
b_h = tl.full([64, _BLOCK_SIZE_0], 0.0, tl.float32)
1736+
# src[gdn_fwd_h.py:N]: for t_i in hl.tile(seqlen, block_size=chunk_size):
1737+
# src[gdn_fwd_h.py:N]: h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype)
1738+
# src[gdn_fwd_h.py:N]: b_w = w[tile_b.begin, t_i, tile_h.begin, :]
1739+
# src[gdn_fwd_h.py:N-N]: ...
1740+
for offset_4 in tl.range(0, 4096, _BLOCK_SIZE_3):
1741+
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1742+
b_h_copy = b_h
1743+
b_h_copy_0 = b_h_copy
1744+
# src[gdn_fwd_h.py:N]: h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype)
1745+
v_0 = tl.cast(b_h_copy_0, tl.bfloat16)
1746+
tile_id = offset_4 // _BLOCK_SIZE_3
1747+
tl.store(h + (offset_1 * 10485760 + tile_id * 655360 + offset_2 * 8192 + indices_5[:, None] * 128 + indices_0[None, :] * 1), v_0, None)
1748+
# src[gdn_fwd_h.py:N]: b_w = w[tile_b.begin, t_i, tile_h.begin, :]
1749+
b_w = tl.load(w + (offset_1 * 20971520 + indices_4[:, None] * 5120 + offset_2 * 64 + indices_5[None, :] * 1), None)
1750+
# src[gdn_fwd_h.py:N]: c_h = b_h.to(dtype)
1751+
v_1 = tl.cast(b_h_copy_0, tl.bfloat16)
1752+
# src[gdn_fwd_h.py:N]: b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype)
1753+
b_v = tl.dot(tl.cast(b_w, tl.bfloat16), tl.cast(v_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
1754+
# src[gdn_fwd_h.py:N]: p_v = u[tile_b.begin, t_i, tile_h.begin, tile_v].to(acc_dtype)
1755+
load_1 = tl.load(u + (offset_1 * 41943040 + indices_4[:, None] * 10240 + offset_2 * 128 + indices_0[None, :] * 1), None)
1756+
v_2 = tl.cast(load_1, tl.float32)
1757+
# src[gdn_fwd_h.py:N]: b_v = p_v - b_v
1758+
v_3 = v_2 - b_v
1759+
# src[gdn_fwd_h.py:N]: m_t = t_i.index < seqlen
1760+
v_4 = tl.full([], 4096, tl.int32)
1761+
v_5 = indices_4 < v_4
1762+
# src[gdn_fwd_h.py:N]: t_i_last = min(t_i.begin + chunk_size, seqlen) - 1
1763+
sub_1 = -1 + (4096 * (4096 <= 256 + offset_4) + (256 + offset_4) * (256 + offset_4 < 4096))
1764+
# src[gdn_fwd_h.py:N]: b_g_last = g[tile_b.begin, t_i_last, tile_h.begin].to(acc_dtype)
1765+
b_g_last = tl.load(g + (offset_1 * 327680 + sub_1 * 80 + offset_2 * 1), None)
1766+
# src[gdn_fwd_h.py:N]: b_g = g[tile_b.begin, t_i, tile_h.begin].to(acc_dtype)
1767+
b_g = tl.load(g + (offset_1 * 327680 + indices_4 * 80 + offset_2 * 1), None)
1768+
# src[gdn_fwd_h.py:N]: b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
1769+
v_6 = b_g_last[None]
1770+
v_7 = v_6 - b_g
1771+
v_8 = libdevice.exp(v_7)
1772+
v_9 = 0.0
1773+
v_10 = v_9[None]
1774+
v_11 = tl.where(v_5, v_8, v_10)
1775+
subscript = v_11[:, None]
1776+
v_12 = v_3 * subscript
1777+
# src[gdn_fwd_h.py:N]: b_g_last = torch.exp(b_g_last)
1778+
v_13 = libdevice.exp(b_g_last)
1779+
# src[gdn_fwd_h.py:N]: b_h *= b_g_last
1780+
v_14 = v_13[None, None]
1781+
v_15 = b_h_copy_0 * v_14
1782+
# src[gdn_fwd_h.py:N]: b_v = b_v.to(dtype)
1783+
v_16 = tl.cast(v_12, tl.bfloat16)
1784+
# src[gdn_fwd_h.py:N]: p_k = k[tile_b.begin, t_i, tile_h.begin, :]
1785+
p_k = tl.load(k + (offset_1 * 20971520 + indices_4[:, None] * 5120 + offset_2 * 64 + indices_5[None, :] * 1), None)
1786+
# src[gdn_fwd_h.py:N]: b_h = hl.dot(p_k.T, b_v, acc=b_h)
1787+
permute = tl.permute(p_k, [1, 0])
1788+
b_h = tl.dot(tl.cast(permute, tl.bfloat16), tl.cast(v_16, tl.bfloat16), acc=v_15, input_precision='tf32', out_dtype=tl.float32)
1789+
1790+
def helion_gdn_fwd_h(k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int, *, _launcher=_default_launcher):
1791+
"""
1792+
Argument:
1793+
k: (batch, seqlen, nheads, dhead)
1794+
w: (batch, seqlen, nheads, dhead)
1795+
u: (batch, seqlen, nheads, expand_v*dhead)
1796+
g: (batch, seqlen, nheads)
1797+
chunk_size: int
1798+
Return:
1799+
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
1800+
"""
1801+
# src[gdn_fwd_h.py:N]: batch, seqlen, nheads, dhead = k.shape
1802+
batch, seqlen, nheads, dhead = k.shape
1803+
# src[gdn_fwd_h.py:N]: dhead = hl.specialize(dhead)
1804+
dhead = 64
1805+
# src[gdn_fwd_h.py:N]: chunk_size = hl.specialize(chunk_size)
1806+
chunk_size = 256
1807+
# src[gdn_fwd_h.py:N]: dstate = u.shape[-1]
1808+
dstate = u.shape[-1]
1809+
# src[gdn_fwd_h.py:N]: acc_dtype = torch.float32
1810+
acc_dtype = torch.float32
1811+
# src[gdn_fwd_h.py:N]: dtype = k.dtype
1812+
dtype = k.dtype
1813+
# src[gdn_fwd_h.py:N]: nchunks = (seqlen + chunk_size - 1) // chunk_size
1814+
nchunks = (seqlen + chunk_size - 1) // chunk_size
1815+
# src[gdn_fwd_h.py:N]: h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k.device)
1816+
h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k.device)
1817+
# src[gdn_fwd_h.py:N]: for tile_b, tile_h, tile_v in hl.tile(
1818+
# src[gdn_fwd_h.py:N]: [batch, nheads, dstate], block_size=[1, 1, block_v]
1819+
# src[gdn_fwd_h.py:N]: ):
1820+
_BLOCK_SIZE_0 = 32
1821+
_RDIM_SIZE_4 = 64
1822+
# src[gdn_fwd_h.py:N]: for t_i in hl.tile(seqlen, block_size=chunk_size):
1823+
# src[gdn_fwd_h.py:N]: h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype)
1824+
# src[gdn_fwd_h.py:N]: b_w = w[tile_b.begin, t_i, tile_h.begin, :]
1825+
# src[gdn_fwd_h.py:N-N]: ...
1826+
_BLOCK_SIZE_3 = 256
1827+
# src[gdn_fwd_h.py:N]: for tile_b, tile_h, tile_v in hl.tile(
1828+
# src[gdn_fwd_h.py:N]: [batch, nheads, dstate], block_size=[1, 1, block_v]
1829+
# src[gdn_fwd_h.py:N]: ):
1830+
# src[gdn_fwd_h.py:N-N]: ...
1831+
_launcher(_helion_helion_gdn_fwd_h, (8 * 80 * triton.cdiv(128, _BLOCK_SIZE_0),), h, w, u, g, k, _BLOCK_SIZE_0, _RDIM_SIZE_4, _BLOCK_SIZE_3, num_warps=4, num_stages=1)
1832+
# src[gdn_fwd_h.py:N]: return h
1833+
return h
1834+
17111835
--- assertExpectedJournal(TestExamples.test_geglu)
17121836
from __future__ import annotations
17131837

test/test_examples.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1822,6 +1822,63 @@ def test_grpo_loss_bwd(self):
18221822
)
18231823
)
18241824

1825+
def test_gdn_fwd_h(self):
1826+
"""Test gated delta net forward h kernel."""
1827+
import math
1828+
1829+
batch = 8
1830+
nheads = 80
1831+
seqlen = 4096
1832+
chunk_size = 256
1833+
dhead = 64
1834+
dstate = 128
1835+
1836+
k = torch.randn(
1837+
batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device=DEVICE
1838+
)
1839+
k = torch.nn.functional.rms_norm(k, (dhead,))
1840+
w = torch.randn(
1841+
batch,
1842+
seqlen // chunk_size,
1843+
chunk_size,
1844+
nheads,
1845+
dhead,
1846+
dtype=torch.float32,
1847+
device=DEVICE,
1848+
)
1849+
wu, ws, wv = torch.linalg.svd(w.permute(0, 1, 3, 2, 4), full_matrices=False)
1850+
w = torch.einsum("bnhik,bnhkj->bnhij", wu, wv)
1851+
w = (
1852+
w.permute(0, 1, 3, 2, 4)
1853+
.reshape(batch, seqlen, nheads, dhead)
1854+
.to(torch.bfloat16)
1855+
)
1856+
u = torch.randn(
1857+
batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device=DEVICE
1858+
)
1859+
u = torch.nn.functional.rms_norm(u, (dstate,))
1860+
g = torch.cumsum(
1861+
0.5
1862+
* math.log(1 / dhead)
1863+
* torch.rand(batch, seqlen, nheads, dtype=torch.float32, device=DEVICE),
1864+
dim=1,
1865+
)
1866+
1867+
args = (k, w, u, g, chunk_size)
1868+
1869+
# Import and use the reference implementation
1870+
mod = import_path(EXAMPLES_DIR / "gdn_fwd_h.py")
1871+
expected = mod.ref_gdn_fwd_h(*args)
1872+
1873+
self.assertExpectedJournal(
1874+
check_example(
1875+
"gdn_fwd_h",
1876+
args,
1877+
expected,
1878+
fn_name="helion_gdn_fwd_h",
1879+
)
1880+
)
1881+
18251882

18261883
if __name__ == "__main__":
18271884
unittest.main()

0 commit comments

Comments
 (0)