File tree Expand file tree Collapse file tree 1 file changed +16
-2
lines changed Expand file tree Collapse file tree 1 file changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -65,10 +65,24 @@ def min_dot_size(
6565def _min_dot_size (
6666 device : torch .device , lhs : torch .dtype , rhs : torch .dtype
6767) -> tuple [int , int , int ]:
68- if device .type != "cuda" :
69- # TODO(jansel): support non-cuda properly
68+ if device .type not in [ "cuda" , "xpu" ] :
69+ # TODO(jansel): support other hardware backends properly besides CUDA and XPU
7070 return (16 , 16 , 16 )
7171
72+ if torch .xpu .is_available ():
73+ from triton .backends .intel .compiler import min_dot_size as min_dot_size_xpu
74+
75+ device_properties = torch .xpu .get_device_properties ()
76+ gpu_target_info = {
77+ k : getattr (device_properties , k )
78+ for k in device_properties .__dir__ ()
79+ if not k .startswith ("_" )
80+ }
81+
82+ return min_dot_size_xpu (gpu_target_info )(
83+ torch_dtype_to_tl (lhs ), torch_dtype_to_tl (rhs )
84+ )
85+
7286 from triton .backends .nvidia .compiler import min_dot_size as min_dot_size_cuda
7387
7488 props = DeviceProperties .create (device )
You can’t perform that action at this time.
0 commit comments