Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions helion/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,27 @@ def min_dot_size(
def _min_dot_size(
device: torch.device, lhs: torch.dtype, rhs: torch.dtype
) -> tuple[int, int, int]:
if device.type != "cuda":
# TODO(jansel): support non-cuda properly
if device.type not in ["cuda", "xpu"]:
# TODO(jansel): support other hardware backends properly besides CUDA and XPU
return (16, 16, 16)

if torch.xpu.is_available():
from triton.backends.intel.compiler import ( # pyright: ignore[reportMissingImports]
min_dot_size as min_dot_size_xpu,
)

device_properties = torch.xpu.get_device_properties()
gpu_target_info = {
k: getattr(device_properties, k)
for k in device_properties.__dir__()
if not k.startswith("_")
}

dot_size_val = min_dot_size_xpu(gpu_target_info)(
torch_dtype_to_tl(lhs), torch_dtype_to_tl(rhs)
)
return tuple(int(v) for v in dot_size_val) # pyright: ignore[reportReturnType]

from triton.backends.nvidia.compiler import min_dot_size as min_dot_size_cuda

props = DeviceProperties.create(device)
Expand Down
4 changes: 3 additions & 1 deletion test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def test_low_mem_dropout(self):
size = 8192
seed = 123
seed2 = 456
x = torch.randn(size=(size,)).cuda()
x = torch.randn(size=(size,)).to(device=DEVICE)

_, out_fwd = code_and_output(
low_mem_dropout,
Expand Down Expand Up @@ -503,6 +503,7 @@ def test_attention_pointer(self):
)
)

@skipIfXPU("failure on XPU")
def test_attention_block_pointer(self):
args = (
torch.randn(2, 32, 1024, 64, dtype=torch.float16, device=DEVICE),
Expand Down Expand Up @@ -697,6 +698,7 @@ def test_segment_reduction(self):
)
)

@skipIfXPU("failure on XPU")
def test_attention_persistent_interleaved_l2_grouping(self):
"""Test attention with persistent interleaved execution and L2 grouping for optimal performance."""
args = (
Expand Down
2 changes: 1 addition & 1 deletion test/test_tensor_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def jsd_forward_kernel(
log_p = torch.randn(batch, vocab, device=DEVICE).log_softmax(dim=-1)

code, (loss, _) = code_and_output(jsd_forward_kernel, (log_q, log_p))
torch.cuda.synchronize()
torch.accelerator.synchronize()

from examples.jsd import TorchJSDBaseline

Expand Down
Loading