77import importlib .util
88import inspect
99import platform
10+ import sys
1011from pathlib import Path
1112from typing import Any , Callable
1213
6061 npu_device_count = torch .npu .device_count ()
6162
6263
64+ @pytest .mark .skipif (
65+ sys .version_info > (3 , 14 ), reason = "torch.compile is not supported on python 3.14+ "
66+ )
6367def test_vmap_compile ():
6468 # Since we monkey patch vmap we need to make sure compile is happy with it
6569 def func (x , y ):
@@ -76,6 +80,9 @@ def func(x, y):
7680@pytest .mark .skipif (
7781 TORCH_VERSION < version .parse ("2.4.0" ), reason = "requires torch>=2.4"
7882)
83+ @pytest .mark .skipif (
84+ sys .version_info > (3 , 14 ), reason = "torch.compile is not supported on python 3.14+ "
85+ )
7986@pytest .mark .parametrize ("mode" , [None , "reduce-overhead" ])
8087class TestTD :
8188 def test_tensor_output (self , mode ):
@@ -271,7 +278,7 @@ def make_td_with_names(data):
271278 make_td_with_names_c (data_dict )
272279
273280 @pytest .mark .skipif (
274- not torch .cuda .is_available () and not is_npu_available () , reason = "cuda or npu required to test device casting"
281+ not torch .cuda .is_available (), reason = "cuda required to test device casting"
275282 )
276283 @pytest .mark .parametrize ("has_device" , [True , False ])
277284 def test_to (self , has_device , mode ):
@@ -366,6 +373,9 @@ class MyClass:
366373@pytest .mark .skipif (
367374 TORCH_VERSION < version .parse ("2.4.0" ), reason = "requires torch>=2.4"
368375)
376+ @pytest .mark .skipif (
377+ sys .version_info > (3 , 14 ), reason = "torch.compile is not supported on python 3.14+ "
378+ )
369379@pytest .mark .parametrize ("mode" , [None , "reduce-overhead" ])
370380class TestTC :
371381 def test_tc_tensor_output (self , mode ):
@@ -558,7 +568,7 @@ def clone(td: TensorDict):
558568 assert clone_c (data ).a .b is data .a .b
559569
560570 @pytest .mark .skipif (
561- not torch .cuda .is_available () and not is_npu_available () , reason = "cuda or npu required to test device casting"
571+ not torch .cuda .is_available (), reason = "cuda required to test device casting"
562572 )
563573 @pytest .mark .parametrize ("has_device" , [True , False ])
564574 def test_tc_to (self , has_device , mode ):
@@ -630,6 +640,9 @@ def func_c_mytd():
630640@pytest .mark .skipif (
631641 TORCH_VERSION < version .parse ("2.4.0" ), reason = "requires torch>=2.4"
632642)
643+ @pytest .mark .skipif (
644+ sys .version_info > (3 , 14 ), reason = "torch.compile is not supported on python 3.14+ "
645+ )
633646@pytest .mark .parametrize ("mode" , [None , "reduce-overhead" ])
634647class TestNN :
635648 def test_func (self , mode ):
@@ -734,6 +747,9 @@ def test_prob_module_with_kwargs(self, mode):
734747@pytest .mark .skipif (
735748 TORCH_VERSION <= version .parse ("2.4.0" ), reason = "requires torch>2.4"
736749)
750+ @pytest .mark .skipif (
751+ sys .version_info > (3 , 14 ), reason = "torch.compile is not supported on python 3.14+ "
752+ )
737753@pytest .mark .parametrize ("mode" , [None , "reduce-overhead" ])
738754class TestFunctional :
739755 def test_functional_error (self , mode ):
@@ -1032,6 +1048,9 @@ def to_numpy(tensor):
10321048 (TORCH_VERSION <= version .parse ("2.7.0" )) and _IS_OSX ,
10331049 reason = "requires torch>=2.7 ons OSX" ,
10341050)
1051+ @pytest .mark .skipif (
1052+ sys .version_info > (3 , 14 ), reason = "torch.compile is not supported on python 3.14+ "
1053+ )
10351054@pytest .mark .parametrize ("compiled" , [False , True ])
10361055class TestCudaGraphs :
10371056 @pytest .fixture (scope = "class" , autouse = True )
@@ -1251,7 +1270,7 @@ def test_state_dict(self, compiled):
12511270 torch .testing .assert_close (y1 , y2 )
12521271
12531272
1254- @pytest .mark .skipif (not torch .cuda .is_available () and not is_npu_available () , reason = "cuda or npu is not available" )
1273+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "cuda is not available" )
12551274class TestCompileNontensor :
12561275 # Same issue with the decorator @tensorclass version
12571276 @pytest .fixture (scope = "class" )
0 commit comments