Skip to content

[XPU] The bias value is not added to the matmul result #795

@EikanWang

Description

@EikanWang

Describe the bug
With a given expression - torch.relu(args[0] @ args[1] + bias), Helion produces a Triton kernel as follows. However, the bias does not impact the final result. I reviewed the produced Triton kernel; it looks good, and the bias has been added to the acc:

    load_2 = tl.load(tl.make_block_ptr(epilogue_closure_0, [1, 1024], [1024, 1], [0, offset_1], [1, _BLOCK_SIZE_1], [1, 0]), boundary_check=[1], padding_option='zero')
    v_0 = tl.cast(load_2, tl.float32)
    v_1 = acc + v_0

However, v_1 does not reflect the v_0 value.

To Reproduce

from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from helion.runtime import default_launcher as _default_launcher

import __main__ as _global_source0

DEVICE='xpu'

@triton.jit
def _helion_matmul(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
    num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0)
    num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1)
    inner_2d_pid = tl.program_id(0)
    num_pid_in_group = 64 * num_pid_n
    group_id = inner_2d_pid // num_pid_in_group
    first_pid_m = group_id * 64
    group_size_m = min(num_pid_m - first_pid_m, 64)
    pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m
    pid_1 = inner_2d_pid % num_pid_in_group // group_size_m
    offset_0 = pid_0 * _BLOCK_SIZE_0
    offset_1 = pid_1 * _BLOCK_SIZE_1
    acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
    for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2):
        acc_copy = acc
        acc_copy_0 = acc_copy
        load = tl.load(tl.make_block_ptr(x, [1024, 1024], [1024, 1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero')
        load_1 = tl.load(tl.make_block_ptr(y, [1024, 1024], [1024, 1], [offset_2, offset_1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
        acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
    load_2 = tl.load(tl.make_block_ptr(epilogue_closure_0, [1, 1024], [1024, 1], [0, offset_1], [1, _BLOCK_SIZE_1], [1, 0]), boundary_check=[1], padding_option='zero')
    v_0 = tl.cast(load_2, tl.float32)
    v_1 = acc + v_0
    v_2 = tl.full([], 0, tl.int32)
    v_3 = triton_helpers.maximum(v_2, v_1)
    v_4 = tl.cast(v_3, tl.float16)
    tl.store(tl.make_block_ptr(out, [1024, 1024], [1024, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_4, boundary_check=[0, 1])

def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
    """
    Performs matrix multiplication of x and y with an optional epilogue function.
    Args:
        x (Tensor): Left matrix of shape [m, k].
        y (Tensor): Right matrix of shape [k, n].
        epilogue (Callable, optional): Function applied to the accumulator and tile indices
            after the matmul. Defaults to identity (no change).
    Returns:
        Tensor: Resulting matrix of shape [m, n].
    """
    m, k = x.size()
    k2, n = y.size()
    assert k == k2, f'size mismatch {k} != {k2}'
    out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
    _BLOCK_SIZE_0 = 64
    _BLOCK_SIZE_1 = 64
    _BLOCK_SIZE_2 = 16
    _launcher(_helion_matmul, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
    return out

bias = torch.ones([1, 1024], device=DEVICE, dtype=torch.float16)
args = (
    torch.ones([1024, 1024], device=DEVICE, dtype=torch.float16),
    torch.ones([1024, 1024], device=DEVICE, dtype=torch.float16),
    lambda acc, tile: torch.relu(acc + bias[tile]),
)

bias.fill_(0.7)
args[0].fill_(0.1)
args[1].fill_(0.2)

def make_epilogue(bias):
    def epilogue(acc, tile):
        return acc + bias[tile[0], tile[1]]
    return epilogue
epilogue = make_epilogue(bias)

out = matmul(args[0], args[1], epilogue)
torch.xpu.synchronize()
torch.testing.assert_close(out, torch.relu(args[0] @ args[1] + bias))

Expected behavior
torch.testing.assert_close(out, torch.relu(args[0] @ args[1] + bias))

Versions
PyTorch: 2.9
Triton: 3.5
Helion: main

Additional context
N/A

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions