Skip to content

Commit 9231d4f

Browse files
committed
Support mx_tensor and enable it's test on Intel GPU
1 parent 03c2d28 commit 9231d4f

File tree

3 files changed

+171
-75
lines changed

3 files changed

+171
-75
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@
3636
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
3737

3838

39+
devices = []
40+
if torch.cuda.is_available():
41+
devices.append("cuda")
42+
43+
if torch.xpu.is_available():
44+
devices.append("xpu")
45+
46+
3947
# source: https://stackoverflow.com/a/22638709
4048
@pytest.fixture(autouse=True)
4149
def run_around_tests():
@@ -63,16 +71,22 @@ def cuda_kernel_profiler(kernel_pattern):
6371
result["found"] = any(kernel_pattern in name for name in kernel_names)
6472

6573

66-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
74+
@pytest.mark.skipif(
75+
not (torch.cuda.is_available() or torch.xpu.is_available()),
76+
reason="CUDA or XPU not available",
77+
)
6778
@pytest.mark.skipif(
6879
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
6980
)
7081
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
7182
@pytest.mark.parametrize("bias", [True, False])
7283
@pytest.mark.parametrize("compile", [True, False])
73-
@pytest.mark.parametrize("emulate", [True, False])
84+
@pytest.mark.parametrize(
85+
"emulate", [True, False] if (not torch.xpu.is_available()) else [True]
86+
)
7487
@pytest.mark.parametrize("use_inference_mode", [True, False])
7588
@pytest.mark.parametrize("x_rank", [2, 3])
89+
@pytest.mark.parametrize("device", devices)
7690
@torch.no_grad()
7791
@skip_if_rocm(
7892
"ROCm float4 gemm require gfx950"
@@ -84,25 +98,31 @@ def test_inference_workflow_mx(
8498
emulate: bool,
8599
use_inference_mode: bool,
86100
x_rank: int,
101+
device,
87102
):
88103
"""
89104
Smoke test for inference compile
90105
"""
91106
# TODO(future): figure out why these CUDA capability conditions are not properly
92107
# applied when inside `pytest.mark.skipif` for this test
93-
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
108+
if (
109+
elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
110+
) and torch.cuda.is_available():
94111
if not is_sm_at_least_89():
95112
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
96113
elif not is_sm_at_least_100() and not emulate:
97114
pytest.skip("CUDA capability >= 10.0 required for mxfp8 gemm")
98-
elif elem_dtype == torch.float4_e2m1fn_x2:
115+
elif (elem_dtype == torch.float4_e2m1fn_x2) and torch.cuda.is_available():
99116
if not is_sm_at_least_100() and not emulate:
100117
pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm")
101118
elif compile:
102119
# TODO(future PR): investigate and fix this
103-
pytest.skip("mxfp4 + compile currently does not work, low SQNR")
120+
pytest.skip("mxfp4 + compile currently does not work on CUDA, low SQNR")
104121

105-
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
122+
if (elem_dtype == torch.float4_e2m1fn_x2) and torch.xpu.is_available() and compile:
123+
pytest.skip("mxfp4 + compile currently does not work on XPU, low SQNR")
124+
125+
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device=device)
106126
m_mx = copy.deepcopy(m)
107127

108128
if emulate:
@@ -120,10 +140,9 @@ def test_inference_workflow_mx(
120140
if compile:
121141
m_mx = torch.compile(m_mx, fullgraph=True)
122142

123-
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
143+
x = torch.randn(128, 32, device=device, dtype=torch.bfloat16)
124144
if x_rank == 3:
125145
x = x.unsqueeze(0)
126-
127146
y_ref = m(x)
128147
if use_inference_mode:
129148
with torch.inference_mode():

0 commit comments

Comments
 (0)