Skip to content

Don't raise internal assertions for multi-grid loop helion kernels #762

@bremerm31

Description

@bremerm31

Describe the bug
I'm encountering a confusing error message when compiling the following helion kernel

@helion.kernel(
    # static_shapes=True gives a performance boost for matmuls
    static_shapes=True,
)
def _matmul_strassen_1_setup(
    m: int,
    n: int,
    k: int,
    A0: torch.Tensor,
    A1: torch.Tensor,
    A2: torch.Tensor,
    A3: torch.Tensor,
    B0: torch.Tensor,
    B1: torch.Tensor,
    B2: torch.Tensor,
    B3: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    _A_intermediate = torch.empty(
        [7, m//2, k//2], device=A0.device, dtype=A0.dtype
    )
    _B_intermediate = torch.empty(
        [7, k//2, n//2], device=B0.device, dtype = B0.dtype
    )

    for tile_m, tile_k in hl.tile([m, k]):
        _A_intermediate[0,tile_m,tile_k] = A0[tile_m,tile_k] + A3[tile_m,tile_k]
        _A_intermediate[1,tile_m,tile_k] = A2[tile_m,tile_k] + A3[tile_m,tile_k]
        _A_intermediate[2,tile_m,tile_k] = A0[tile_m,tile_k]
        _A_intermediate[3,tile_m,tile_k] = A3[tile_m,tile_k]
        _A_intermediate[4,tile_m,tile_k] = A0[tile_m,tile_k] + A1[tile_m,tile_k]
        _A_intermediate[5,tile_m,tile_k] = A2[tile_m,tile_k] - A0[tile_m,tile_k]
        _A_intermediate[6,tile_m,tile_k] = A1[tile_m,tile_k] - A3[tile_m,tile_k]

    for tile_k, tile_n in hl.tile([k, n]):
        _B_intermediate[0,tile_k, tile_n] = B0[tile_k, tile_n] + B3[tile_k, tile_n]
        _B_intermediate[1,tile_k, tile_n] = B0[tile_k, tile_n]
        _B_intermediate[2,tile_k, tile_n] = B1[tile_k, tile_n] - B3[tile_k, tile_n]
        _B_intermediate[3,tile_k, tile_n] = B2[tile_k, tile_n] - B0[tile_k, tile_n]
        _B_intermediate[4,tile_k, tile_n] = B3[tile_k, tile_n]
        _B_intermediate[5,tile_k, tile_n] = B0[tile_k, tile_n] + B1[tile_k, tile_n]
        _B_intermediate[6,tile_k, tile_n] = B2[tile_k, tile_n] + B3[tile_k, tile_n]

    return _A_intermediate, _B_intermediate

returns error

WARNING:helion.runtime.kernel:Helion compiler triton codegen error for @helion.kernel(config=helion.Config(block_sizes=[1, 8, 64, 1], flatten_loops=[True, False], indexing='pointer', l2_groupings=[2, 2], loop_orders=[[1, 0], [0, 1]], num_stages=4, num_warps=32, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[]), static_shapes=True)
Traceback (most recent call last):
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/_compiler/ast_extension.py", line 272, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/_compiler/generate_ast.py", line 232, in visit_For
    fn._codegen(state)
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/language/loops.py", line 458, in _
    return _codegen_loop_helper(state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/language/loops.py", line 568, in _codegen_loop_helper
    state.tile_strategy.codegen_grid(state, block_ids)
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/_compiler/tile_dispatch.py", line 107, in codegen_grid
    grid_state = strategy.codegen_grid(state)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/_compiler/tile_strategy.py", line 428, in codegen_grid
    state.device_function.set_pid(TmpPid())
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/_compiler/device_function.py", line 330, in set_pid
    assert self.pid is None, "pid already set"
           ^^^^^^^^^^^^^^^^
AssertionError: pid already set

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/runtime/kernel.py", line 436, in compile_config
    triton_code = self.to_triton_code(
                  ^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/runtime/kernel.py", line 411, in to_triton_code
    root = generate_ast(self.host_function, config, emit_repro_caller)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/_compiler/generate_ast.py", line 407, in generate_ast
    codegen.add_statement(codegen.visit(stmt))
                          ^^^^^^^^^^^^^^^^^^^
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/_compiler/ast_extension.py", line 276, in visit
    raise exc.InternalError(e) from e
helion.exc.InternalError: AssertionError: pid already set
While processing:
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/examples/matmul_strassen1.py", line 53, in _matmul_strassen_1_setup
    for tile_m, tile_k in hl.tile([m, k]):

  0%|          | 0/10 [00:00<?, ?it/s]
WARNING:tritonbench.utils.triton_op:Caught exception, terminating early with partial results
Traceback (most recent call last):
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/_compiler/ast_extension.py", line 272, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/_compiler/generate_ast.py", line 232, in visit_For
    fn._codegen(state)
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/language/loops.py", line 458, in _
    return _codegen_loop_helper(state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/language/loops.py", line 568, in _codegen_loop_helper
    state.tile_strategy.codegen_grid(state, block_ids)
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/_compiler/tile_dispatch.py", line 107, in codegen_grid
    grid_state = strategy.codegen_grid(state)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/_compiler/tile_strategy.py", line 428, in codegen_grid
    state.device_function.set_pid(TmpPid())
  File "/data/users/mbremer/fbsource/buck-out/v2/gen/fbcode/b60fd1c0177358a2/helion/benchmarks/__run__/run-inplace#link-tree/helion/_compiler/device_function.py", line 330, in set_pid
    assert self.pid is None, "pid already set"
           ^^^^^^^^^^^^^^^^
AssertionError: pid already set

To Reproduce
FB-only task T240062719

Expected behavior
I'm unsure what the actual error is. it would be helpful to have a more descriptive error message/help the user debug their kernel.

Versions
Commit in task

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions