Skip to content

Commit db054ba

Browse files
committed
Enable min dot size for xpu
1 parent faf1cfe commit db054ba

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

helion/_compat.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,24 @@ def min_dot_size(
6565
def _min_dot_size(
6666
device: torch.device, lhs: torch.dtype, rhs: torch.dtype
6767
) -> tuple[int, int, int]:
68-
if device.type != "cuda":
69-
# TODO(jansel): support non-cuda properly
68+
if device.type not in ["cuda", "xpu"]:
69+
# TODO(jansel): support other hardware backends properly besides CUDA and XPU
7070
return (16, 16, 16)
7171

72+
if torch.xpu.is_available():
73+
from triton.backends.intel.compiler import min_dot_size as min_dot_size_xpu
74+
75+
device_properties = torch.xpu.get_device_properties()
76+
gpu_target_info = {
77+
k: getattr(device_properties, k)
78+
for k in device_properties.__dir__()
79+
if not k.startswith("_")
80+
}
81+
82+
return min_dot_size_xpu(gpu_target_info)(
83+
torch_dtype_to_tl(lhs), torch_dtype_to_tl(rhs)
84+
)
85+
7286
from triton.backends.nvidia.compiler import min_dot_size as min_dot_size_cuda
7387

7488
props = DeviceProperties.create(device)

0 commit comments

Comments
 (0)