|
25 | 25 | # Helion Kernel Implementation |
26 | 26 | # ---------------------------- |
27 | 27 | @helion.kernel() |
28 | | -def helion_gdn_fwd_h_kernel( |
29 | | - k_c: torch.Tensor, w_c: torch.Tensor, u_c: torch.Tensor, g_c: torch.Tensor |
| 28 | +def helion_gdn_fwd_h( |
| 29 | + k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int |
30 | 30 | ) -> torch.Tensor: |
31 | 31 | """ |
32 | 32 | Argument: |
33 | | - k_c: (batch, nchunks, chunk_size, nheads, dhead) |
34 | | - w_c: (batch, nchunks, chunk_size, nheads, dhead) |
35 | | - u_c: (batch, nchunks, chunk_size, nheads, expand_v*dhead) |
36 | | - g_c: (batch, nchunks, chunk_size, nheads) |
| 33 | + k: (batch, seqlen, nheads, dhead) |
| 34 | + w: (batch, seqlen, nheads, dhead) |
| 35 | + u: (batch, seqlen, nheads, expand_v*dhead) |
| 36 | + g: (batch, seqlen, nheads) |
| 37 | + chunk_size: int |
37 | 38 | Return: |
38 | 39 | h: (batch, nchunks, nheads, dhead, expand_v*dhead) |
39 | 40 | """ |
40 | 41 |
|
41 | | - batch, nchunks, chunk_size, nheads, dhead = k_c.shape |
| 42 | + batch, seqlen, nheads, dhead = k.shape |
42 | 43 | dhead = hl.specialize(dhead) |
43 | 44 | chunk_size = hl.specialize(chunk_size) |
44 | | - dstate = u_c.shape[-1] |
| 45 | + dstate = u.shape[-1] |
45 | 46 |
|
46 | 47 | acc_dtype = torch.float32 |
47 | | - dtype = k_c.dtype |
| 48 | + dtype = k.dtype |
48 | 49 |
|
49 | | - h = torch.empty( |
50 | | - batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k_c.device |
51 | | - ) |
| 50 | + nchunks = (seqlen + chunk_size - 1) // chunk_size |
| 51 | + h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k.device) |
52 | 52 | block_v = hl.register_block_size(dstate) |
53 | | - seqlen = chunk_size * nchunks |
54 | 53 |
|
55 | 54 | for tile_b, tile_h, tile_v in hl.tile( |
56 | 55 | [batch, nheads, dstate], block_size=[1, 1, block_v] |
57 | 56 | ): |
58 | 57 | b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype) |
59 | | - for i_t in range(nchunks): |
60 | | - h[tile_b.begin, i_t, tile_h.begin, :, tile_v] = b_h.to(dtype) |
61 | | - b_w = w_c[tile_b.begin, i_t, :, tile_h.begin, :] |
| 58 | + for t_i in hl.tile(seqlen, block_size=chunk_size): |
| 59 | + h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype) |
| 60 | + b_w = w[tile_b.begin, t_i, tile_h.begin, :] |
62 | 61 | c_h = b_h.to(dtype) |
63 | 62 | b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype) |
64 | | - p_v = u_c[tile_b.begin, i_t, :, tile_h.begin, tile_v].to(acc_dtype) |
| 63 | + p_v = u[tile_b.begin, t_i, tile_h.begin, tile_v].to(acc_dtype) |
65 | 64 | b_v = p_v - b_v |
66 | | - m_t = (i_t * chunk_size + hl.arange(0, chunk_size)) < seqlen |
67 | | - b_g_last = g_c[tile_b.begin, i_t, chunk_size - 1, tile_h.begin].to( |
68 | | - acc_dtype |
69 | | - ) |
70 | | - b_g = g_c[tile_b.begin, i_t, :, tile_h.begin].to(acc_dtype) |
| 65 | + m_t = t_i.index < seqlen |
| 66 | + t_i_last = min(t_i.begin + chunk_size, seqlen) - 1 |
| 67 | + b_g_last = g[tile_b.begin, t_i_last, tile_h.begin].to(acc_dtype) |
| 68 | + b_g = g[tile_b.begin, t_i, tile_h.begin].to(acc_dtype) |
71 | 69 | b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None] |
72 | 70 | b_g_last = torch.exp(b_g_last) |
73 | 71 | b_h *= b_g_last |
74 | 72 | b_v = b_v.to(dtype) |
75 | | - p_k = k_c[tile_b.begin, i_t, :, tile_h.begin, :] |
| 73 | + p_k = k[tile_b.begin, t_i, tile_h.begin, :] |
76 | 74 | b_h = hl.dot(p_k.T, b_v, acc=b_h) |
77 | 75 | return h |
78 | 76 |
|
79 | 77 |
|
80 | | -def helion_gdn_fwd_h( |
81 | | - k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int |
82 | | -) -> torch.Tensor: |
83 | | - """ |
84 | | - Argument: |
85 | | - k: (batch, seqlen, nheads, dhead) |
86 | | - w: (batch, seqlen, nheads, dhead) |
87 | | - u: (batch, seqlen, nheads, expand_v*dhead) |
88 | | - g: (batch, seqlen, nheads) |
89 | | - chunk_size: int |
90 | | - Return: |
91 | | - h: (batch, nchunks, nheads, dhead, expand_v*dhead) |
92 | | - """ |
93 | | - |
94 | | - batch, seqlen, nheads, dhead = k.shape |
95 | | - dstate = u.shape[-1] |
96 | | - nchunks = (seqlen + chunk_size - 1) // chunk_size |
97 | | - |
98 | | - k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead) |
99 | | - w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead) |
100 | | - u_c = u.reshape(batch, nchunks, chunk_size, nheads, dstate) |
101 | | - g_c = g.reshape(batch, nchunks, chunk_size, nheads) |
102 | | - return helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c) |
103 | | - |
104 | | - |
105 | 78 | def helion_gdn_fwd_h_tb( |
106 | 79 | tb_obj: object, |
107 | 80 | k: torch.Tensor, |
|
0 commit comments