Skip to content

Commit d969da4

Browse files
committed
seqlen support
1 parent 852b2e1 commit d969da4

File tree

1 file changed

+21
-48
lines changed

1 file changed

+21
-48
lines changed

examples/gdn_fwd_h.py

Lines changed: 21 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,83 +25,56 @@
2525
# Helion Kernel Implementation
2626
# ----------------------------
2727
@helion.kernel()
28-
def helion_gdn_fwd_h_kernel(
29-
k_c: torch.Tensor, w_c: torch.Tensor, u_c: torch.Tensor, g_c: torch.Tensor
28+
def helion_gdn_fwd_h(
29+
k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int
3030
) -> torch.Tensor:
3131
"""
3232
Argument:
33-
k_c: (batch, nchunks, chunk_size, nheads, dhead)
34-
w_c: (batch, nchunks, chunk_size, nheads, dhead)
35-
u_c: (batch, nchunks, chunk_size, nheads, expand_v*dhead)
36-
g_c: (batch, nchunks, chunk_size, nheads)
33+
k: (batch, seqlen, nheads, dhead)
34+
w: (batch, seqlen, nheads, dhead)
35+
u: (batch, seqlen, nheads, expand_v*dhead)
36+
g: (batch, seqlen, nheads)
37+
chunk_size: int
3738
Return:
3839
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
3940
"""
4041

41-
batch, nchunks, chunk_size, nheads, dhead = k_c.shape
42+
batch, seqlen, nheads, dhead = k.shape
4243
dhead = hl.specialize(dhead)
4344
chunk_size = hl.specialize(chunk_size)
44-
dstate = u_c.shape[-1]
45+
dstate = u.shape[-1]
4546

4647
acc_dtype = torch.float32
47-
dtype = k_c.dtype
48+
dtype = k.dtype
4849

49-
h = torch.empty(
50-
batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k_c.device
51-
)
50+
nchunks = (seqlen + chunk_size - 1) // chunk_size
51+
h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k.device)
5252
block_v = hl.register_block_size(dstate)
53-
seqlen = chunk_size * nchunks
5453

5554
for tile_b, tile_h, tile_v in hl.tile(
5655
[batch, nheads, dstate], block_size=[1, 1, block_v]
5756
):
5857
b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
59-
for i_t in range(nchunks):
60-
h[tile_b.begin, i_t, tile_h.begin, :, tile_v] = b_h.to(dtype)
61-
b_w = w_c[tile_b.begin, i_t, :, tile_h.begin, :]
58+
for t_i in hl.tile(seqlen, block_size=chunk_size):
59+
h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype)
60+
b_w = w[tile_b.begin, t_i, tile_h.begin, :]
6261
c_h = b_h.to(dtype)
6362
b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype)
64-
p_v = u_c[tile_b.begin, i_t, :, tile_h.begin, tile_v].to(acc_dtype)
63+
p_v = u[tile_b.begin, t_i, tile_h.begin, tile_v].to(acc_dtype)
6564
b_v = p_v - b_v
66-
m_t = (i_t * chunk_size + hl.arange(0, chunk_size)) < seqlen
67-
b_g_last = g_c[tile_b.begin, i_t, chunk_size - 1, tile_h.begin].to(
68-
acc_dtype
69-
)
70-
b_g = g_c[tile_b.begin, i_t, :, tile_h.begin].to(acc_dtype)
65+
m_t = t_i.index < seqlen
66+
t_i_last = min(t_i.begin + chunk_size, seqlen) - 1
67+
b_g_last = g[tile_b.begin, t_i_last, tile_h.begin].to(acc_dtype)
68+
b_g = g[tile_b.begin, t_i, tile_h.begin].to(acc_dtype)
7169
b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
7270
b_g_last = torch.exp(b_g_last)
7371
b_h *= b_g_last
7472
b_v = b_v.to(dtype)
75-
p_k = k_c[tile_b.begin, i_t, :, tile_h.begin, :]
73+
p_k = k[tile_b.begin, t_i, tile_h.begin, :]
7674
b_h = hl.dot(p_k.T, b_v, acc=b_h)
7775
return h
7876

7977

80-
def helion_gdn_fwd_h(
81-
k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int
82-
) -> torch.Tensor:
83-
"""
84-
Argument:
85-
k: (batch, seqlen, nheads, dhead)
86-
w: (batch, seqlen, nheads, dhead)
87-
u: (batch, seqlen, nheads, expand_v*dhead)
88-
g: (batch, seqlen, nheads)
89-
chunk_size: int
90-
Return:
91-
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
92-
"""
93-
94-
batch, seqlen, nheads, dhead = k.shape
95-
dstate = u.shape[-1]
96-
nchunks = (seqlen + chunk_size - 1) // chunk_size
97-
98-
k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
99-
w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
100-
u_c = u.reshape(batch, nchunks, chunk_size, nheads, dstate)
101-
g_c = g.reshape(batch, nchunks, chunk_size, nheads)
102-
return helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c)
103-
104-
10578
def helion_gdn_fwd_h_tb(
10679
tb_obj: object,
10780
k: torch.Tensor,

0 commit comments

Comments
 (0)