diff --git a/problems/nvidia/eval.py b/problems/nvidia/eval.py new file mode 100644 index 0000000..6286f7f --- /dev/null +++ b/problems/nvidia/eval.py @@ -0,0 +1,489 @@ +import base64 +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional +import tempfile + +import torch.cuda +from cutlass.cute.nvgpu.common import OpError + +from utils import set_seed, clear_l2_cache + +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, "w") + os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a + b) * (a + b + 1) // 2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + pass + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg) ** 2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats( + runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) + ) + + +def _clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x) for x in data) + elif isinstance(data, list): + return [_clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + + data = generate_input(**test.args) + torch.cuda.synchronize() + try: + submission_output = custom_kernel(_clone_data(data)) + + except OpError as E: + print(f"Encountered {E}", file=sys.stderr) + return False, str(E) + torch.cuda.synchronize() + return check_implementation(data, submission_output) + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + return pool.apply(_run_single_test, (test,)) + + +def run_testing( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes the actual test case code and checks for correctness. + + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + # Step 1: Compile kernel once before running tests + # compile_success, compile_error = pool.apply(_compile_kernel_once) + # if not compile_success: + # return 112 + + # Step 2: Run all tests with compiled kernel + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _compile_kernel_once(): + """ + Compile the kernel once before any benchmarking. + This ensures compilation time is not included in benchmark results. + """ + from submission import compile_kernel + + try: + compile_kernel() + torch.cuda.synchronize() + return True, None + except OpError as E: + return False, f"Compilation failed: {E}" + except Exception as E: + return False, f"Compilation failed: {E}" + + +def _run_single_benchmark( + test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float +) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel, compile_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data) + + # Ensure kernel is compiled before any timing (compilation is cached) + # try: + # a, b, c = data + # compile_kernel(a, b, c) + # torch.cuda.synchronize() + # except OpError as E: + # return f"Compilation failed: {E}" + # except Exception as E: + # return f"Compilation failed: {E}" + + # first, one obligatory correctness check + try: + output = custom_kernel(_clone_data(data)) + except OpError as E: + return f"Encountered {E}" + good, message = check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 200 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if ( + stats.err / stats.mean < 0.001 + or stats.mean * stats.runs > max_time_ns + or total_bm_duration > 120e9 + ): + break + + return calculate_stats(durations) + + +def run_single_benchmark( + pool: multiprocessing.Pool, + test: TestCase, + recheck: bool, + max_repeats: int, + max_time_ns: float, +): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + + +def run_benchmarking( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # Step 1: Compile kernel once (outside of timing) + # compile_success, compile_error = pool.apply(_compile_kernel_once) + # if not compile_success: + # return 112 + + # Step 2: Warm up with compiled kernel + run_single_benchmark(pool, tests[0], False, 200, 10e7) + + # Step 3: Run benchmarks (compilation time excluded) + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 200, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log( + f"benchmark.{idx}.report", + base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), + ) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + seed = int(seed) if seed else None + set_seed(seed or 42) + + # filename = None + + # with tempfile.NamedTemporaryFile(delete=False) as tmp: + + # def build_test_string(tests: list[dict]): + # as_str = "" + # for test in tests: + # kvs = [] + # for k, v in test.items(): + # kvs.append(f"{k}: {v}") + # as_str += "; ".join(kvs) + "\n" + # return as_str + + # import yaml + # print(sys.argv[2]) + # print(open(sys.argv[2], "r").read()) + + # yaml_content = yaml.safe_load(open(sys.argv[2], "r")) + # if mode == "test": + # tests_str = build_test_string(yaml_content.get("tests", [])) + # elif mode in ("benchmark", "leaderboard", "profile"): + # tests_str = build_test_string(yaml_content.get("benchmarks", [])) + + # tmp.write(tests_str.encode("utf-8")) + # tmp.flush() + # filename = tmp.name + + + tests = get_test_cases(sys.argv[2], seed) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + + mp_context = multiprocessing.get_context("spawn") + with mp_context.Pool(1) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # Step 1: Compile kernel once (outside of timing) + # compile_success, compile_error = pool.apply(_compile_kernel_once) + # if not compile_success: + # return 112 + + # Step 2: Warmup with compiled kernel + run_single_benchmark(pool, tests[0], False, 200, 1e7) + + # Step 3: Run leaderboard benchmarks (compilation time excluded) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 200, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log( + f"benchmark.{i}.{field.name}", + getattr(result, field.name), + ) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log( + f"benchmark.{i}.error", str(result) + ) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # TODO: Implement script mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/problems/nvidia/nvfp4_dual_gemm/reference.py b/problems/nvidia/nvfp4_dual_gemm/reference.py new file mode 100644 index 0000000..4901af6 --- /dev/null +++ b/problems/nvidia/nvfp4_dual_gemm/reference.py @@ -0,0 +1,194 @@ +import torch +from task import input_t, output_t +from utils import make_match_reference + +# Scaling factor vector size +sf_vec_size = 16 + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + +# Helper function to convert scale factor tensor to blocked format +def to_blocked(input_matrix): + rows, cols = input_matrix.shape + + # Please ensure rows and cols are multiples of 128 and 4 respectively + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + padded = input_matrix + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() + + +def ref_kernel( + data: input_t, +) -> output_t: + """ + PyTorch reference implementation of NVFP4 block-scaled dual GEMM with silu activation, + C = silu(A @ B1) * (A @ B2). + """ + a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, _, _, _, c_ref = data + + # Get dimensions from MxNxL layout + m, n, l = c_ref.shape + + # Call torch._scaled_mm to compute the GEMV result + ref1 = torch.empty( + (l, m, n), + dtype=torch.float32, + device="cuda", + ).permute(1, 2, 0) + ref2 = torch.empty( + (l, m, n), + dtype=torch.float32, + device="cuda", + ).permute(1, 2, 0) + for l_idx in range(l): + # Convert the scale factor tensor to blocked format + scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx]) + scale_b1 = to_blocked(sfb1_ref_cpu[:, :, l_idx]) + scale_b2 = to_blocked(sfb2_ref_cpu[:, :, l_idx]) + # (m, k) @ (n, k).T -> (m, n) + res1 = torch._scaled_mm( + a_ref[:, :, l_idx], + b1_ref[:, :, l_idx].transpose(0, 1), + scale_a.cuda(), + scale_b1.cuda(), + bias=None, + out_dtype=torch.float32, + ) + ref1[:, :, l_idx] = res1 + + res2 = torch._scaled_mm( + a_ref[:, :, l_idx], + b2_ref[:, :, l_idx].transpose(0, 1), + scale_a.cuda(), + scale_b2.cuda(), + bias=None, + out_dtype=torch.float32, + ) + ref2[:, :, l_idx] = res2 + # Do silu on the first GEMM result and multiply with the second GEMM result + c_ref = (torch.nn.functional.silu(ref1) * ref2).to(torch.float16) + return c_ref + + +def generate_input( + m: int, + n: int, + k: int, + l: int, + seed: int, +): + """ + Generate input tensors for NVFP4 block-scaled dual GEMM with silu activation, + C = silu(A @ B1) * (A @ B2). + + Args: + m: Number of rows in matrix A + n: Number of columns in matrix B1 and B2 + k: Number of columns in A and rows of B1 and B2 + l: Batch size + seed: Random seed for reproducibility + + Returns: + Tuple of (a, b, scale_a, scale_b, c) where: + a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + b1: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + b2: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b1: [n, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b2: [n, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b1_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b2_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + c: [m, n, l] - Output matrix in torch.float16 data type + """ + torch.manual_seed(seed) + + # Generate uint8 tensor, then convert to float4e2m1fn_x2 data type + a_ref = torch.randint( + 0, 2, (l, m, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + b1_ref = torch.randint( + 0, 2, (l, n, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + b2_ref = torch.randint( + 0, 2, (l, n, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + a_ref = a_ref.view(torch.float4_e2m1fn_x2) + b1_ref = b1_ref.view(torch.float4_e2m1fn_x2) + b2_ref = b2_ref.view(torch.float4_e2m1fn_x2) + + # Create float16 output tensor + c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute( + 1, 2, 0 + ) + + # Helper function to prepare the scale factor tensors for both reference + # kernel and customize kernel. The customized data layout can be found in: + # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout + def create_scale_factor_tensors(l, mn, sf_k): + # Create the reference scale factor tensor (mn, sf_k, l) on CPU. + ref_shape = (l, mn, sf_k) + ref_permute_order = (1, 2, 0) + # Init with uint8 tensor, then convert to float8_e4m3fn + ref_f8_random_int = torch.randint(1, 3, ref_shape, dtype=torch.int8, device='cuda') + ref_f8_torch_tensor = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) + # permute to match ref_permute_order + ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order) + + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + # Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout + # Which is needed by the CuTe customized kernel + mma_permute_order = (3, 4, 1, 5, 2, 0) + # Generate a random int8 tensor, then convert to float8_e4m3fn + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8, device='cuda') + reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + # Permute according to mma_permute_order + reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order) + + # GPU-side vectorized reordering (replaces slow CPU nested loops) + # Create index grids for all dimensions + i_idx = torch.arange(mn, device='cuda') + j_idx = torch.arange(sf_k, device='cuda') + b_idx = torch.arange(l, device='cuda') + + # Create meshgrid for all combinations of (i, j, b) + i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij') + + # Calculate target indices in vectorized manner + mm = i_grid // (atom_m[0] * atom_m[1]) + mm32 = i_grid % atom_m[0] + mm4 = (i_grid % 128) // atom_m[0] + kk = j_grid // atom_k + kk4 = j_grid % atom_k + + # Perform the reordering with advanced indexing (all on GPU) + reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid] + + return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor + + sf_k = ceil_div(k, sf_vec_size) + sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k) + sfb1_ref_cpu, sfb1_ref_permuted = create_scale_factor_tensors(l, n, sf_k) + sfb2_ref_cpu, sfb2_ref_permuted = create_scale_factor_tensors(l, n, sf_k) + + return (a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, sfa_ref_permuted, sfb1_ref_permuted, sfb2_ref_permuted, c_ref) + + +check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) \ No newline at end of file diff --git a/problems/nvidia/nvfp4_dual_gemm/submission.py b/problems/nvidia/nvfp4_dual_gemm/submission.py new file mode 100644 index 0000000..f733212 --- /dev/null +++ b/problems/nvidia/nvfp4_dual_gemm/submission.py @@ -0,0 +1,957 @@ +from torch._higher_order_ops.torchbind import call_torchbind_fake +import cuda.bindings.driver as cuda + +import torch +from task import input_t, output_t + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.runtime import make_ptr + +# Kernel configuration parameters +# Tile sizes for M, N, K dimensions +mma_tiler_mnk= (128, 128, 256) +# Shape of the K dimension for the MMA instruction +mma_inst_shape_k = 64 +# FP4 data type for A and B +ab_dtype = cutlass.Float4E2M1FN +# FP8 data type for scale factors +sf_dtype = cutlass.Float8E4M3FN +# FP16 output type +c_dtype = cutlass.Float16 +# Scale factor block size (16 elements share one scale) +sf_vec_size = 16 +# Number of threads per CUDA thread block +threads_per_cta = 128 +# Stage numbers of shared memory and tmem +num_acc_stage = 1 +num_ab_stage = 1 +# Total number of columns in tmem +num_tmem_alloc_cols = 512 + + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + + +# GPU device kernel +@cute.kernel +def kernel( + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b1: cute.CopyAtom, + mB_nkl1: cute.Tensor, + tma_atom_b2: cute.CopyAtom, + mB_nkl2: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb1: cute.CopyAtom, + mSFB_nkl1: cute.Tensor, + tma_atom_sfb2: cute.CopyAtom, + mSFB_nkl2: cute.Tensor, + mC_mnl: cute.Tensor, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + num_tma_load_bytes: cutlass.Constexpr[int], + epilogue_op: cutlass.Constexpr = lambda x: x + * (1.0 / (1.0 + cute.math.exp(-x, fastmath=True))), +): + """ + GPU device kernel performing the batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + tidx = cute.arch.thread_idx() + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + + # Coords outside cluster + cta_coord = (bidx, bidy, bidz) + mma_tile_coord_mnl = ( + cta_coord[0] // cute.size(tiled_mma.thr_id.shape), + cta_coord[1], + cta_coord[2], + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Define shared storage for kernel + # + @cute.struct + class SharedStorage: + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_ab_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_acc_stage * 2] + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + # (MMA, MMA_M, MMA_K, STAGE) + sA = smem.allocate_tensor( + element_type=ab_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB1 = smem.allocate_tensor( + element_type=ab_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB2 = smem.allocate_tensor( + element_type=ab_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfa_smem_layout_staged, + byte_alignment=128, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB1 = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfb_smem_layout_staged, + byte_alignment=128, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB2 = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfb_smem_layout_staged, + byte_alignment=128, + ) + + # + # Initialize mainloop ab_pipeline, acc_pipeline and their states + # + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=num_tma_load_bytes, + ).make_participants() + acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=num_acc_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, + threads_per_cta, + ), + ).make_participants() + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl1 = cute.local_tile( + mB_nkl1, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl2 = cute.local_tile( + mB_nkl2, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + gSFB_nkl1 = cute.local_tile( + mSFB_nkl1, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl2 = cute.local_tile( + mSFB_nkl2, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/SFA/SFB/C + # + # (MMA, MMA_M, MMA_K, RestK) + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB1 = thr_mma.partition_B(gB_nkl1) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB2 = thr_mma.partition_B(gB_nkl2) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB1 = thr_mma.partition_B(gSFB_nkl1) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB2 = thr_mma.partition_B(gSFB_nkl2) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B/SFA/SFB + # + # TMA Partition_S/D for A + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + 0, + cute.make_layout(1), + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA Partition_S/D for B1 + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB1, tBgB1 = cpasync.tma_partition( + tma_atom_b1, + 0, + cute.make_layout(1), + cute.group_modes(sB1, 0, 3), + cute.group_modes(tCgB1, 0, 3), + ) + # TMA Partition_S/D for B2 + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB2, tBgB2 = cpasync.tma_partition( + tma_atom_b2, + 0, + cute.make_layout(1), + cute.group_modes(sB2, 0, 3), + cute.group_modes(tCgB2, 0, 3), + ) + # TMA Partition_S/D for SFA + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cpasync.tma_partition( + tma_atom_sfa, + 0, + cute.make_layout(1), + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + # TMA Partition_S/D for SFB1 + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB1, tBgSFB1 = cpasync.tma_partition( + tma_atom_sfb1, + 0, + cute.make_layout(1), + cute.group_modes(sSFB1, 0, 3), + cute.group_modes(tCgSFB1, 0, 3), + ) + tBsSFB1 = cute.filter_zeros(tBsSFB1) + tBgSFB1 = cute.filter_zeros(tBgSFB1) + # TMA Partition_S/D for SFB2 + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB2, tBgSFB2 = cpasync.tma_partition( + tma_atom_sfb2, + 0, + cute.make_layout(1), + cute.group_modes(sSFB2, 0, 3), + cute.group_modes(tCgSFB2, 0, 3), + ) + tBsSFB2 = cute.filter_zeros(tBsSFB2) + tBgSFB2 = cute.filter_zeros(tBgSFB2) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB1 = tiled_mma.make_fragment_B(sB1) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB2 = tiled_mma.make_fragment_B(sB2) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + + # + # Alloc tensor memory buffer + # Make ACC1 and ACC2 tmem tensor + # ACC1 += A @ B1 + # ACC2 += A @ B2 + # + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=threads_per_cta, + ) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + ) + tmem.allocate(num_tmem_alloc_cols) + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) + tCtAcc1 = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + acc_tmem_ptr1 = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc1), + dtype=cutlass.Float32, + ) + tCtAcc2 = cute.make_tensor(acc_tmem_ptr1, tCtAcc_fake.layout) + + # + # Make SFA/SFB1/SFB2 tmem tensor + # + # SFA tmem layout: (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + # Get SFA tmem ptr + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc1) + + tcgen05.find_tmem_tensor_col_offset(tCtAcc2), + dtype=sf_dtype, + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # SFB1, SFB2 tmem layout: (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + # Get SFB1 tmem ptr + sfb_tmem_ptr1 = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc1) + + tcgen05.find_tmem_tensor_col_offset(tCtAcc2) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=sf_dtype, + ) + tCtSFB1 = cute.make_tensor(sfb_tmem_ptr1, tCtSFB_layout) + # Get SFB2 tmem ptr + sfb_tmem_ptr2 = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc1) + + tcgen05.find_tmem_tensor_col_offset(tCtAcc2) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + + tcgen05.find_tmem_tensor_col_offset(tCtSFB1), + dtype=sf_dtype, + ) + tCtSFB2 = cute.make_tensor(sfb_tmem_ptr2, tCtSFB_layout) + + # + # Partition for S2T copy of SFA/SFB1/SFB2 + # + # Make S2T CopyAtom + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), + sf_dtype, + ) + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFA_compact = cute.filter_zeros(sSFA) + # (MMA, MMA_MN, MMA_K) + tCtSFA_compact = cute.filter_zeros(tCtSFA) + tiled_copy_s2t_sfa = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFA_compact) + thr_copy_s2t_sfa = tiled_copy_s2t_sfa.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFA_compact_s2t_ = thr_copy_s2t_sfa.partition_S(tCsSFA_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFA_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfa, tCsSFA_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFA_compact_s2t = thr_copy_s2t_sfa.partition_D(tCtSFA_compact) + + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFB1_compact = cute.filter_zeros(sSFB1) + # (MMA, MMA_MN, MMA_K) + tCtSFB1_compact = cute.filter_zeros(tCtSFB1) + tiled_copy_s2t_sfb = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFB1_compact) + thr_copy_s2t_sfb = tiled_copy_s2t_sfb.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB1_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB1_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB1_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfb, tCsSFB1_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFB1_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB1_compact) + + # SFB2 S2T copy and partition + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFB2_compact = cute.filter_zeros(sSFB2) + # (MMA, MMA_MN, MMA_K) + tCtSFB2_compact = cute.filter_zeros(tCtSFB2) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB2_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB2_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB2_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfb, tCsSFB2_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFB2_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB2_compact) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB1 = tBgB1[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB2 = tBgB2[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tAgSFA = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgSFB1 = tBgSFB1[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgSFB2 = tBgSFB2[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + # + # Execute Data copy and Math computation in the k_tile loop + # + if warp_idx == 0: + # Wait for accumulator buffer empty + acc_empty = acc_producer.acquire_and_advance() + # Set ACCUMULATE field to False for the first k_tile iteration + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + # Execute k_tile loop + for k_tile in range(k_tile_cnt): + # Wait for AB buffer empty + ab_empty = ab_producer.acquire_and_advance() + + # TMA load A/B1/B2/SFA/SFB1/SFB2 to shared memory + cute.copy( + tma_atom_a, + tAgA[(None, ab_empty.count)], + tAsA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_b1, + tBgB1[(None, ab_empty.count)], + tBsB1[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_b2, + tBgB2[(None, ab_empty.count)], + tBsB2[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_sfa, + tAgSFA[(None, ab_empty.count)], + tAsSFA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_sfb1, + tBgSFB1[(None, ab_empty.count)], + tBsSFB1[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_sfb2, + tBgSFB2[(None, ab_empty.count)], + tBsSFB2[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + + # Wait for AB buffer full + ab_full = ab_consumer.wait_and_advance() + + # Copy SFA/SFB1/SFB2 to tmem + s2t_stage_coord = (None, None, None, None, ab_full.index) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB1_compact_s2t_staged = tCsSFB1_compact_s2t[s2t_stage_coord] + tCsSFB2_compact_s2t_staged = tCsSFB2_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB1_compact_s2t_staged, + tCtSFB1_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB2_compact_s2t_staged, + tCtSFB2_compact_s2t, + ) + + # tCtAcc1 += tCrA * tCrSFA * tCrB1 * tCrSFB1 + # tCtAcc2 += tCrA * tCrSFA * tCrB2 * tCrSFB2 + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_full.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB1[sf_kblock_coord].iterator, + ) + cute.gemm( + tiled_mma, + tCtAcc1, + tCrA[kblock_coord], + tCrB1[kblock_coord], + tCtAcc1, + ) + + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB2[sf_kblock_coord].iterator, + ) + cute.gemm( + tiled_mma, + tCtAcc2, + tCrA[kblock_coord], + tCrB2[kblock_coord], + tCtAcc2, + ) + + # Enable accumulate on tCtAcc1/tCtAcc2 after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_full.release() + acc_empty.commit() + + # + # Epilogue + # Partition for epilogue + # + op = tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE) + copy_atom_t2r = cute.make_copy_atom(op, cutlass.Float32) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc1) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc1 = thr_copy_t2r.partition_S(tCtAcc1) + # (T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc2 = thr_copy_t2r.partition_S(tCtAcc2) + # (T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(tCgC) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc1 = cute.make_rmem_tensor( + tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 + ) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc2 = cute.make_rmem_tensor( + tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 + ) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rC = cute.make_rmem_tensor( + tTR_gC[None, None, None, None, 0, 0, 0].shape, c_dtype + ) + # STG Atom + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), c_dtype) + tTR_gC = tTR_gC[(None, None, None, None, *mma_tile_coord_mnl)] + + # Wait for accumulator buffer full + acc_full = acc_consumer.wait_and_advance() + + # Copy accumulator to register + cute.copy(tiled_copy_t2r, tTR_tAcc1, tTR_rAcc1) + cute.copy(tiled_copy_t2r, tTR_tAcc2, tTR_rAcc2) + + # Silu activation on acc1 and multiply with acc2 + acc_vec1 = epilogue_op(tTR_rAcc1.load()) + acc_vec2 = tTR_rAcc2.load() + acc_vec = acc_vec1 * acc_vec2 + + tTR_rC.store(acc_vec.to(c_dtype)) + # Store C to global memory + cute.copy(simt_atom, tTR_rC, tTR_gC) + + acc_full.release() + # Deallocate TMEM + cute.arch.barrier() + tmem.free(acc_tmem_ptr) + return + + +@cute.jit +def my_kernel( + a_ptr: cute.Pointer, + b1_ptr: cute.Pointer, + b2_ptr: cute.Pointer, + sfa_ptr: cute.Pointer, + sfb1_ptr: cute.Pointer, + sfb2_ptr: cute.Pointer, + c_ptr: cute.Pointer, + problem_size: tuple, + epilogue_op: cutlass.Constexpr = lambda x: x + * (1.0 / (1.0 + cute.math.exp(-x, fastmath=True))), +): + """ + Host-side JIT function to prepare tensors and launch GPU kernel. + """ + m, n, k, l = problem_size + + # Setup attributes that depend on gemm inputs + a_tensor = cute.make_tensor( + a_ptr, + cute.make_layout( + (m, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(m * k, 32)), + ), + ) + b_tensor1 = cute.make_tensor( + b1_ptr, + cute.make_layout( + (n, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(n * k, 32)), + ), + ) + b_tensor2 = cute.make_tensor( + b2_ptr, + cute.make_layout( + (n, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(n * k, 32)), + ), + ) + c_tensor = cute.make_tensor( + c_ptr, cute.make_layout((cute.assume(m, 32), n, l), stride=(n, 1, m * n)) + ) + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + a_tensor.shape, sf_vec_size + ) + sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + b_tensor1.shape, sf_vec_size + ) + sfb_tensor1 = cute.make_tensor(sfb1_ptr, sfb_layout) + sfb_tensor2 = cute.make_tensor(sfb2_ptr, sfb_layout) + + mma_op = tcgen05.MmaMXF4NVF4Op( + sf_dtype, + (mma_tiler_mnk[0], mma_tiler_mnk[1], mma_inst_shape_k), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((1, 1, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute A/B/SFA/SFB/C shared memory layout + a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + ab_dtype, + num_ab_stage, + ) + # B1 and B2 have the same size thus share the same smem layout + b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + ab_dtype, + num_ab_stage, + ) + sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + num_ab_stage, + ) + # SFB1 and SFB2 have the same size thus share the same smem layout + sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + num_ab_stage, + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA for A + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + a_tensor, + a_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk .shape, + ) + # Setup TMA for B1 + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b1, tma_tensor_b1 = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + b_tensor1, + b_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk .shape, + ) + # Setup TMA for B2 + tma_atom_b2, tma_tensor_b2 = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + b_tensor2, + b_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk .shape, + ) + # Setup TMA for SFA + sfa_smem_layout = cute.slice_( + sfa_smem_layout_staged , (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + sfa_tensor, + sfa_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk .shape, + internal_type=cutlass.Int16, + ) + # Setup TMA for SFB1 + sfb_smem_layout = cute.slice_( + sfb_smem_layout_staged , (None, None, None, 0) + ) + tma_atom_sfb1, tma_tensor_sfb1 = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + sfb_tensor1, + sfb_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk .shape, + internal_type=cutlass.Int16, + ) + # Setup TMA for SFB2 + tma_atom_sfb2, tma_tensor_sfb2 = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + sfb_tensor2, + sfb_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk .shape, + internal_type=cutlass.Int16, + ) + + # Compute TMA load bytes + a_copy_size = cute.size_in_bytes(ab_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(ab_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(sf_dtype, sfb_smem_layout) + num_tma_load_bytes = ( + a_copy_size + b_copy_size * 2 + sfa_copy_size + sfb_copy_size * 2 + ) * atom_thr_size + + # Compute grid size + grid = ( + cute.ceil_div(c_tensor.shape[0], mma_tiler_mnk[0]), + cute.ceil_div(c_tensor.shape[1], mma_tiler_mnk[1]), + c_tensor.shape[2], + ) + + # Launch the kernel. + kernel( + # MMA (Matrix Multiply-Accumulate) configuration + tiled_mma, # Tiled MMA object defining NVFP4 GEMM compute pattern + + # TMA (Tensor Memory Accelerator) atoms and tensors for shared input matrix A + tma_atom_a, # TMA copy atom defining how to load A from global memory + tma_tensor_a, # Tensor descriptor for A matrix (m, k, l) - shared by both GEMMs + + # TMA atoms and tensors for first B matrix (B1) + tma_atom_b1, # TMA copy atom defining how to load B1 from global memory + tma_tensor_b1, # Tensor descriptor for B1 matrix (n, k, l) - first GEMM + + # TMA atoms and tensors for second B matrix (B2) + tma_atom_b2, # TMA copy atom defining how to load B2 from global memory + tma_tensor_b2, # Tensor descriptor for B2 matrix (n, k, l) - second GEMM + + # TMA atoms and tensors for scale factor A (shared) + tma_atom_sfa, # TMA copy atom for loading scale factors for A + tma_tensor_sfa, # Tensor descriptor for SFA (block scale factors for A) - shared + + # TMA atoms and tensors for scale factor B1 + tma_atom_sfb1, # TMA copy atom for loading scale factors for B1 + tma_tensor_sfb1, # Tensor descriptor for SFB1 (block scale factors for B1) + + # TMA atoms and tensors for scale factor B2 + tma_atom_sfb2, # TMA copy atom for loading scale factors for B2 + tma_tensor_sfb2, # Tensor descriptor for SFB2 (block scale factors for B2) + + # Output tensor C (stores both C1 and C2 results) + c_tensor, # Output tensor where both GEMM results will be stored (m, n, l) + + # Shared memory layouts with staging for pipelined execution + a_smem_layout_staged, # Staged shared memory layout for A (includes stage dimension) + b_smem_layout_staged, # Staged shared memory layout for B1/B2 (includes stage dimension) + sfa_smem_layout_staged, # Staged shared memory layout for SFA (includes stage dimension) + sfb_smem_layout_staged, # Staged shared memory layout for SFB1/SFB2 (includes stage dimension) + + # Pipeline synchronization parameter + num_tma_load_bytes, # Total bytes to load per TMA transaction (for barrier setup) + + # Epilogue operation + epilogue_op, # Epilogue operation to apply to output (e.g., element-wise ops) + ).launch( + grid=grid, + block=[threads_per_cta, 1, 1], + cluster=(1, 1, 1), + ) + return + + +# Global cache for compiled kernel +_compiled_kernel_cache = None +# This function is used to compile the kernel once and cache it and then allow users to +# run the kernel multiple times to get more accurate timing results. +def compile_kernel(): + """ + Compile the kernel once and cache it. + This should be called before any timing measurements. + + Returns: + The compiled kernel function + """ + global _compiled_kernel_cache + + if _compiled_kernel_cache is not None: + return _compiled_kernel_cache + + + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + b1_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + b2_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + c_ptr = make_ptr( + c_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + sfa_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + sfb1_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + sfb2_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + + # Compile the kernel + _compiled_kernel_cache = cute.compile(my_kernel, a_ptr, b1_ptr, b2_ptr, sfa_ptr, sfb1_ptr, sfb2_ptr, c_ptr, (0, 0, 0, 0)) + + return _compiled_kernel_cache + + +def custom_kernel(data: input_t) -> output_t: + """ + Execute the block-scaled dual GEMM kernel with silu activation, + C = silu(A @ B1) * (A @ B2). + + This is the main entry point called by the evaluation framework. + It converts PyTorch tensors to CuTe tensors, launches the kernel, + and returns the result. + + Args: + data: Tuple of (a, b1, b2, sfa_cpu, sfb1_cpu, sfb2_cpu, c) PyTorch tensors + a: [m, k, l] - Input matrix in float4e2m1fn + b1: [n, k, l] - Input matrix in float4e2m1fn + b2: [n, k, l] - Input matrix in float4e2m1fn + sfa_cpu: [m, k, l] - Scale factors in float8_e4m3fn, used by reference implementation + sfb1_cpu: [n, k, l] - Scale factors in float8_e4m3fn, used by reference implementation + sfb2_cpu: [n, k, l] - Scale factors in float8_e4m3fn, used by reference implementation + sfa_permuted: [32, 4, rest_m, 4, rest_k, l] - Scale factors in float8_e4m3fn + sfb1_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn + sfb2_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn + c: [m, n, l] - Output vector in float16 + + Returns: + Output tensor c with computed results + """ + a, b1, b2, _, _, _, sfa_permuted, sfb1_permuted, sfb2_permuted, c = data + + # Ensure kernel is compiled (will use cached version if available) + # To avoid the compilation overhead, we compile the kernel once and cache it. + compiled_func = compile_kernel() + + # Get dimensions from MxKxL layout + _, k, _ = a.shape + m, n, l = c.shape + # Torch use e2m1_x2 data type, thus k is halved + k = k * 2 + + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + b1_ptr = make_ptr( + ab_dtype, b1.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + b2_ptr = make_ptr( + ab_dtype, b2.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + c_ptr = make_ptr( + c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + sfa_ptr = make_ptr( + sf_dtype, sfa_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + sfb1_ptr = make_ptr( + sf_dtype, sfb1_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + sfb2_ptr = make_ptr( + sf_dtype, sfb2_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + + # Execute the compiled kernel + compiled_func(a_ptr, b1_ptr, b2_ptr, sfa_ptr, sfb1_ptr, sfb2_ptr, c_ptr, (m, n, k, l)) + + return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_dual_gemm/task.py b/problems/nvidia/nvfp4_dual_gemm/task.py new file mode 100644 index 0000000..8facfb0 --- /dev/null +++ b/problems/nvidia/nvfp4_dual_gemm/task.py @@ -0,0 +1,11 @@ +import torch +from typing import TypedDict, TypeVar + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +output_t = TypeVar("output_t", bound=torch.Tensor) +class TestSpec(TypedDict): + m: int + n: int + k: int + l: int + seed: int \ No newline at end of file diff --git a/problems/nvidia/nvfp4_dual_gemm/task.yml b/problems/nvidia/nvfp4_dual_gemm/task.yml new file mode 100644 index 0000000..2bac435 --- /dev/null +++ b/problems/nvidia/nvfp4_dual_gemm/task.yml @@ -0,0 +1,63 @@ +# name: nvfp4-dual-gemm + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + + You will implement a block scaled dual matrix-matrix multiplication kernel with silu activation optimized for NVIDIA B200. + To be explicit, you will be given a tuple of tensors: + ``` + (a, b1, b2, sfa, sfb1, sfb2, c) + ``` + where: + * `a` is M x K x L in K-major order in nvfp4(e2m1) + * `b1` is N x K x L in K-major order in nvfp4(e2m1) + * `b2` is N x K x L in K-major order in nvfp4(e2m1) + * `sfa` is M x (K // 16) x L in K-major order in fp8(e4m3fnuz) + * `sfb1` is N x (K // 16) x L in K-major order in fp8(e4m3fnuz) + * `sfb2` is N x (K // 16) x L in K-major order in fp8(e4m3fnuz) + * `c` is M x N x L in fp16 + + Matrix sizes `M` is divisible by mma_tiler_mn[0], `N` is divisible by mma_tiler_mn[1], `K` is divisible by 256. + The ranking criteria is the geometric mean of the benchmark results. + For the grand price, your kernel will be evaluated against the speed of light analysis + and the solution closest to the speed of light will be awarded the grand price. + ``` + The speed of light analysis based on the max(FP4 Tensor Core math throughput, DRAM memory throughput) of B200 and tested under 1.5Ghz clock: + M N K L time[us] + 128 4096 7168 1 4.505 + 512 4096 7168 1 8.714 + 128 3072 4096 1 1.984 + 512 3072 7168 1 6.535 + ``` +config: + main: "eval.py" + +templates: + Python: "template.py" + +tests: + - {"m": 128, "n": 256, "k": 256, "l": 1, "seed": 1111} + - {"m": 128, "n": 1536, "k": 7168, "l": 1, "seed": 1111} + - {"m": 128, "n": 3072, "k": 1536, "l": 1, "seed": 1111} + - {"m": 256, "n": 7168, "k": 256, "l": 1, "seed": 1111} + - {"m": 256, "n": 7168, "k": 2048, "l": 1, "seed": 1111} + - {"m": 2304, "n": 4608, "k": 7168, "l": 1, "seed": 1111} + - {"m": 384, "n": 7168, "k": 2304, "l": 1, "seed": 1111} + - {"m": 512, "n": 512, "k": 7168, "l": 1, "seed": 1111} + - {"m": 512, "n": 4096, "k": 512, "l": 1, "seed": 1111} + - {"m": 512, "n": 1536, "k": 7168, "l": 1, "seed": 1111} + +benchmarks: + - {"m": 7168, "n": 128, "k": 16384, "l": 1, "seed": 1111} + - {"m": 4096, "n": 128, "k": 7168, "l": 1, "seed": 1111} + - {"m": 7168, "n": 128, "k": 2048, "l": 1, "seed": 1111} + +ranking_by: "geom" \ No newline at end of file diff --git a/problems/nvidia/nvfp4_dual_gemm/template.py b/problems/nvidia/nvfp4_dual_gemm/template.py new file mode 100644 index 0000000..d8985df --- /dev/null +++ b/problems/nvidia/nvfp4_dual_gemm/template.py @@ -0,0 +1,28 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + """ + Reference implementation of block-scale fp4 dual gemm with silu activation + Args: + data: Tuple that expands to: + a: torch.Tensor[float4e2m1fn] of shape [m, k, l], + b1: torch.Tensor[float4e2m1fn] of shape [n, k, l], + b2: torch.Tensor[float4e2m1fn] of shape [n, k, l], + sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], used by reference implementation + sfb1: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], used by reference implementation + sfb2: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], used by reference implementation + sfa_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_m, 4, rest_k, l], + sfb1_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l], + sfb2_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l], + c: torch.Tensor[float16] of shape [m, n, l] + Returns: + Tensor containing output in float16 + c: torch.Tensor[float16] of shape [m, n, l] + """ + # c: [m, n, l] is pre-allocated memory to avoid timing allocation overhead. + a, b1, b2, sfa, sfb1, sfb2, sfa_permuted, sfb1_permuted, sfb2_permuted, c = data + + # Your implementation here + + return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/eval.py b/problems/nvidia/nvfp4_gemm/eval.py new file mode 100644 index 0000000..072b176 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/eval.py @@ -0,0 +1,437 @@ +import base64 +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional +import tempfile + +import torch.cuda +from cutlass.cute.nvgpu.common import OpError + +from utils import set_seed, clear_l2_cache + +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, "w") + os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a + b) * (a + b + 1) // 2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + pass + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg) ** 2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats( + runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) + ) + + +def _clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x) for x in data) + elif isinstance(data, list): + return [_clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + + data = generate_input(**test.args) + torch.cuda.synchronize() + try: + submission_output = custom_kernel(_clone_data(data)) + + except OpError as E: + print(f"Encountered {E}", file=sys.stderr) + return False, str(E) + torch.cuda.synchronize() + return check_implementation(data, submission_output) + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + return pool.apply(_run_single_test, (test,)) + + +def run_testing( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes the actual test case code and checks for correctness. + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _run_single_benchmark( + test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float +) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data) + # first, one obligatory correctness check + try: + output = custom_kernel(_clone_data(data)) + except OpError as E: + return f"Encountered {E}" + good, message = check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 100 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if ( + stats.err / stats.mean < 0.001 + or stats.mean * stats.runs > max_time_ns + or total_bm_duration > 120e9 + ): + break + + return calculate_stats(durations) + + +def run_single_benchmark( + pool: multiprocessing.Pool, + test: TestCase, + recheck: bool, + max_repeats: int, + max_time_ns: float, +): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + + +def run_benchmarking( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # warm up + run_single_benchmark(pool, tests[0], False, 100, 10e7) + + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 100, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log( + f"benchmark.{idx}.report", + base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), + ) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + seed = int(seed) if seed else None + set_seed(seed or 42) + + filename = None + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + + def build_test_string(tests: list[dict]): + as_str = "" + for test in tests: + kvs = [] + for k, v in test.items(): + kvs.append(f"{k}: {v}") + as_str += "; ".join(kvs) + "\n" + return as_str + + import yaml + + yaml_content = yaml.safe_load(open(sys.argv[2], "r")) + if mode == "test": + tests_str = build_test_string(yaml_content.get("tests", [])) + elif mode in ("benchmark", "leaderboard", "profile"): + tests_str = build_test_string(yaml_content.get("benchmarks", [])) + + tmp.write(tests_str.encode("utf-8")) + tmp.flush() + filename = tmp.name + + tests = get_test_cases(filename, seed) + + os.unlink(filename) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + + mp_context = multiprocessing.get_context("spawn") + with mp_context.Pool(1) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # warmup + run_single_benchmark(pool, tests[0], False, 100, 1e7) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 100, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log( + f"benchmark.{i}.{field.name}", + getattr(result, field.name), + ) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log( + f"benchmark.{i}.error", str(result) + ) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # TODO: Implement script mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/reference.py b/problems/nvidia/nvfp4_gemm/reference.py new file mode 100644 index 0000000..51dd750 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/reference.py @@ -0,0 +1,161 @@ +import torch +from task import input_t, output_t +from utils import make_match_reference + +# Scaling factor vector size +sf_vec_size = 16 + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + +# Helper function to convert scale factor tensor to blocked format +def to_blocked(input_matrix): + rows, cols = input_matrix.shape + + # Please ensure rows and cols are multiples of 128 and 4 respectively + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + padded = input_matrix + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() + + +def ref_kernel( + data: input_t, +) -> output_t: + """ + PyTorch reference implementation of NVFP4 block-scaled GEMM. + """ + a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, _, _, c_ref = data + + # Get dimensions from MxNxL layout + _, _, l = c_ref.shape + + # Call torch._scaled_mm to compute the GEMM result + for l_idx in range(l): + # Convert the scale factor tensor to blocked format + scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx]) + scale_b = to_blocked(sfb_ref_cpu[:, :, l_idx]) + # (m, k) @ (n, k).T -> (m, n) + res = torch._scaled_mm( + a_ref[:, :, l_idx], + b_ref[:, :, l_idx].transpose(0, 1), + scale_a.cuda(), + scale_b.cuda(), + bias=None, + out_dtype=torch.float16, + ) + c_ref[:, :, l_idx] = res + return c_ref + + +def generate_input( + m: int, + n: int, + k: int, + l: int, + seed: int, +): + """ + Generate input tensors for NVFP4 block-scaled GEMM. + + Args: + m: Number of rows in matrix A + n: Number of columns in matrix B + k: Number of columns in A and rows of B + l: Batch size + seed: Random seed for reproducibility + + Returns: + Tuple of (a, b, scale_a, scale_b, c) where: + a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + b: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b: [n, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + c: [m, n, l] - Output matrix in torch.float16 data type + """ + torch.manual_seed(seed) + + # Generate uint8 tensor, then convert to float4e2m1fn_x2 data type + a_ref = torch.randint( + 0, 2, (l, m, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + b_ref = torch.randint( + 0, 2, (l, n, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + a_ref = a_ref.view(torch.float4_e2m1fn_x2) + b_ref = b_ref.view(torch.float4_e2m1fn_x2) + + # Create float16 output tensor + c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute( + 1, 2, 0 + ) + + # Helper function to prepare the scale factor tensors for both reference + # kernel and customize kernel. The customized data layout can be found in: + # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout + def create_scale_factor_tensors(l, mn, sf_k): + # Create the reference scale factor tensor (mn, sf_k, l) on CPU. + ref_shape = (l, mn, sf_k) + ref_permute_order = (1, 2, 0) + # Init with uint8 tensor, then convert to float8_e4m3fn + ref_f8_random_int = torch.randint(1, 3, ref_shape, dtype=torch.int8, device='cuda') + ref_f8_torch_tensor = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) + # permute to match ref_permute_order + ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order) + + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + # Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout + # Which is needed by the CuTe customized kernel + mma_permute_order = (3, 4, 1, 5, 2, 0) + # Generate a random int8 tensor, then convert to float8_e4m3fn + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8, device='cuda') + reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + # Permute according to mma_permute_order + reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order) + + # GPU-side vectorized reordering (replaces slow CPU nested loops) + # Create index grids for all dimensions + i_idx = torch.arange(mn, device='cuda') + j_idx = torch.arange(sf_k, device='cuda') + b_idx = torch.arange(l, device='cuda') + + # Create meshgrid for all combinations of (i, j, b) + i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij') + + # Calculate target indices in vectorized manner + mm = i_grid // (atom_m[0] * atom_m[1]) + mm32 = i_grid % atom_m[0] + mm4 = (i_grid % 128) // atom_m[0] + kk = j_grid // atom_k + kk4 = j_grid % atom_k + + # Perform the reordering with advanced indexing (all on GPU) + reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid] + + return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor + + sf_k = ceil_div(k, sf_vec_size) + sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k) + sfb_ref_cpu, sfb_ref_permuted = create_scale_factor_tensors(l, n, sf_k) + + return (a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, sfa_ref_permuted, sfb_ref_permuted, c_ref) + + +check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) diff --git a/problems/nvidia/nvfp4_gemm/submission.py b/problems/nvidia/nvfp4_gemm/submission.py new file mode 100644 index 0000000..c2f37d9 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/submission.py @@ -0,0 +1,761 @@ +from torch._higher_order_ops.torchbind import call_torchbind_fake +import cuda.bindings.driver as cuda + +import torch +from task import input_t, output_t + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.runtime import make_ptr + +# Kernel configuration parameters +# Tile sizes for M, N, K dimensions +mma_tiler_mnk = (128, 128, 256) +# Shape of the K dimension for the MMA instruction +mma_inst_shape_k = 64 +# FP4 data type for A and B +ab_dtype = cutlass.Float4E2M1FN +# FP8 data type for scale factors +sf_dtype = cutlass.Float8E4M3FN +# FP16 output type +c_dtype = cutlass.Float16 +# Scale factor block size (16 elements share one scale) +sf_vec_size = 16 +# Number of threads per CUDA thread block +threads_per_cta = 128 +# Stage numbers of shared memory and tmem +num_acc_stage = 1 +num_ab_stage = 1 +# Total number of columns in tmem +num_tmem_alloc_cols = 512 + + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + + +# The CuTe reference implementation for NVFP4 block-scaled GEMM +@cute.kernel +def kernel( + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + mC_mnl: cute.Tensor, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + num_tma_load_bytes: cutlass.Constexpr[int], +): + """ + GPU device kernel performing the batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + tidx = cute.arch.thread_idx() + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + + # Coords outside cluster + cta_coord = (bidx, bidy, bidz) + mma_tile_coord_mnl = ( + cta_coord[0] // cute.size(tiled_mma.thr_id.shape), + cta_coord[1], + cta_coord[2], + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Define shared storage for kernel + # + @cute.struct + class SharedStorage: + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_ab_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_acc_stage * 2] + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + # (MMA, MMA_M, MMA_K, STAGE) + sA = smem.allocate_tensor( + element_type=ab_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = smem.allocate_tensor( + element_type=ab_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfa_smem_layout_staged, + byte_alignment=128, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfb_smem_layout_staged, + byte_alignment=128, + ) + + # + # Initialize mainloop ab_pipeline, acc_pipeline and their states + # + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=num_tma_load_bytes, + ).make_participants() + acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=num_acc_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, + threads_per_cta, + ), + ).make_participants() + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + gSFB_nkl = cute.local_tile( + mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/SFA/SFB/C + # + # (MMA, MMA_M, MMA_K, RestK) + thr_mma = tiled_mma.get_slice(0) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B/SFA/SFB + # + # TMA Partition_S/D for A + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + 0, + cute.make_layout(1), + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA Partition_S/D for B + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + 0, + cute.make_layout(1), + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + # TMA Partition_S/D for SFA + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cpasync.tma_partition( + tma_atom_sfa, + 0, + cute.make_layout(1), + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + # TMA Partition_S/D for SFB + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cpasync.tma_partition( + tma_atom_sfb, + 0, + cute.make_layout(1), + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + + # + # Alloc tensor memory buffer + # + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=threads_per_cta, + ) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + ) + tmem.allocate(num_tmem_alloc_cols) + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) + tCtAcc = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # + # Make SFA/SFB tmem tensor + # + # Get SFA tmem ptr + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc), + dtype=sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + # Get SFB tmem ptr + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + + # + # Partition for S2T copy of SFA/SFB + # + # Make S2T CopyAtom + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), + sf_dtype, + ) + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFA_compact = cute.filter_zeros(sSFA) + # (MMA, MMA_MN, MMA_K) + tCtSFA_compact = cute.filter_zeros(tCtSFA) + tiled_copy_s2t_sfa = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFA_compact) + thr_copy_s2t_sfa = tiled_copy_s2t_sfa.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFA_compact_s2t_ = thr_copy_s2t_sfa.partition_S(tCsSFA_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFA_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfa, tCsSFA_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFA_compact_s2t = thr_copy_s2t_sfa.partition_D(tCtSFA_compact) + + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFB_compact = cute.filter_zeros(sSFB) + # (MMA, MMA_MN, MMA_K) + tCtSFB_compact = cute.filter_zeros(tCtSFB) + tiled_copy_s2t_sfb = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFB_compact) + thr_copy_s2t_sfb = tiled_copy_s2t_sfb.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfb, tCsSFB_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFB_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB_compact) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tAgSFA = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgSFB = tBgSFB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + # + # Execute Data copy and Math computation in the k_tile loop + # + if warp_idx == 0: + # Wait for accumulator buffer empty + acc_empty = acc_producer.acquire_and_advance() + # Set ACCUMULATE field to False for the first k_tile iteration + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + # Execute k_tile loop + for k_tile in range(k_tile_cnt): + # Wait for AB buffer empty + ab_empty = ab_producer.acquire_and_advance() + + # TMA load A/B/SFA/SFB to shared memory + cute.copy( + tma_atom_a, + tAgA[(None, k_tile)], + tAsA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_b, + tBgB[(None, k_tile)], + tBsB[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_sfa, + tAgSFA[(None, k_tile)], + tAsSFA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_sfb, + tBgSFB[(None, k_tile)], + tBsSFB[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + + # Wait for AB buffer full + ab_full = ab_consumer.wait_and_advance() + + # Copy SFA/SFB from shared memory to TMEM + s2t_stage_coord = (None, None, None, None, ab_full.index) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_full.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB[sf_kblock_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_full.release() + acc_empty.commit() + + # + # Epilogue + # Partition for epilogue + # + op = tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE) + copy_atom_t2r = cute.make_copy_atom(op, cutlass.Float32) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + # (T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(tCgC) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 + ) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rC = cute.make_rmem_tensor( + tTR_gC[None, None, None, None, 0, 0, 0].shape, c_dtype + ) + # STG Atom + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), c_dtype) + tTR_gC = tTR_gC[(None, None, None, None, *mma_tile_coord_mnl)] + + # Wait for accumulator buffer full + acc_full = acc_consumer.wait_and_advance() + + # Copy accumulator to register + cute.copy(tiled_copy_t2r, tTR_tAcc, tTR_rAcc) + acc_vec = tTR_rAcc.load().to(c_dtype) + tTR_rC.store(acc_vec) + # Store C to global memory + cute.copy(simt_atom, tTR_rC, tTR_gC) + + acc_full.release() + + # Deallocate TMEM + cute.arch.barrier() + tmem.free(acc_tmem_ptr) + + return + + +@cute.jit +def my_kernel( + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + sfa_ptr: cute.Pointer, + sfb_ptr: cute.Pointer, + c_ptr: cute.Pointer, + problem_size: tuple, +): + """ + Host-side JIT function to prepare tensors and launch GPU kernel. + """ + m, n, k, l = problem_size + + # Setup attributes that depend on gemm inputs + a_tensor = cute.make_tensor( + a_ptr, + cute.make_layout( + (m, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(m * k, 32)), + ), + ) + b_tensor = cute.make_tensor( + b_ptr, + cute.make_layout( + (n, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(n * k, 32)), + ), + ) + c_tensor = cute.make_tensor( + c_ptr, cute.make_layout((cute.assume(m, 32), n, l), stride=(n, 1, m * n)) + ) + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + a_tensor.shape, sf_vec_size + ) + sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + b_tensor.shape, sf_vec_size + ) + sfb_tensor = cute.make_tensor(sfb_ptr, sfb_layout) + + mma_op = tcgen05.MmaMXF4NVF4Op( + sf_dtype, + (mma_tiler_mnk[0], mma_tiler_mnk[1], mma_inst_shape_k), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((1, 1, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute A/B/SFA/SFB/C shared memory layout + a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + ab_dtype, + num_ab_stage, + ) + b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + ab_dtype, + num_ab_stage, + ) + sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + num_ab_stage, + ) + sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + num_ab_stage, + ) + + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA for A + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + a_tensor, + a_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + ) + # Setup TMA for B + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + b_tensor, + b_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + ) + # Setup TMA for SFA + sfa_smem_layout = cute.slice_( + sfa_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + sfa_tensor, + sfa_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + # Setup TMA for SFB + sfb_smem_layout = cute.slice_( + sfb_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + sfb_tensor, + sfb_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # Compute TMA load bytes + a_copy_size = cute.size_in_bytes(ab_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(ab_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(sf_dtype, sfb_smem_layout) + num_tma_load_bytes = ( + a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size + ) * atom_thr_size + + # Compute grid size + grid = ( + cute.ceil_div(c_tensor.shape[0], mma_tiler_mnk[0]), + cute.ceil_div(c_tensor.shape[1], mma_tiler_mnk[1]), + c_tensor.shape[2], + ) + + # Launch the kernel + kernel( + # MMA (Matrix Multiply-Accumulate) configuration + tiled_mma, # Tiled MMA object defining NVFP4 GEMM compute pattern + + # TMA (Tensor Memory Accelerator) atoms and tensors for input matrix A + tma_atom_a, # TMA copy atom defining how to load A from global memory + tma_tensor_a, # Tensor descriptor for A matrix (m, k, l) + + # TMA atoms and tensors for input matrix B + tma_atom_b, # TMA copy atom defining how to load B from global memory + tma_tensor_b, # Tensor descriptor for B matrix (n, k, l) + + # TMA atoms and tensors for scale factor A + tma_atom_sfa, # TMA copy atom for loading scale factors for A + tma_tensor_sfa, # Tensor descriptor for SFA (block scale factors for A) + + # TMA atoms and tensors for scale factor B + tma_atom_sfb, # TMA copy atom for loading scale factors for B + tma_tensor_sfb, # Tensor descriptor for SFB (block scale factors for B) + + # Output tensor C + c_tensor, # Output tensor C where result will be stored (m, n, l) + + # Shared memory layouts with staging for pipelined execution + a_smem_layout_staged, # Staged shared memory layout for A (includes stage dimension) + b_smem_layout_staged, # Staged shared memory layout for B (includes stage dimension) + sfa_smem_layout_staged, # Staged shared memory layout for SFA (includes stage dimension) + sfb_smem_layout_staged, # Staged shared memory layout for SFB (includes stage dimension) + + # Pipeline synchronization parameter + num_tma_load_bytes, # Total bytes to load per TMA transaction (for barrier setup) + ).launch( + grid=grid, + block=[threads_per_cta, 1, 1], + cluster=(1, 1, 1), + ) + return + + +# Global cache for compiled kernel +_compiled_kernel_cache = None +# This function is used to compile the kernel once and cache it and then allow users to +# run the kernel multiple times to get more accurate timing results. +def compile_kernel(): + """ + Compile the kernel once and cache it. + This should be called before any timing measurements. + + Returns: + The compiled kernel function + """ + global _compiled_kernel_cache + + if _compiled_kernel_cache is not None: + return _compiled_kernel_cache + + + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + b_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + c_ptr = make_ptr( + c_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + sfa_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + sfb_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + + # Compile the kernel + _compiled_kernel_cache = cute.compile(my_kernel, a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (0, 0, 0, 0)) + + return _compiled_kernel_cache + + +def custom_kernel(data: input_t) -> output_t: + """ + Execute the block-scaled GEMM kernel. + + This is the main entry point called by the evaluation framework. + It converts PyTorch tensors to CuTe tensors, launches the kernel, + and returns the result. + + Args: + data: Tuple of (a, b, sfa_ref, sfb_ref, sfa_permuted, sfb_permuted, c) PyTorch tensors + a: [m, k, l] - Input matrix in float4e2m1fn + b: [n, k, l] - Input vector in float4e2m1fn + sfa_ref: [m, k, l] - Scale factors in float8_e4m3fn, used by reference implementation + sfb_ref: [n, k, l] - Scale factors in float8_e4m3fn, used by reference implementation + sfa_permuted: [32, 4, rest_m, 4, rest_k, l] - Scale factors in float8_e4m3fn + sfb_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn + c: [m, n, l] - Output vector in float16 + + Returns: + Output tensor c with computed results + """ + a, b, _, _, sfa_permuted, sfb_permuted, c = data + + # Ensure kernel is compiled (will use cached version if available) + # To avoid the compilation overhead, we compile the kernel once and cache it. + compiled_func = compile_kernel() + + # Get dimensions from MxKxL layout + m, k, l = a.shape + n, _, _ = b.shape + # Torch use e2m1_x2 data type, thus k is halved + k = k * 2 + + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + b_ptr = make_ptr( + ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + c_ptr = make_ptr( + c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + sfa_ptr = make_ptr( + sf_dtype, sfa_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + sfb_ptr = make_ptr( + sf_dtype, sfb_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + + # Execute the compiled kernel + compiled_func(a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l)) + + return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/task.py b/problems/nvidia/nvfp4_gemm/task.py new file mode 100644 index 0000000..66db735 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/task.py @@ -0,0 +1,11 @@ +import torch +from typing import TypedDict, TypeVar + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +output_t = TypeVar("output_t", bound=torch.Tensor) +class TestSpec(TypedDict): + m: int + n: int + k: int + l: int + seed: int \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/task.yml b/problems/nvidia/nvfp4_gemm/task.yml new file mode 100644 index 0000000..06388bb --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/task.yml @@ -0,0 +1,60 @@ +# name: nvfp4-block-scaled-gemm + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + + You will implement a block scaled matrix-matrix multiplication kernel optimized for NVIDIA B200. + To be explicit, you will be given a tuple of tensors: + ``` + (a, b, sfa, sfb, c) + ``` + where: + * `a` is M x K x L in K-major order in nvfp4(e2m1) + * `b` is N x K x L in K-major order in nvfp4(e2m1) + * `sfa` is M x (K // 16) x L in K-major order in fp8(e4m3fnuz) + * `sfb` is N x (K // 16) x L in K-major order in fp8(e4m3fnuz) + * `c` is M x N x L in fp16 + + Matrix sizes `M` is divisible by mma_tiler_mn[0], `N` is divisible by mma_tiler_mn[1], `K` is divisible by 256. + The ranking criteria is the geometric mean of the benchmark results. + For the grand price, your kernel will be evaluated against the speed of light analysis + and the solution closest to the speed of light will be awarded the grand price. + ``` + The speed of light analysis based on the max(FP4 Tensor Core math throughput, DRAM memory throughput) of B200 and tested under 1.5Ghz clock: + M N K L time[us] + 128 7168 16384 1 8.994 + 128 4096 7168 1 2.354 + 128 7168 2048 1 1.333 + ``` +config: + main: "eval.py" + +templates: + Python: "template.py" + +tests: + - {"m": 128, "n": 256, "k": 256, "l": 1, "seed": 1111} + - {"m": 128, "n": 1536, "k": 7168, "l": 1, "seed": 1111} + - {"m": 128, "n": 3072, "k": 1536, "l": 1, "seed": 1111} + - {"m": 256, "n": 7168, "k": 256, "l": 1, "seed": 1111} + - {"m": 256, "n": 7168, "k": 2048, "l": 1, "seed": 1111} + - {"m": 2304, "n": 4608, "k": 7168, "l": 1, "seed": 1111} + - {"m": 384, "n": 7168, "k": 2304, "l": 1, "seed": 1111} + - {"m": 512, "n": 512, "k": 7168, "l": 1, "seed": 1111} + - {"m": 512, "n": 4096, "k": 512, "l": 1, "seed": 1111} + - {"m": 512, "n": 1536, "k": 7168, "l": 1, "seed": 1111} + +benchmarks: + - {"m": 7168, "n": 128, "k": 16384, "l": 1, "seed": 1111} + - {"m": 4096, "n": 128, "k": 7168, "l": 1, "seed": 1111} + - {"m": 7168, "n": 128, "k": 2048, "l": 1, "seed": 1111} + +ranking_by: "geom" \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/template.py b/problems/nvidia/nvfp4_gemm/template.py new file mode 100644 index 0000000..3855d69 --- /dev/null +++ b/problems/nvidia/nvfp4_gemm/template.py @@ -0,0 +1,25 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + """ + Reference implementation of block-scale fp4 gemm + Args: + data: Tuple that expands to: + a: torch.Tensor[float4e2m1fn] of shape [m, k, l], + b: torch.Tensor[float4e2m1fn] of shape [n, k, l], + sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], + sfb: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], + sfa_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_m, 4, rest_k, l], + sfb_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l], + c: torch.Tensor[float16] of shape [m, n, l] + Returns: + Tensor containing output in float16 + c: torch.Tensor[float16] of shape [m, n, l] + """ + # c: [m, n, l] is pre-allocated memory to avoid timing allocation overhead. + a, b, sfa, sfb, sfa_permuted, sfb_permuted, c = data + + # Your implementation here + + return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemv/reference.py b/problems/nvidia/nvfp4_gemv/reference.py new file mode 100644 index 0000000..b01ef18 --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/reference.py @@ -0,0 +1,166 @@ +import torch +from task import input_t, output_t +from utils import make_match_reference + +# Scaling factor vector size +sf_vec_size = 16 + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + + +# Helper function to convert scale factor tensor to blocked format +def to_blocked(input_matrix): + rows, cols = input_matrix.shape + + # Please ensure rows and cols are multiples of 128 and 4 respectively + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + padded = input_matrix + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() + + +def ref_kernel( + data: input_t, +) -> output_t: + """ + PyTorch reference implementation of NVFP4 block-scaled GEMV. + """ + a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, _, _, c_ref = data + + # Get dimensions from MxNxL layout + _, _, l = c_ref.shape + + # Call torch._scaled_mm to compute the GEMV result + for l_idx in range(l): + # Convert the scale factor tensor to blocked format + scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx]) + scale_b = to_blocked(sfb_ref_cpu[:, :, l_idx]) + # (m, k) @ (n, k).T -> (m, n) + res = torch._scaled_mm( + a_ref[:, :, l_idx], + b_ref[:, :, l_idx].transpose(0, 1), + scale_a.cuda(), + scale_b.cuda(), + bias=None, + out_dtype=torch.float16, + ) + c_ref[:, 0, l_idx] = res[:, 0] + return c_ref + + +def generate_input( + m: int, + k: int, + l: int, + seed: int, +): + """ + Generate input tensors for NVFP4 block-scaled GEMV. + + Args: + m: Number of rows in matrix A + k: Number of columns in A (and length of vector b) + l: Batch size + seed: Random seed for reproducibility + + Returns: + Tuple of (a, b, scale_a, scale_b, c) where: + a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + b: [1, k, l] - Input vector in torch.float4e2m1fn_x2 data type + scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b: [1, k, l] - Input scale factors in torch.float8e4m3fn data type + scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + scale_b_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + c: [m, 1, l] - Output vector in torch.float16 data type + """ + torch.manual_seed(seed) + + # GEMV N dimension is always 1 + n = 1 + # Scaling factor needs to pad the N size to 128 + n_padded_128 = 128 + + # Generate uint8 tensor, then convert to float4e2m1fn_x2 data type + a_ref = torch.randint( + 0, 2, (l, m, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + # Pad b tensor's N dimension to 128 to call torch._scaled_mm for nvfp4 dot product computation + b_ref = torch.randint( + 0, 2, (l, n_padded_128, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + a_ref = a_ref.view(torch.float4_e2m1fn_x2) + b_ref = b_ref.view(torch.float4_e2m1fn_x2) + + # Create float16 output tensor + c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute( + 1, 2, 0 + ) + + # Helper function to prepare the scale factor tensors for both reference + # kernel and customize kernel. The customized data layout can be found in: + # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout + def create_scale_factor_tensors(l, mn, sf_k): + # Create the reference scale factor tensor (mn, sf_k, l) on CPU. + ref_shape = (l, mn, sf_k) + ref_permute_order = (1, 2, 0) + # Init with uint8 tensor, then convert to float8_e4m3fn + ref_f8_random_int = torch.randint(1, 3, ref_shape, dtype=torch.int8, device='cuda') + ref_f8_torch_tensor = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) + # permute to match ref_permute_order + ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order) + + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + # Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout + # Which is needed by the CuTe customized kernel + mma_permute_order = (3, 4, 1, 5, 2, 0) + # Generate a random int8 tensor, then convert to float8_e4m3fn + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8, device='cuda') + reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + # Permute according to mma_permute_order + reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order) + + # GPU-side vectorized reordering (replaces slow CPU nested loops) + # Create index grids for all dimensions + i_idx = torch.arange(mn, device='cuda') + j_idx = torch.arange(sf_k, device='cuda') + b_idx = torch.arange(l, device='cuda') + + # Create meshgrid for all combinations of (i, j, b) + i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij') + + # Calculate target indices in vectorized manner + mm = i_grid // (atom_m[0] * atom_m[1]) + mm32 = i_grid % atom_m[0] + mm4 = (i_grid % 128) // atom_m[0] + kk = j_grid // atom_k + kk4 = j_grid % atom_k + + # Perform the reordering with advanced indexing (all on GPU) + reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid] + + return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor + + sf_k = ceil_div(k, sf_vec_size) + sfa_ref_cpu, sfa_permuted = create_scale_factor_tensors(l, m, sf_k) + sfb_ref_cpu, sfb_permuted = create_scale_factor_tensors(l, n_padded_128, sf_k) + + return (a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, sfa_permuted, sfb_permuted, c_ref) + + +check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemv/submission.py b/problems/nvidia/nvfp4_gemv/submission.py new file mode 100644 index 0000000..9cf2394 --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/submission.py @@ -0,0 +1,266 @@ +from torch._higher_order_ops.torchbind import call_torchbind_fake +import cuda.bindings.driver as cuda + +import torch +from task import input_t, output_t + +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import make_ptr +import cutlass.utils.blockscaled_layout as blockscaled_utils + +# Kernel configuration parameters +mma_tiler_mnk = (128, 1, 64) # Tile sizes for M, N, K dimensions +ab_dtype = cutlass.Float4E2M1FN # FP4 data type for A and B +sf_dtype = cutlass.Float8E4M3FN # FP8 data type for scale factors +c_dtype = cutlass.Float16 # FP16 output type +sf_vec_size = 16 # Scale factor block size (16 elements share one scale) +threads_per_cta = 128 # Number of threads per CUDA thread block + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + + +# The CuTe reference implementation for NVFP4 block-scaled GEMV +@cute.kernel +def kernel( + mA_mkl: cute.Tensor, + mB_nkl: cute.Tensor, + mSFA_mkl: cute.Tensor, + mSFB_nkl: cute.Tensor, + mC_mnl: cute.Tensor, +): + # Get CUDA block and thread indices + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + + # Extract the local tile for input matrix A (shape: [block_M, block_K, rest_M, rest_K, rest_L]) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # Extract the local tile for scale factor tensor for A (same shape as gA_mkl) + # Here, block_M = (32, 4); block_K = (16, 4) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # Extract the local tile for input matrix B (shape: [block_N, block_K, rest_N, rest_K, rest_L]) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # Extract the local tile for scale factor tensor for B (same shape as gB_nkl) + gSFB_nkl = cute.local_tile( + mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # Extract the local tile for output matrix C (shape: [block_M, block_N, rest_M, rest_N, rest_L]) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) + ) + + # Select output element corresponding to this thread and block indices + tCgC = gC_mnl[tidx, None, bidx, bidy, bidz] + tCgC = cute.make_tensor(tCgC.iterator, 1) + res = cute.zeros_like(tCgC, cutlass.Float32) + + # Get the number of k tiles (depth dimension) for the reduction loop + k_tile_cnt = gA_mkl.layout[3].shape + for k_tile in range(k_tile_cnt): + tAgA = gA_mkl[tidx, None, bidx, k_tile, bidz] + tBgB = gB_nkl[0, None, bidy, k_tile, bidz] + tAgSFA = gSFA_mkl[tidx, None, bidx, k_tile, bidz] + tBgSFB = gSFB_nkl[0, None, bidy, k_tile, bidz] + + tArA = cute.make_rmem_tensor_like(tAgA, cutlass.Float32) + tBrB = cute.make_rmem_tensor_like(tBgB, cutlass.Float32) + tArSFA = cute.make_rmem_tensor_like(tAgSFA, cutlass.Float32) + tBrSFB = cute.make_rmem_tensor_like(tBgSFB, cutlass.Float32) + + # Load NVFP4 or FP8 values from global memory + a_val_nvfp4 = tAgA.load() + b_val_nvfp4 = tBgB.load() + sfa_val_fp8 = tAgSFA.load() + sfb_val_fp8 = tBgSFB.load() + + # Convert loaded values to float32 for computation (FFMA) + a_val = a_val_nvfp4.to(cutlass.Float32) + b_val = b_val_nvfp4.to(cutlass.Float32) + sfa_val = sfa_val_fp8.to(cutlass.Float32) + sfb_val = sfb_val_fp8.to(cutlass.Float32) + + # Store the converted values to RMEM CuTe tensors + tArA.store(a_val) + tBrB.store(b_val) + tArSFA.store(sfa_val) + tBrSFB.store(sfb_val) + + # Iterate over SF vector tiles and compute the scale&matmul accumulation + for i in cutlass.range_constexpr(mma_tiler_mnk[2]): + res += tArA[i] * tArSFA[i] * tBrB[i] * tBrSFB[i] + + # Store the final float16 result back to global memory + tCgC.store(res.to(cutlass.Float16)) + return + + +@cute.jit +def my_kernel( + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + sfa_ptr: cute.Pointer, + sfb_ptr: cute.Pointer, + c_ptr: cute.Pointer, + problem_size: tuple, +): + """ + Host-side JIT function to prepare tensors and launch GPU kernel. + """ + m, _, k, l = problem_size + # Create CuTe Tensor via pointer and problem size. + a_tensor = cute.make_tensor( + a_ptr, + cute.make_layout( + (m, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(m * k, 32)), + ), + ) + # We use n=128 to create the torch tensor to do fp4 computation via torch._scaled_mm + # then copy torch tensor to cute tensor for cute customize kernel computation + # therefore we need to ensure b_tensor has the right stride with this 128 padded size on n. + n_padded_128 = 128 + b_tensor = cute.make_tensor( + b_ptr, + cute.make_layout( + (n_padded_128, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(n_padded_128 * k, 32)), + ), + ) + c_tensor = cute.make_tensor( + c_ptr, cute.make_layout((cute.assume(m, 32), 1, l), stride=(1, 1, m)) + ) + # Convert scale factor tensors to MMA layout + # The layout matches Tensor Core requirements: (((32, 4), REST_M), ((SF_K, 4), REST_K), (1, REST_L)) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + a_tensor.shape, sf_vec_size + ) + sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout) + + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + b_tensor.shape, sf_vec_size + ) + sfb_tensor = cute.make_tensor(sfb_ptr, sfb_layout) + + # Compute grid dimensions + # Grid is (M_blocks, 1, L) where: + # - M_blocks = ceil(M / 128) to cover all output rows + # - L = batch size + grid = ( + cute.ceil_div(c_tensor.shape[0], 128), + 1, + c_tensor.shape[2], + ) + + # Launch the CUDA kernel + kernel(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor).launch( + grid=grid, + block=[threads_per_cta, 1, 1], + cluster=(1, 1, 1), + ) + return + + +# Global cache for compiled kernel +_compiled_kernel_cache = None +# This function is used to compile the kernel once and cache it and then allow users to +# run the kernel multiple times to get more accurate timing results. +def compile_kernel(): + """ + Compile the kernel once and cache it. + This should be called before any timing measurements. + + Returns: + The compiled kernel function + """ + global _compiled_kernel_cache + + if _compiled_kernel_cache is not None: + return _compiled_kernel_cache + + + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + b_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + c_ptr = make_ptr( + c_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + sfa_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + sfb_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + + # Compile the kernel + _compiled_kernel_cache = cute.compile(my_kernel, a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (0, 0, 0, 0)) + + return _compiled_kernel_cache + + +def custom_kernel(data: input_t) -> output_t: + """ + Execute the block-scaled GEMV kernel. + + This is the main entry point called by the evaluation framework. + It converts PyTorch tensors to CuTe tensors, launches the kernel, + and returns the result. + + Args: + data: Tuple of (a, b, sfa_cpu, sfb_cpu, c) PyTorch tensors + a: [m, k, l] - Input matrix in float4e2m1fn + b: [1, k, l] - Input vector in float4e2m1fn + sfa_cpu: [m, k, l] - Scale factors in float8_e4m3fn + sfb_cpu: [1, k, l] - Scale factors in float8_e4m3fn + sfa_permuted: [32, 4, rest_m, 4, rest_k, l] - Scale factors in float8_e4m3fn + sfb_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn + c: [m, 1, l] - Output vector in float16 + + Returns: + Output tensor c with computed GEMV results + """ + a, b, _, _, sfa_permuted, sfb_permuted, c = data + + # Ensure kernel is compiled (will use cached version if available) + # To avoid the compilation overhead, we compile the kernel once and cache it. + compiled_func = compile_kernel() + + # Get dimensions from MxKxL layout + m, k, l = a.shape + # Torch use e2m1_x2 data type, thus k is halved + k = k * 2 + # GEMV N dimension is always 1 + n = 1 + + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + b_ptr = make_ptr( + ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + c_ptr = make_ptr( + c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + sfa_ptr = make_ptr( + sf_dtype, sfa_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + sfb_ptr = make_ptr( + sf_dtype, sfb_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + + # Execute the compiled kernel + compiled_func(a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l)) + + return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemv/task.py b/problems/nvidia/nvfp4_gemv/task.py new file mode 100644 index 0000000..c1a06e3 --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/task.py @@ -0,0 +1,10 @@ +import torch +from typing import TypedDict, TypeVar + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +output_t = TypeVar("output_t", bound=torch.Tensor) +class TestSpec(TypedDict): + m: int + k: int + l: int + seed: int \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemv/task.yml b/problems/nvidia/nvfp4_gemv/task.yml new file mode 100644 index 0000000..756fe80 --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/task.yml @@ -0,0 +1,59 @@ +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + + You will implement a batched matrix-vector multiplication kernel optimized for NVIDIA B200. + To be explicit, you will be given a tuple of tensors: + ``` + (a, b, sfa, sfb, c) + ``` + where: + * `a` is M x K x L in K-major order in nvfp4(e2m1) + * `b` is 1 x K x L in K-major order in nvfp4(e2m1) + * `sfa` is M x (K // 16) x L in K-major order in fp8(e4m3fnuz) + * `sfb` is 1 x (K // 16) x L in K-major order in fp8(e4m3fnuz) + * `c` is M x 1 x L in fp16 + + Matrix sizes `M` is divisible by mma_tiler_mn[0] defined in the kernel, `K` is divisible by 64. + The ranking criteria is the geometric mean of the benchmark results. + For the grand price, your kernel will be evaluated against the speed of light analysis + and the solution closest to the speed of light will be awarded the grand price. + ``` + The speed of light analysis based on the max(FFMA math throughput, DRAM memory throughput) of B200 and tested under 1.5Ghz clock: + M K L time[us] + 7168 16384 1 8.622 + 4096 7168 8 17.275 + 7168 2048 4 4.317 + ``` +config: + main: "eval.py" + +templates: + Python: "template.py" + +tests: + - {"m": 128, "k": 256, "l": 1, "seed": 1111} + - {"m": 128, "k": 1536, "l": 1, "seed": 1111} + - {"m": 128, "k": 3072, "l": 1, "seed": 1111} + - {"m": 256, "k": 7168, "l": 1, "seed": 1111} + - {"m": 256, "k": 7168, "l": 1, "seed": 1111} + - {"m": 2432, "k": 4608, "l": 2, "seed": 1111} + - {"m": 384, "k": 7168, "l": 2, "seed": 1111} + - {"m": 512, "k": 512, "l": 2, "seed": 1111} + - {"m": 512, "k": 4096, "l": 2, "seed": 1111} + - {"m": 512, "k": 1536, "l": 2, "seed": 1111} + + +benchmarks: + - {"m": 7168, "k": 16384, "l":1, "seed": 1111} + - {"m": 4096, "k": 7168, "l":8, "seed": 1111} + - {"m": 7168, "k": 2048, "l":4, "seed": 1111} + +ranking_by: "geom" diff --git a/problems/nvidia/nvfp4_gemv/template.py b/problems/nvidia/nvfp4_gemv/template.py new file mode 100644 index 0000000..acb8228 --- /dev/null +++ b/problems/nvidia/nvfp4_gemv/template.py @@ -0,0 +1,25 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + """ + Reference implementation of block-scale fp8 gemv + Args: + data: Tuple that expands to: + a: torch.Tensor[float4e2m1fn] of shape [m, k, l], + b: torch.Tensor[float4e2m1fn] of shape [1, k, l], + sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], used by reference implementation + sfb: torch.Tensor[float8_e4m3fnuz] of shape [1, k // 16, l], used by reference implementation + sfa_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_m, 4, rest_k, l], + sfb_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l], + c: torch.Tensor[float16] of shape [m, 1, l] + Returns: + Tensor containing output in float16 + c: torch.Tensor[float16] of shape [m, 1, l] + """ + # c: [l, m, 1] is pre-allocated memory to avoid timing allocation overhead. + a, b, sfa, sfb, sfa_permuted, sfb_permuted, c = data + + # Your implementation here + + return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_group_gemm/reference.py b/problems/nvidia/nvfp4_group_gemm/reference.py new file mode 100644 index 0000000..f71da00 --- /dev/null +++ b/problems/nvidia/nvfp4_group_gemm/reference.py @@ -0,0 +1,186 @@ +import torch +from task import input_t, output_t +from utils import make_match_reference + +# Scaling factor vector size +sf_vec_size = 16 + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + + +# Helper function to convert scale factor tensor to blocked format +def to_blocked(input_matrix): + rows, cols = input_matrix.shape + + # Please ensure rows and cols are multiples of 128 and 4 respectively + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + padded = input_matrix + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() + + +def ref_kernel( + data: input_t, +) -> output_t: + """ + PyTorch reference implementation of NVFP4 block-scaled group GEMM. + """ + abc_tensors, sfasfb_tensors, _, problem_sizes = data + + result_tensors = [] + for i, ( + (a_ref, b_ref, c_ref), + (sfa_ref, sfb_ref), + (m, n, k, l), + ) in enumerate( + zip( + abc_tensors, + sfasfb_tensors, + problem_sizes, + ) + ): + for l_idx in range(l): + # Convert the scale factor tensor to blocked format + scale_a = to_blocked(sfa_ref[:, :, l_idx]) + scale_b = to_blocked(sfb_ref[:, :, l_idx]) + # (m, k) @ (n, k).T -> (m, n) + res = torch._scaled_mm( + a_ref[:, :, l_idx].view(torch.float4_e2m1fn_x2), + b_ref[:, :, l_idx].transpose(0, 1).view(torch.float4_e2m1fn_x2), + scale_a.cuda(), + scale_b.cuda(), + bias=None, + out_dtype=torch.float16, + ) + c_ref[:, :, l_idx] = res + result_tensors.append((c_ref)) + return result_tensors + + +# Helper function to prepare the scale factor tensors for both reference +# kernel and customize kernel. The customized data layout can be found in: +# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout +def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): + sf_k = ceil_div(k, sf_vec_size) + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + # Create the reordered scale factor tensor (32, 4, rest_m, 4, rest_k, l) on GPU. + mma_permute_order = (3, 4, 1, 5, 2, 0) + # Generate a random int8 tensor, then convert to float8_e4m3fn + rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8, device='cuda') + reordered_f8_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) + # Permute according to mma_permute_order + reordered_f8_tensor = reordered_f8_tensor.permute(*mma_permute_order) + + # Move ref_f8_tensor to GPU if not already there + if ref_f8_tensor.device.type == 'cpu': + ref_f8_tensor = ref_f8_tensor.cuda() + + # GPU-side vectorized reordering (replaces slow CPU nested loops) + # Create index grids for all dimensions + i_idx = torch.arange(mn, device='cuda') + j_idx = torch.arange(sf_k, device='cuda') + b_idx = torch.arange(l, device='cuda') + + # Create meshgrid for all combinations of (i, j, b) + i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij') + + # Calculate target indices in vectorized manner + mm = i_grid // (atom_m[0] * atom_m[1]) + mm32 = i_grid % atom_m[0] + mm4 = (i_grid % 128) // atom_m[0] + kk = j_grid // atom_k + kk4 = j_grid % atom_k + + # Perform the reordering with advanced indexing (all on GPU) + reordered_f8_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_tensor[i_grid, j_grid, b_grid] + + return reordered_f8_tensor + + +def generate_input( + m: int, + n: int, + k: int, + g: int, + seed: int, +): + """ + Generate input tensors for NVFP4 block-scaled group GEMM. + Each group can have different m, n, k, l. + + Args: + problem_sizes: List of tuples (m, n, k, l) for each problem + m: Number of rows in matrix A + n: Number of columns in matrix B + k: Number of columns in A and rows of B + l: Batch size, always is 1 + groups: Number of groups + seed: Random seed for reproducibility + + Returns: + Tuple of (list(tuple(a, b, c)), list(tuple(sfa, sfb)), list(tuple(sfa_reordered, sfb_reordered)), list(tuple(m, n, k, l))) where each group has its own a, b, c, sfa, sfb. + a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + b: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type + sfa: [m, k // 16, l] - Input scale factors in torch.float8e4m3fn data type + sfb: [n, k // 16, l] - Input scale factors in torch.float8e4m3fn data type + sfa_reordered: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + sfb_reordered: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type + c: [m, n, l] - Output matrix in torch.float16 data type + """ + torch.manual_seed(seed) + + abc_tensors = [] + sfasfb_tensors = [] + sfasfb_reordered_tensors = [] + problem_sizes = [] + l = 1 + # Generate a, b, c, sfa, sfb tensors for all groups + for group_idx in range(g): + a_ref = torch.randint( + 0, 2, (l, m, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + b_ref = torch.randint( + 0, 2, (l, n, k // 2), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + a_ref = a_ref.view(torch.float4_e2m1fn_x2) + b_ref = b_ref.view(torch.float4_e2m1fn_x2) + + c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute( + 1, 2, 0 + ) + + sf_k = ceil_div(k, sf_vec_size) + sfa_ref_cpu = torch.randint( + 1, 3, (l, m, sf_k), dtype=torch.int8 + ).to(dtype=torch.float8_e4m3fn).permute(1, 2, 0) + sfb_ref_cpu = torch.randint( + 1, 3, (l, n, sf_k), dtype=torch.int8 + ).to(dtype=torch.float8_e4m3fn).permute(1, 2, 0) + + sfa_reordered = create_reordered_scale_factor_tensor(l, m, k, sfa_ref_cpu) + sfb_reordered = create_reordered_scale_factor_tensor(l, n, k, sfb_ref_cpu) + + abc_tensors.append((a_ref, b_ref, c_ref)) + sfasfb_tensors.append((sfa_ref_cpu, sfb_ref_cpu)) + sfasfb_reordered_tensors.append((sfa_reordered, sfb_reordered)) + problem_sizes.append((m, n, k, l)) + + return (abc_tensors, sfasfb_tensors, sfasfb_reordered_tensors, problem_sizes) + + +check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) diff --git a/problems/nvidia/nvfp4_group_gemm/submission.py b/problems/nvidia/nvfp4_group_gemm/submission.py new file mode 100644 index 0000000..c9ab35f --- /dev/null +++ b/problems/nvidia/nvfp4_group_gemm/submission.py @@ -0,0 +1,1074 @@ +from torch._higher_order_ops.torchbind import call_torchbind_fake +import cuda.bindings.driver as cuda +import functools +from typing import Tuple, List + +import torch +from task import input_t, output_t + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.runtime import make_ptr + +# Kernel configuration parameters +# Size of tma descriptor in bytes +bytes_per_tensormap = 128 +# Number of tensormaps: a, b, sfa, sfb +num_tensormaps = 4 +# Tile sizes for M, N, K dimensions +mma_tiler_mnk = (128, 128, 256) +# Shape of the K dimension for the MMA instruction +mma_inst_shape_k = 64 +# FP4 data type for A and B +ab_dtype = cutlass.Float4E2M1FN +# FP8 data type for scale factors +sf_dtype = cutlass.Float8E4M3FN +# FP16 output type +c_dtype = cutlass.Float16 +# Scale factor block size (16 elements share one scale) +sf_vec_size = 16 +# Number of threads per CUDA thread block +threads_per_cta = 128 +# Stage numbers of shared memory and tmem +num_acc_stage = 1 +num_ab_stage = 1 +# Total number of columns in tmem +num_tmem_alloc_cols = 512 + + +# Helper function for ceiling division +def ceil_div(a, b): + return (a + b - 1) // b + + +# The CuTe reference implementation for NVFP4 block-scaled GEMM +@cute.kernel +def kernel( + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tensor_of_abc_ptrs: cute.Tensor, + tensor_of_sfasfb_ptrs: cute.Tensor, + tensormaps: cute.Tensor, + tensor_of_problem_sizes: cute.Tensor, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + cta_mn_list: List[Tuple[int, int]], + num_tma_load_bytes: cutlass.Constexpr[int], +): + """ + GPU device kernel performing the Group GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + tidx, _, _ = cute.arch.thread_idx() + + # + # Delinearize bidz to coord_x, coord_y and group_idx for each CTA + # + bidx, bidy, bidz = cute.arch.block_idx() + group_idx = 0 + find = False + coord_x = 0 + coord_y = 0 + cta_rest = bidz + for _, (cta_m, cta_n) in enumerate(cta_mn_list): + if cta_rest >= (cta_m * cta_n): + group_idx += 1 + cta_rest -= cta_m * cta_n + else: + if not find: + coord_y = cta_rest // cta_m + coord_x = cta_rest % cta_m + cta_rest -= cta_m * cta_n + find = True + + # + # Construct C Tensor for each CTA + # + mC_mnl_iter = cute.make_ptr( + c_dtype, tensor_of_abc_ptrs[group_idx, 2], cute.AddressSpace.gmem + ).align(32) + m = tensor_of_problem_sizes[group_idx, 0] + n = tensor_of_problem_sizes[group_idx, 1] + k = tensor_of_problem_sizes[group_idx, 2] + l = tensor_of_problem_sizes[group_idx, 3] + + mC_mnl_layout = cute.make_layout( + ( + m, + n, + l, + ), + stride=( + cute.assume(n, 32), + 1, + m * n, + ), + ) + mC_mnl = cute.make_tensor(mC_mnl_iter, mC_mnl_layout) + # Local partition for global C Tensor + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, 0) + ) + + # + # Define shared storage for kernel + # + size_tensormap_in_i64 = ( + num_tensormaps * bytes_per_tensormap // 8 + ) + @cute.struct + class SharedStorage: + tensormap_buffer: cute.struct.MemRange[ + cutlass.Int64, size_tensormap_in_i64 + ] + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_ab_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_acc_stage * 2] + tmem_holding_buf: cutlass.Int32 + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + tensormap_smem_ptr = storage.tensormap_buffer.data_ptr() + tensormap_a_smem_ptr = tensormap_smem_ptr + tensormap_b_smem_ptr = ( + tensormap_a_smem_ptr + + bytes_per_tensormap // 8 + ) + tensormap_sfa_smem_ptr = ( + tensormap_b_smem_ptr + + bytes_per_tensormap // 8 + ) + tensormap_sfb_smem_ptr = ( + tensormap_sfa_smem_ptr + + bytes_per_tensormap // 8 + ) + # Setup smem tensor for A, B, SFA, SFB + # (MMA, MMA_M, MMA_K, STAGE) + sA = smem.allocate_tensor( + element_type=ab_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = smem.allocate_tensor( + element_type=ab_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfa_smem_layout_staged, + byte_alignment=128, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = smem.allocate_tensor( + element_type=sf_dtype, + layout=sfb_smem_layout_staged, + byte_alignment=128, + ) + + # Initialize mainloop ab_pipeline, acc_pipeline and their states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=num_tma_load_bytes, + ).make_participants() + acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, + threads_per_cta, + ), + ).make_participants() + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # (bM, bK, RestM, RestK, RestL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(0) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # Update tma descriptor with the correct shapes and strides + tensormap_manager = utils.TensorMapManager( + utils.TensorMapUpdateMode.SMEM, + 128, + ) + tensormap_a_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(bidz, 0, None)].iterator + ) + tensormap_b_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(bidz, 1, None)].iterator + ) + tensormap_sfa_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(bidz, 2, None)].iterator + ) + tensormap_sfb_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(bidz, 3, None)].iterator + ) + + mA_mkl_iter = cute.make_ptr( + ab_dtype, tensor_of_abc_ptrs[group_idx, 0], cute.AddressSpace.gmem + ).align(32) + mB_nkl_iter = cute.make_ptr( + ab_dtype, tensor_of_abc_ptrs[group_idx, 1], cute.AddressSpace.gmem + ).align(32) + sfa_mkl_iter = cute.make_ptr( + sf_dtype, tensor_of_sfasfb_ptrs[group_idx, 0], cute.AddressSpace.gmem + ).align(32) + sfb_nkl_iter = cute.make_ptr( + sf_dtype, tensor_of_sfasfb_ptrs[group_idx, 1], cute.AddressSpace.gmem + ).align(32) + mA_mkl_layout = cute.make_layout( + (m, k, l), + stride=( + cute.assume(k, 32), + 1, + cute.assume(m * k, 32), + ), + ) + mB_nkl_layout = cute.make_layout( + (n, k, l), + stride=( + cute.assume(k, 32), + 1, + cute.assume(n * k, 32), + ), + ) + # SFA, SFB follows specialized layout defined in the following link: + # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout + atom_shape = ((32, 4), (sf_vec_size, 4)) + atom_stride = ((16, 4), (0, 1)) + sfa_layout = cute.tile_to_shape( + cute.make_layout(atom_shape, stride=atom_stride), + mA_mkl_layout.shape, + (2, 1, 3), + ) + sfb_layout = cute.tile_to_shape( + cute.make_layout(atom_shape, stride=atom_stride), + mB_nkl_layout.shape, + (2, 1, 3), + ) + real_tensor_a = cute.make_tensor(mA_mkl_iter, mA_mkl_layout) + real_tensor_b = cute.make_tensor(mB_nkl_iter, mB_nkl_layout) + real_tensor_sfa = cute.make_tensor(sfa_mkl_iter, sfa_layout) + real_tensor_sfb = cute.make_tensor(sfb_nkl_iter, sfb_layout) + + # Let warp 0 initialize tensormap + if warp_idx == 0: + tensormap_manager.init_tensormap_from_atom( + tma_atom_a, tensormap_a_smem_ptr, 0 + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_b, tensormap_b_smem_ptr, 0 + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_sfa, tensormap_sfa_smem_ptr, 0 + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_sfb, tensormap_sfb_smem_ptr, 0 + ) + tensormap_manager.update_tensormap( + ( + real_tensor_a, + real_tensor_b, + real_tensor_sfa, + real_tensor_sfb, + ), + (tma_atom_a, tma_atom_b, tma_atom_sfa, tma_atom_sfb), + ( + tensormap_a_gmem_ptr, + tensormap_b_gmem_ptr, + tensormap_sfa_gmem_ptr, + tensormap_sfb_gmem_ptr, + ), + 0, # tma warp id + ( + tensormap_a_smem_ptr, + tensormap_b_smem_ptr, + tensormap_sfa_smem_ptr, + tensormap_sfb_smem_ptr, + ), + ) + + tensormap_manager.fence_tensormap_update(tensormap_a_gmem_ptr) + tensormap_manager.fence_tensormap_update(tensormap_b_gmem_ptr) + tensormap_manager.fence_tensormap_update(tensormap_sfa_gmem_ptr) + tensormap_manager.fence_tensormap_update(tensormap_sfb_gmem_ptr) + + cute.arch.barrier() + + # + # Partition global/shared tensor for TMA load A/B/SFA/SFB + # + # TMA load A partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + 0, + cute.make_layout(1), + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + 0, + cute.make_layout(1), + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # TMALDG_SFA partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cpasync.tma_partition( + tma_atom_sfa, + 0, + cute.make_layout(1), + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + # TMALDG_SFB partition_S/D + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cpasync.tma_partition( + tma_atom_sfb, + 0, + cute.make_layout(1), + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + # + # Alloc tensor memory buffer + # + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=threads_per_cta, + ) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + ) + tmem.allocate(num_tmem_alloc_cols) + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) + tCtAcc = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # + # Make SFA/SFB tmem tensor + # + # Get SFA tmem ptr + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc), + dtype=sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + # Get SFB tmem ptr + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + + # + # Partition for S2T copy of SFA/SFB + # + # Make S2T CopyAtom + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), + sf_dtype, + ) + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFA_compact = cute.filter_zeros(sSFA) + tCtSFA_compact = cute.filter_zeros(tCtSFA) + tiled_copy_s2t_sfa = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFA_compact) + thr_copy_s2t_sfa = tiled_copy_s2t_sfa.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFA_compact_s2t_ = thr_copy_s2t_sfa.partition_S(tCsSFA_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFA_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfa, tCsSFA_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFA_compact_s2t = thr_copy_s2t_sfa.partition_D(tCtSFA_compact) + + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSFB_compact = cute.filter_zeros(sSFB) + # (MMA, MMA_MN, MMA_K) + tCtSFB_compact = cute.filter_zeros(tCtSFB) + tiled_copy_s2t_sfb = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFB_compact) + thr_copy_s2t_sfb = tiled_copy_s2t_sfb.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSFB_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfb, tCsSFB_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSFB_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB_compact) + + # Number of K loops + k_tile_cnt = cute.ceil_div(real_tensor_a.shape[1], mma_tiler_mnk[2]) + + # + # Slice to per mma tile index + # + mma_tile_coord_mnl = (coord_x, coord_y, 0) + # ((atom_v, rest_v), RestK) + tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tAgSFA = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgSFB = tBgSFB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + # + # Main loop + # + if warp_idx == 0: + # Wait for accumulator buffer empty + acc_empty = acc_producer.acquire_and_advance() + # Set ACCUMULATE field to False for the first k_tile iteration + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + # Execute k_tile loop + for k_tile in range(k_tile_cnt): + # Wait for AB buffer empty + ab_empty = ab_producer.acquire_and_advance() + + # TMALDG A/B/SFA/SFB + cute.copy( + tma_atom_a, + tAgA[(None, k_tile)], + tAsA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_a_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_b, + tBgB[(None, k_tile)], + tBsB[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_b_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_sfa, + tAgSFA[(None, k_tile)], + tAsSFA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_sfa_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_sfb, + tBgSFB[(None, k_tile)], + tBsSFB[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_sfb_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + + # Wait for AB buffer full + ab_full = ab_consumer.wait_and_advance() + + # Copy SFA/SFB to tmem + s2t_stage_coord = (None, None, None, None, ab_full.index) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_full.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB[sf_kblock_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_full.release() + acc_empty.commit() + + # + # Epilogue + # Partition for epilogue + # + op = tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE) + copy_atom_t2r = cute.make_copy_atom(op, cutlass.Float32) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + # (T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(tCgC) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[None, None, None, None, 0, 0].shape, cutlass.Float32 + ) + # (T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rC = cute.make_fragment(tTR_gC[None, None, None, None, 0, 0].shape, c_dtype) + # STG Atom + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), c_dtype) + tTR_gC = tTR_gC[(None, None, None, None, coord_x, coord_y)] + + # Release TMEM allocation lock + tmem.relinquish_alloc_permit() + + # Wait for accumulator buffer full + acc_full = acc_consumer.wait_and_advance() + + # Copy accumulator to register + cute.copy(tiled_copy_t2r, tTR_tAcc, tTR_rAcc) + acc_vec = tTR_rAcc.load() + tTR_rC.store(acc_vec.to(c_dtype)) + # Store C to global memory + cute.copy(simt_atom, tTR_rC, tTR_gC) + + acc_full.release() + + # Deallocate TMEM + cute.arch.barrier() + tmem.free(acc_tmem_ptr) + pass + + +# Host-side JIT function to prepare tensors and launch GPU kernel. +@cute.jit +def my_kernel( + initial_abc_ptrs: Tuple[cute.Pointer, cute.Pointer, cute.Pointer], + initial_sfasfb_ptrs: Tuple[cute.Pointer, cute.Pointer], + initial_idx: Tuple[int, int, int], + ptr_of_tensor_of_problem_sizes: cute.Pointer, + ptr_of_tensor_of_abc_ptrs: cute.Pointer, + ptr_of_tensor_of_sfasfb_ptrs: cute.Pointer, + total_num_clusters: cutlass.Constexpr[int], + problem_sizes: List[ + Tuple[int, int, int, int] + ], # Problem sizes for each group + tensor_of_tensormap, + num_groups: cutlass.Constexpr[int], +): + + tensor_of_abc_ptrs = cute.make_tensor( + ptr_of_tensor_of_abc_ptrs, cute.make_layout((num_groups, 3), stride=(3, 1)) + ) + tensor_of_sfasfb_ptrs = cute.make_tensor( + ptr_of_tensor_of_sfasfb_ptrs, cute.make_layout((num_groups, 2), stride=(2, 1)) + ) + tensor_of_problem_sizes = cute.make_tensor( + ptr_of_tensor_of_problem_sizes, cute.make_layout((num_groups, 4), stride=(4, 1)) + ) + + a_ptr, b_ptr, _ = initial_abc_ptrs + sfa_ptr, sfb_ptr = initial_sfasfb_ptrs + min_a_idx, min_b_idx, _ = initial_idx + min_a_shape = problem_sizes[0] + min_b_shape = problem_sizes[0] + for group_idx, shape in enumerate(problem_sizes): + if group_idx == min_a_idx: + min_a_shape = shape + if group_idx == min_b_idx: + min_b_shape = shape + + initial_a = cute.make_tensor( + a_ptr, + cute.make_layout( + (min_a_shape[0], cute.assume(min_a_shape[2], 32), min_a_shape[3]), + stride=( + cute.assume(min_a_shape[2], 32), + 1, + cute.assume(min_a_shape[0] * min_a_shape[2], 32), + ), + ), + ) + min_b_shape = problem_sizes[0] + initial_b = cute.make_tensor( + b_ptr, + cute.make_layout( + (min_b_shape[1], cute.assume(min_b_shape[2], 32), min_b_shape[3]), + stride=( + cute.assume(min_b_shape[2], 32), + 1, + cute.assume(min_b_shape[1] * min_b_shape[2], 32), + ), + ), + ) + + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + initial_a.shape, sf_vec_size + ) + initial_sfa = cute.make_tensor(sfa_ptr, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + initial_b.shape, sf_vec_size + ) + initial_sfb = cute.make_tensor(sfb_ptr, sfb_layout) + + # Select MMA operation + mma_op = tcgen05.MmaMXF4NVF4Op( + sf_dtype, + (mma_tiler_mnk[0], mma_tiler_mnk[1], mma_inst_shape_k), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((1, 1, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute A/B/SFA/SFB/C shared memory layout + a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + ab_dtype, + num_ab_stage, + ) + b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + ab_dtype, + num_ab_stage, + ) + sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + num_ab_stage, + ) + sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + num_ab_stage, + ) + + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # TMA load for A + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + initial_a, + a_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + ) + # TMA load for B + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + initial_b, + b_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + ) + + # TMA load for SFA + sfa_smem_layout = cute.slice_( + sfa_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + initial_sfa, + sfa_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # TMA load for SFB + sfb_smem_layout = cute.slice_( + sfb_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), + initial_sfb, + sfb_smem_layout, + mma_tiler_mnk, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + # Compute TMA load bytes + a_copy_size = cute.size_in_bytes(ab_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(ab_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(sf_dtype, sfb_smem_layout) + num_tma_load_bytes = ( + a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size + ) * atom_thr_size + + # Store CTA shape information for each Group in a List + cta_mn_list = [] + for group_idx, (m, n, k, l) in enumerate(problem_sizes): + x, y = cute.ceil_div(problem_sizes[group_idx][:2], mma_tiler_mnk[0:2]) + cta_mn_list.append((x, y)) + + # Compute grid size + grid = (1, 1, total_num_clusters) + + # Launch the kernel + kernel( + # MMA (Matrix Multiply-Accumulate) configuration + tiled_mma, # Tiled MMA object defining NVFP4 GEMM compute pattern + + # TMA (Tensor Memory Accelerator) atoms and tensors for input matrix A + tma_atom_a, # TMA copy atom defining how to load A from global memory + tma_tensor_a, # Tensor descriptor for A (created from smallest A tensor) + + # TMA atoms and tensors for input matrix B + tma_atom_b, # TMA copy atom defining how to load B from global memory + tma_tensor_b, # Tensor descriptor for B (created from smallest B tensor) + + # TMA atoms and tensors for scale factor A + tma_atom_sfa, # TMA copy atom for loading scale factors for A + tma_tensor_sfa, # Tensor descriptor for SFA (block scale factors for A) + + # TMA atoms and tensors for scale factor B + tma_atom_sfb, # TMA copy atom for loading scale factors for B + tma_tensor_sfb, # Tensor descriptor for SFB (block scale factors for B) + + # Runtime tensor metadata for dynamic group access + tensor_of_abc_ptrs, # Device tensor containing pointers to A, B, C for all groups + tensor_of_sfasfb_ptrs, # Device tensor containing pointers to SFA, SFB for all groups + tensor_of_tensormap, # Pre-allocated buffer for tensormap descriptors per CTA + tensor_of_problem_sizes, # Device tensor containing (m, n, k, l) for each group + + # Shared memory layouts with staging for pipelined execution + a_smem_layout_staged, # Staged shared memory layout for A (includes stage dimension) + b_smem_layout_staged, # Staged shared memory layout for B (includes stage dimension) + sfa_smem_layout_staged, # Staged shared memory layout for SFA (includes stage dimension) + sfb_smem_layout_staged, # Staged shared memory layout for SFB (includes stage dimension) + + # CTA grid configuration per group + cta_mn_list, # List of (M_tiles, N_tiles) for each group + + # Pipeline synchronization parameter + num_tma_load_bytes, # Total bytes to load per TMA transaction (for barrier setup) + ).launch( + grid=grid, + block=[threads_per_cta, 1, 1], + cluster=(1, 1, 1), + ) + return + + +_compiled_kernel_cache = None + +def compile_kernel(): + pass + +def custom_kernel(data: input_t) -> output_t: + """ + Execute the block-scaled group GEMM kernel. + + This is the main entry point called by the evaluation framework. + It converts PyTorch tensors to CuTe tensors, launches the kernel, + and returns the result. + + Args: + data: Tuple of (abc_tensors, sfasfb_tensors, problem_sizes) where: + abc_tensors: list of tuples (a, b, c) where + a is torch.Tensor[float4e2m1fn_x2] of shape [m, k // 2, l] + b is torch.Tensor[float4e2m1fn_x2] of shape [n, k // 2, l] + c is torch.Tensor[float16] of shape [m, n, l] + sfasfb_tensors: list of tuples (sfa, sfb) where + sfa is torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l] + sfb is torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l] + problem_sizes: list of tuples (m, n, k, l) + each group has its own a, b, c, sfa, sfb with different m, n, k, l problem sizes + l should always be 1 for each group. + list size is the number of groups. + + Returns: + list of c tensors where c is torch.Tensor[float16] of shape [m, n, l] for each group + """ + abc_tensors, _, sfasfb_reordered_tensors, problem_sizes = data + + # Choose A, B, C, SFA, SFB with the smallest size to create initial tensormaps + key_size_a = lambda item: item[1][0] * item[1][2] + key_size_b = lambda item: item[1][1] * item[1][2] + key_size_c = lambda item: item[1][0] * item[1][1] + # Find the indices of the groups with the smallest tensor sizes + min_a_idx, _ = min(enumerate(problem_sizes), key=key_size_a) + min_b_idx, _ = min(enumerate(problem_sizes), key=key_size_b) + min_c_idx, _ = min(enumerate(problem_sizes), key=key_size_c) + + # Extract raw data pointers from all input tensors for each group + # These will be passed to the GPU kernel to access the actual tensor data + abc_ptrs = [] + sfasfb_ptrs = [] + for i, ((a, b, c), (sfa_reordered, sfb_reordered), (m, n, k, l)) in enumerate(zip(abc_tensors, sfasfb_reordered_tensors, problem_sizes)): + # Store pointers to A, B, and C matrices for this group + abc_ptrs.append((a.data_ptr(), b.data_ptr(), c.data_ptr())) + # Store pointers to scale factor tensors for this group + sfasfb_ptrs.append((sfa_reordered.data_ptr(), sfb_reordered.data_ptr())) + + # Create initial CuTe pointers from the smallest tensors for tensormap initialization + # Using smallest tensors helps with efficient TMA (Tensor Memory Accelerator) setup + # These will be used as templates to create tensormaps for all other tensors + initial_cute_abc_ptrs = ( + # Pointer to the smallest A matrix (FP4 type) + make_ptr( + ab_dtype, + abc_tensors[min_a_idx][0].data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ), + # Pointer to the smallest B matrix (FP4 type) + make_ptr( + ab_dtype, + abc_tensors[min_b_idx][1].data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ), + # Pointer to the smallest C matrix (FP16 type, output) + make_ptr( + c_dtype, + abc_tensors[min_c_idx][2].data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ), + ) + initial_cute_sfasfb_ptrs = ( + # Pointer to the smallest scale factor A tensor (FP8 type) + make_ptr( + sf_dtype, + sfasfb_reordered_tensors[min_a_idx][0].data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ), + # Pointer to the smallest scale factor B tensor (FP8 type) + make_ptr( + sf_dtype, + sfasfb_reordered_tensors[min_b_idx][1].data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ), + ) + + # Create torch tensor to store problem sizes for all groups + # Shape: (num_groups, 4) where each row contains (m, n, k, l) for that group + # Layout: (num_groups, 4):(4, 1) means row-major storage + tensor_of_problem_sizes = torch.tensor( + problem_sizes, dtype=torch.int32, device="cuda" + ) + + # Create torch tensors to store data pointers for all groups + # These allow the GPU kernel to dynamically access different tensors per group + # tensor_of_abc_ptrs: Shape (num_groups, 3) containing (a_ptr, b_ptr, c_ptr) per group + # tensor_of_sfasfb_ptrs: Shape (num_groups, 2) containing (sfa_ptr, sfb_ptr) per group + tensor_of_abc_ptrs = torch.tensor(abc_ptrs, dtype=torch.int64, device="cuda") + tensor_of_sfasfb_ptrs = torch.tensor(sfasfb_ptrs, dtype=torch.int64, device="cuda") + + # Compute the tile shape for each CUDA Thread Block (CTA) + # cta_tile_shape_mn: [M_tile, N_tile] = [128, 128] for this kernel + cta_tile_shape_mn = [128, mma_tiler_mnk[1]] + # cluster_tile_shape_mn: Total tile shape per cluster (same as CTA since cluster is 1x1) + cluster_tile_shape_mn = tuple( + x * y for x, y in zip(cta_tile_shape_mn, (1, 1)) + ) + + # Compute total number of cluster tiles needed across all groups + # Each group's (m, n) dimensions are divided into tiles of size cluster_tile_shape_mn + # This determines the total grid size (bidz dimension) for kernel launch + total_num_clusters = 0 + num_groups = len(problem_sizes) + for m, n, _, _ in problem_sizes: + # Calculate number of tiles needed in M and N dimensions for this group + num_clusters_mn = tuple( + (x + y - 1) // y for x, y in zip((m, n), cluster_tile_shape_mn) + ) + # Multiply M_tiles * N_tiles to get total tiles for this group + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + + # Allocate device memory for tensormap descriptors + # Each cluster needs its own set of tensormaps (one for A, B, SFA, SFB) + # Shape: (total_num_clusters, num_tensormaps=4, bytes_per_tensormap/8=16) + # Tensormaps are hardware descriptors used by TMA for efficient memory transfers + tensormap_shape = ( + total_num_clusters, + num_tensormaps, + bytes_per_tensormap // 8, + ) + tensor_of_tensormap = torch.empty(tensormap_shape, dtype=torch.int64, device="cuda") + + # Create CuTe pointers to the metadata tensors that will be passed to the kernel + # These allow the GPU kernel to read problem sizes and tensor pointers + cute_ptr_of_tensor_of_abc_ptrs = make_ptr( + cutlass.Int64, + tensor_of_abc_ptrs.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + cute_ptr_of_tensor_of_sfasfb_ptrs = make_ptr( + cutlass.Int64, + tensor_of_sfasfb_ptrs.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + cute_ptr_of_tensor_of_problem_sizes = make_ptr( + cutlass.Int32, + tensor_of_problem_sizes.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + + # Launch the JIT-compiled GPU kernel with all prepared data + # The kernel will perform block-scaled group GEMM: C = A * SFA * B * SFB for all groups + my_kernel( + initial_cute_abc_ptrs, # Template pointers for tensormap initialization + initial_cute_sfasfb_ptrs, # Template scale factor pointers + (min_a_idx, min_b_idx, min_c_idx), # Indices of smallest tensors + cute_ptr_of_tensor_of_problem_sizes, # Pointer to problem sizes array + cute_ptr_of_tensor_of_abc_ptrs, # Pointer to ABC tensor pointers array + cute_ptr_of_tensor_of_sfasfb_ptrs, # Pointer to scale factor pointers array + total_num_clusters, # Total number of CTAs to launch + problem_sizes, # Problem sizes list (for host-side processing) + tensor_of_tensormap, # Pre-allocated tensormap buffer + num_groups, # Number of groups in this batch + ) + + res = [] + for i in range(num_groups): + res.append(abc_tensors[i][2]) + return res \ No newline at end of file diff --git a/problems/nvidia/nvfp4_group_gemm/task.py b/problems/nvidia/nvfp4_group_gemm/task.py new file mode 100644 index 0000000..94c1143 --- /dev/null +++ b/problems/nvidia/nvfp4_group_gemm/task.py @@ -0,0 +1,8 @@ +import torch +from typing import TypedDict, TypeVar + +input_t = TypeVar("input_t", bound=tuple[list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], list[tuple[torch.Tensor, torch.Tensor]], list[tuple[torch.Tensor, torch.Tensor]], list[tuple[int, int, int, int]]]) +output_t = TypeVar("output_t", bound=list[torch.Tensor]) +class TestSpec(TypedDict): + problem_sizes: list[tuple[int, int, int, int]] + seed: int \ No newline at end of file diff --git a/problems/nvidia/nvfp4_group_gemm/task.yml b/problems/nvidia/nvfp4_group_gemm/task.yml new file mode 100644 index 0000000..6390ae9 --- /dev/null +++ b/problems/nvidia/nvfp4_group_gemm/task.yml @@ -0,0 +1,65 @@ +# name: nvfp4-block-scaled-gemm + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + + You will implement a block scaled group matrix-matrix multiplication kernel optimized for NVIDIA B200. + To be explicit, you will be given a tuple of tensors: + ``` + (abc_tensors, sfasfb_tensors, problem_sizes) + ``` + where: + * `abc_tensors` is list of tuples (a, b, c) where + a is torch.Tensor[float4e2m1fn_x2] of shape [M, K // 2, L] + b is torch.Tensor[float4e2m1fn_x2] of shape [N, K // 2, L] + c is torch.Tensor[float16] of shape [M, N, L] + * `sfasfb_tensors` is list of tuples (sfa, sfb) where + sfa is torch.Tensor[float8_e4m3fnuz] of shape [M, K // 16, L] + sfb is torch.Tensor[float8_e4m3fnuz] of shape [N, K // 16, L] + * `problem_sizes` is list of tuples (M, N, K, L) + + Each group's matrix sizes `M` is divisible by mma_tiler_mn[0], `N` is divisible by mma_tiler_mn[1], `K` is divisible by 256. + The ranking criteria is the geometric mean of the benchmark results. + For the grand price, your kernel will be evaluated against the speed of light analysis + and the solution closest to the speed of light will be awarded the grand price. + ``` + The speed of light analysis based on the max(FP4 Tensor Core math throughput, DRAM memory throughput) of B200 and tested under 1.5Ghz clock: + G M N K L time[us] + 8 128 4096 7168 1 18.833 + 8 128 7168 2048 1 10.667 + 2 256 3072 4096 1 2.406 + 2 256 4096 1536 1 1.525 + ``` +config: + main: "eval.py" + +templates: + Python: "template.py" + +tests: + - {"m": 128, "n": 256, "k": 512, "g": 8, "seed": 1111} + - {"m": 128, "n": 256, "k": 512, "g": 2, "seed": 1111} + - {"m": 128, "n": 384, "k": 640, "g": 3, "seed": 1111} + - {"m": 256, "n": 384, "k": 640, "g": 4, "seed": 1111} + - {"m": 256, "n": 512, "k": 384, "g": 2, "seed": 1111} + - {"m": 384, "n": 512, "k": 384, "g": 2, "seed": 1111} + - {"m": 384, "n": 640, "k": 512, "g": 2, "seed": 1111} + - {"m": 256, "n": 640, "k": 128, "g": 8, "seed": 1111} + - {"m": 512, "n": 768, "k": 256, "g": 5, "seed": 1111} + - {"m": 512, "n": 768, "k": 768, "g": 3, "seed": 1111} + +benchmarks: + - {"m": 4096, "n": 128, "k": 7168, "g": 8, "seed": 1111} + - {"m": 7168, "n": 128, "k": 2048, "g": 8, "seed": 1111} + - {"m": 3072, "n": 256, "k": 4096, "g": 2, "seed": 1111} + - {"m": 4096, "n": 256, "k": 1536, "g": 2, "seed": 1111} + +ranking_by: "geom" diff --git a/problems/nvidia/nvfp4_group_gemm/template.py b/problems/nvidia/nvfp4_group_gemm/template.py new file mode 100644 index 0000000..b6005fa --- /dev/null +++ b/problems/nvidia/nvfp4_group_gemm/template.py @@ -0,0 +1,31 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + """ + Reference implementation of block-scale fp4 group gemm + Args: + data: list of tuples (abc_tensors, sfasfb_tensors, sfasfb_reordered_tensors, problem_sizes) where: + abc_tensors: list of tuples (a, b, c) where + a is torch.Tensor[float4e2m1fn_x2] of shape [m, k // 2, l] + b is torch.Tensor[float4e2m1fn_x2] of shape [n, k // 2, l] + c is torch.Tensor[float16] of shape [m, n, l] + sfasfb_tensors: list of tuples (sfa, sfb) where + sfa is torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l] + sfb is torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l] + sfasfb_reordered_tensors: list of tuples (sfa_reordered, sfb_reordered) where + sfa_reordered is torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_m, 4, rest_k, l] + sfb_reordered is torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l] + problem_sizes: list of tuples (m, n, k, l) + each group has its own a, b, c, sfa, sfb with different m, n, k, l problem sizes + l should always be 1 for each group. + Returns: + list of tuples (c) where c is torch.Tensor[float16] of shape [m, n, l] + """ + abc_tensors, sfasfb_tensors, sfasfb_reordered_tensors, problem_sizes = data + result_tensors = [] + for i, ((a, b, c), (sfa_reordered, sfb_reordered), (m, n, k, l)) in enumerate(zip(abc_tensors, sfasfb_reordered_tensors, problem_sizes)): + # add you implementation here + result_tensors.append(c) + + return result_tensors \ No newline at end of file diff --git a/problems/nvidia/nvfp4_group_gemm/utils.py b/problems/nvidia/nvfp4_group_gemm/utils.py new file mode 100644 index 0000000..486116b --- /dev/null +++ b/problems/nvidia/nvfp4_group_gemm/utils.py @@ -0,0 +1,176 @@ +import os +import random +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + + +# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py +@torch.no_grad() +def verbose_allclose( + received: torch.Tensor, + expected: torch.Tensor, + rtol=1e-05, + atol=1e-08, + max_print=5 +) -> list[str]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + rtol (float): Relative tolerance; relative to expected + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + + Raises: + AssertionError: If the tensors are not all close within the given tolerance. + """ + # Check if the shapes of the tensors match + if received.shape != expected.shape: + return ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(received - expected) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(expected) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +@torch.no_grad() +def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): + """ + Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + max_print (int): Maximum number of mismatched elements to print. + + Returns: + Empty string if tensors are equal, otherwise detailed error information + """ + mismatched = torch.not_equal(received, expected) + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: + """ + Convenient "default" implementation for tasks' `check_implementation` function. + """ + expected = reference(data) + for i, (output_i, expected_i) in enumerate(zip(output, expected)): + reasons = verbose_allclose(output_i, expected_i, rtol=rtol, atol=atol) + if len(reasons) > 0: + return False, f"mismatch found! custom implementation doesn't match reference: {i} {reasons}" + + return True, '' + + +def make_match_reference(reference: callable, **kwargs): + def wrapped(data, output): + return match_reference(data, output, reference=reference, **kwargs) + return wrapped + + +class DeterministicContext: + def __init__(self): + self.allow_tf32 = None + self.deterministic = None + self.cublas = None + + def __enter__(self): + self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') + self.allow_tf32 = torch.backends.cudnn.allow_tf32 + self.deterministic = torch.backends.cudnn.deterministic + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + return self + + def __exit__(self, exc_type, exc_value, traceback): + torch.backends.cudnn.allow_tf32 = self.allow_tf32 + torch.backends.cudnn.deterministic = self.deterministic + torch.use_deterministic_algorithms(False) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas + +def clear_l2_cache(): + # import cupy as cp + # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) + # create a large dummy tensor + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") + # write stuff to + dummy.fill_(42) + del dummy \ No newline at end of file diff --git a/problems/nvidia/utils.py b/problems/nvidia/utils.py new file mode 100644 index 0000000..e8a9082 --- /dev/null +++ b/problems/nvidia/utils.py @@ -0,0 +1,176 @@ +import os +import random +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + + +# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py +@torch.no_grad() +def verbose_allclose( + received: torch.Tensor, + expected: torch.Tensor, + rtol=1e-05, + atol=1e-08, + max_print=5 +) -> list[str]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + rtol (float): Relative tolerance; relative to expected + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + + Raises: + AssertionError: If the tensors are not all close within the given tolerance. + """ + # Check if the shapes of the tensors match + if received.shape != expected.shape: + return ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(received - expected) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(expected) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +@torch.no_grad() +def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): + """ + Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + max_print (int): Maximum number of mismatched elements to print. + + Returns: + Empty string if tensors are equal, otherwise detailed error information + """ + mismatched = torch.not_equal(received, expected) + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: + """ + Convenient "default" implementation for tasks' `check_implementation` function. + """ + expected = reference(data) + reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) + + if len(reasons) > 0: + return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) + + return True, '' + + +def make_match_reference(reference: callable, **kwargs): + def wrapped(data, output): + return match_reference(data, output, reference=reference, **kwargs) + return wrapped + + +class DeterministicContext: + def __init__(self): + self.allow_tf32 = None + self.deterministic = None + self.cublas = None + + def __enter__(self): + self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') + self.allow_tf32 = torch.backends.cudnn.allow_tf32 + self.deterministic = torch.backends.cudnn.deterministic + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + return self + + def __exit__(self, exc_type, exc_value, traceback): + torch.backends.cudnn.allow_tf32 = self.allow_tf32 + torch.backends.cudnn.deterministic = self.deterministic + torch.use_deterministic_algorithms(False) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas + +def clear_l2_cache(): + # import cupy as cp + # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) + # create a large dummy tensor + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") + # write stuff to + dummy.fill_(42) + del dummy \ No newline at end of file