3636 pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
3737
3838
39+ devices = []
40+ if torch .cuda .is_available ():
41+ devices .append ("cuda" )
42+
43+ if torch .xpu .is_available ():
44+ devices .append ("xpu" )
45+
46+
3947# source: https://stackoverflow.com/a/22638709
4048@pytest .fixture (autouse = True )
4149def run_around_tests ():
@@ -63,16 +71,22 @@ def cuda_kernel_profiler(kernel_pattern):
6371 result ["found" ] = any (kernel_pattern in name for name in kernel_names )
6472
6573
66- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
74+ @pytest .mark .skipif (
75+ not (torch .cuda .is_available () or torch .xpu .is_available ()),
76+ reason = "CUDA or XPU not available" ,
77+ )
6778@pytest .mark .skipif (
6879 not torch_version_at_least ("2.8.0" ), reason = "torch.compile requires PyTorch 2.8+"
6980)
7081@pytest .mark .parametrize ("elem_dtype" , [torch .float8_e4m3fn , torch .float4_e2m1fn_x2 ])
7182@pytest .mark .parametrize ("bias" , [True , False ])
7283@pytest .mark .parametrize ("compile" , [True , False ])
73- @pytest .mark .parametrize ("emulate" , [True , False ])
84+ @pytest .mark .parametrize (
85+ "emulate" , [True , False ] if (not torch .xpu .is_available ()) else [True ]
86+ )
7487@pytest .mark .parametrize ("use_inference_mode" , [True , False ])
7588@pytest .mark .parametrize ("x_rank" , [2 , 3 ])
89+ @pytest .mark .parametrize ("device" , devices )
7690@torch .no_grad ()
7791@skip_if_rocm (
7892 "ROCm float4 gemm require gfx950"
@@ -84,25 +98,31 @@ def test_inference_workflow_mx(
8498 emulate : bool ,
8599 use_inference_mode : bool ,
86100 x_rank : int ,
101+ device ,
87102):
88103 """
89104 Smoke test for inference compile
90105 """
91106 # TODO(future): figure out why these CUDA capability conditions are not properly
92107 # applied when inside `pytest.mark.skipif` for this test
93- if elem_dtype in (torch .float8_e4m3fn , torch .float8_e5m2 ):
108+ if (
109+ elem_dtype in (torch .float8_e4m3fn , torch .float8_e5m2 )
110+ ) and torch .cuda .is_available ():
94111 if not is_sm_at_least_89 ():
95112 pytest .skip ("CUDA capability >= 8.9 required for float8 in triton" )
96113 elif not is_sm_at_least_100 () and not emulate :
97114 pytest .skip ("CUDA capability >= 10.0 required for mxfp8 gemm" )
98- elif elem_dtype == torch .float4_e2m1fn_x2 :
115+ elif ( elem_dtype == torch .float4_e2m1fn_x2 ) and torch . cuda . is_available () :
99116 if not is_sm_at_least_100 () and not emulate :
100117 pytest .skip ("CUDA capability >= 10.0 required for mxfp4 gemm" )
101118 elif compile :
102119 # TODO(future PR): investigate and fix this
103- pytest .skip ("mxfp4 + compile currently does not work, low SQNR" )
120+ pytest .skip ("mxfp4 + compile currently does not work on CUDA , low SQNR" )
104121
105- m = nn .Linear (32 , 128 , bias = bias , dtype = torch .bfloat16 , device = "cuda" )
122+ if (elem_dtype == torch .float4_e2m1fn_x2 ) and torch .xpu .is_available () and compile :
123+ pytest .skip ("mxfp4 + compile currently does not work on XPU, low SQNR" )
124+
125+ m = nn .Linear (32 , 128 , bias = bias , dtype = torch .bfloat16 , device = device )
106126 m_mx = copy .deepcopy (m )
107127
108128 if emulate :
@@ -120,10 +140,9 @@ def test_inference_workflow_mx(
120140 if compile :
121141 m_mx = torch .compile (m_mx , fullgraph = True )
122142
123- x = torch .randn (128 , 32 , device = "cuda" , dtype = torch .bfloat16 )
143+ x = torch .randn (128 , 32 , device = device , dtype = torch .bfloat16 )
124144 if x_rank == 3 :
125145 x = x .unsqueeze (0 )
126-
127146 y_ref = m (x )
128147 if use_inference_mode :
129148 with torch .inference_mode ():
0 commit comments