|
| 1 | +#!POPCORN leaderboard amd-fp8-mm |
| 2 | +#!POPCORN gpu MI300 |
| 3 | + |
| 4 | +from task import input_t, output_t |
| 5 | +import torch |
| 6 | +import triton |
| 7 | +import triton.language as tl |
| 8 | + |
| 9 | +NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count |
| 10 | + |
| 11 | + |
| 12 | +@triton.jit |
| 13 | +def kernel( |
| 14 | + A_ptr, |
| 15 | + B_ptr, |
| 16 | + A_scale_ptr, |
| 17 | + B_scale_ptr, |
| 18 | + C_ptr, |
| 19 | + M: tl.constexpr, |
| 20 | + N: tl.constexpr, |
| 21 | + K: tl.constexpr, |
| 22 | + BLOCK_M: tl.constexpr, |
| 23 | + BLOCK_N: tl.constexpr, |
| 24 | + BLOCK_K: tl.constexpr, |
| 25 | + BLOCK_Q: tl.constexpr = 128, |
| 26 | + TRANSPOSE: tl.constexpr = False, |
| 27 | +): |
| 28 | + program_id = tl.program_id(0) |
| 29 | + num_pid_across_n = tl.cdiv(N, BLOCK_N) |
| 30 | + |
| 31 | + program_id_m = program_id // num_pid_across_n |
| 32 | + program_id_n = program_id % num_pid_across_n |
| 33 | + |
| 34 | + if not TRANSPOSE: |
| 35 | + A_stride_m, A_stride_k = 1, M |
| 36 | + B_stride_n, B_stride_k = 1, N |
| 37 | + else: |
| 38 | + A_stride_m, A_stride_k = K, 1 |
| 39 | + B_stride_n, B_stride_k = K, 1 |
| 40 | + C_stride_m, C_stride_n = N, 1 |
| 41 | + # Scale matrices are stored in column-major order, with A being 1x128 and B being 128x128 chunks |
| 42 | + # BLOCK_Q is 128 |
| 43 | + A_scale_stride_m, A_scale_stride_k = 1, M |
| 44 | + B_scale_stride_n, B_scale_stride_k = 1, tl.cdiv(N, BLOCK_Q) |
| 45 | + |
| 46 | + # Calculate the row and column indices in the output matrix for the current pid |
| 47 | + offset_m = program_id_m * BLOCK_M |
| 48 | + offset_n = program_id_n * BLOCK_N |
| 49 | + |
| 50 | + # Arange to make a row and column ptrs |
| 51 | + block_offsets_m = offset_m + tl.arange(0, BLOCK_M) |
| 52 | + block_offsets_n = offset_n + tl.arange(0, BLOCK_N) |
| 53 | + block_offsets_k = tl.arange(0, BLOCK_K) |
| 54 | + |
| 55 | + # ptrs for BLOCK_M rows of A and BLOCK_N columns of B |
| 56 | + A_block_ptrs = A_ptr + ( |
| 57 | + block_offsets_m[:, None] * A_stride_m + block_offsets_k[None, :] * A_stride_k |
| 58 | + ) |
| 59 | + B_block_ptrs = B_ptr + ( |
| 60 | + block_offsets_k[:, None] * B_stride_k + block_offsets_n[None, :] * B_stride_n |
| 61 | + ) |
| 62 | + # since a_scales are 1x128, a_scale_ptrs need to be of shape (BLOCK_M, 1) |
| 63 | + # since N, K <= BLOCK_Q, b_scale_ptrs is always a scalar ptr |
| 64 | + A_scale_block_ptrs = A_scale_ptr + (block_offsets_m[:, None] * A_scale_stride_m) |
| 65 | + B_scale_block_ptrs = B_scale_ptr + (offset_n // BLOCK_Q) * B_scale_stride_n |
| 66 | + |
| 67 | + # Initialize accumulator for the currrent pid (responsible for BLOCK_M * BLOCK_N elements) |
| 68 | + master_accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| 69 | + |
| 70 | + # In each iteration we we load BLOCK_Q elements from K dimension for BLOCK_M rows, resp. BLOCK_N columns |
| 71 | + # We choose this to use only 1 scale per iteration |
| 72 | + num_k_iters = K // BLOCK_Q |
| 73 | + for _ in range(0, num_k_iters): |
| 74 | + # Initialize accumulator for the current k iteration |
| 75 | + inner_accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| 76 | + # In each iteration we load BLOCK_K elements from K dimension for BLOCK_M rows, resp. BLOCK_N columns |
| 77 | + # We choose this to use small `tl.dot` for the inner accumulator |
| 78 | + for _ in tl.range(0, BLOCK_Q // BLOCK_K): |
| 79 | + A_block = tl.load(A_block_ptrs) # (BLOCK_M, BLOCK_K) |
| 80 | + B_block = tl.load(B_block_ptrs) # (BLOCK_K, BLOCK_N) |
| 81 | + inner_accumulator = tl.dot( |
| 82 | + A_block, B_block, inner_accumulator |
| 83 | + ) # (BLOCK_M, BLOCK_N) |
| 84 | + |
| 85 | + # Move along the K dimension of A, B |
| 86 | + A_block_ptrs += BLOCK_K * A_stride_k |
| 87 | + B_block_ptrs += BLOCK_K * B_stride_k |
| 88 | + |
| 89 | + A_scales = tl.load(A_scale_block_ptrs) # (BLOCK_M, 1) |
| 90 | + B_scales = tl.load(B_scale_block_ptrs) # () |
| 91 | + master_accumulator += inner_accumulator * (A_scales * B_scales) |
| 92 | + |
| 93 | + # Move along the K dimension of A, B scales |
| 94 | + A_scale_block_ptrs += A_scale_stride_k |
| 95 | + B_scale_block_ptrs += B_scale_stride_k |
| 96 | + |
| 97 | + # Store the result for the current pid |
| 98 | + block_offsets_m = ( |
| 99 | + program_id_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] |
| 100 | + ) # (BLOCK_M, 1) |
| 101 | + block_offsets_n = ( |
| 102 | + program_id_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] |
| 103 | + ) # (1, BLOCK_N) |
| 104 | + mask = (block_offsets_m < M) & (block_offsets_n < N) # (BLOCK_M, BLOCK_N) |
| 105 | + C_block_ptrs = C_ptr + (block_offsets_m * C_stride_m + block_offsets_n * C_stride_n) |
| 106 | + tl.store(C_block_ptrs, master_accumulator, mask=mask) |
| 107 | + |
| 108 | + |
| 109 | +@torch.compile(dynamic=False, mode="max-autotune-no-cudagraphs") |
| 110 | +def contiguous(x): |
| 111 | + return x.contiguous() |
| 112 | + |
| 113 | + |
| 114 | +def get_config(M, N, K): |
| 115 | + num_blocks_ref = (M // 128) * (N // 128) |
| 116 | + TRANSPOSE = False |
| 117 | + matrix_instr_nonkdim = 16 |
| 118 | + BLOCK_M, BLOCK_N, BLOCK_K = (128, 128, 64) |
| 119 | + if num_blocks_ref * 8 < NUM_SMS: # 2 and 7 |
| 120 | + BLOCK_M, BLOCK_N, BLOCK_K = (32, 64, 128) |
| 121 | + matrix_instr_nonkdim = 16 |
| 122 | + elif num_blocks_ref < NUM_SMS: |
| 123 | + BLOCK_M, BLOCK_N, BLOCK_K = (64, 64, 64) |
| 124 | + |
| 125 | + config = dict( |
| 126 | + BLOCK_M=BLOCK_M, |
| 127 | + BLOCK_N=BLOCK_N, |
| 128 | + BLOCK_K=BLOCK_K, |
| 129 | + waves_per_eu=2, |
| 130 | + matrix_instr_nonkdim=matrix_instr_nonkdim, |
| 131 | + num_warps=4, |
| 132 | + num_stages=2, |
| 133 | + TRANSPOSE=TRANSPOSE, |
| 134 | + ) |
| 135 | + return config |
| 136 | + |
| 137 | + |
| 138 | +def custom_kernel(data: input_t) -> output_t: |
| 139 | + A_tensor, B_tensor, A_scale_tensor, B_scale_tensor, C_tensor = data |
| 140 | + |
| 141 | + M, K = A_tensor.shape |
| 142 | + N, _ = B_tensor.shape |
| 143 | + |
| 144 | + # heuristic |
| 145 | + config = get_config(M, N, K) |
| 146 | + |
| 147 | + num_blocks = triton.cdiv(M, config["BLOCK_M"]) * triton.cdiv(N, config["BLOCK_N"]) |
| 148 | + kernel[(num_blocks,)]( |
| 149 | + A_tensor, B_tensor, A_scale_tensor, B_scale_tensor, C_tensor, M, N, K, **config |
| 150 | + ) |
| 151 | + |
| 152 | + return C_tensor |
0 commit comments