Skip to content

Commit b6663b9

Browse files
Add SM architecture version check (#8199)
Fixes #8198 NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5. Review the [TensorRT Support Matrix](https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html) for which GPUs are supported by this release. Add SM architecture version check to skip trt test before 7.0. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0bb20a8 commit b6663b9

File tree

9 files changed

+99
-8
lines changed

9 files changed

+99
-8
lines changed

monai/bundle/scripts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,8 @@ def trt_export(
15891589
"""
15901590
Export the model checkpoint to the given filepath as a TensorRT engine-based TorchScript.
15911591
Currently, this API only supports converting models whose inputs are all tensors.
1592+
Note: NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.
1593+
Review the TensorRT Support Matrix for which GPUs are supported.
15921594
15931595
There are two ways to export a model:
15941596
1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript.

monai/networks/trt_compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,9 @@ def trt_compile(
505505
) -> torch.nn.Module:
506506
"""
507507
Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook.
508-
Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x
508+
Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x.
509+
NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.
510+
Review the TensorRT Support Matrix for which GPUs are supported.
509511
Args:
510512
model: module to patch with TrtCompiler object.
511513
base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path.

monai/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
InvalidPyTorchVersionError,
108108
OptionalImportError,
109109
allow_missing_reference,
110+
compute_capabilities_after,
110111
damerau_levenshtein_distance,
111112
exact_version,
112113
get_full_type_name,

monai/utils/module.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,3 +634,44 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st
634634
if is_prerelease:
635635
return False
636636
return True
637+
638+
639+
@functools.lru_cache(None)
640+
def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: str | None = None) -> bool:
641+
"""
642+
Compute whether the current system GPU CUDA compute capability is after or equal to the specified version.
643+
The current system GPU CUDA compute capability is determined by the first GPU in the system.
644+
The compared version is a string in the form of "major.minor".
645+
646+
Args:
647+
major: major version number to be compared with.
648+
minor: minor version number to be compared with. Defaults to 0.
649+
current_ver_string: if None, the current system GPU CUDA compute capability will be used.
650+
651+
Returns:
652+
True if the current system GPU CUDA compute capability is greater than or equal to the specified version.
653+
"""
654+
if current_ver_string is None:
655+
cuda_available = torch.cuda.is_available()
656+
pynvml, has_pynvml = optional_import("pynvml")
657+
if not has_pynvml: # assuming that the user has Ampere and later GPU
658+
return True
659+
if not cuda_available:
660+
return False
661+
else:
662+
pynvml.nvmlInit()
663+
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # get the first GPU
664+
major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
665+
current_ver_string = f"{major_c}.{minor_c}"
666+
pynvml.nvmlShutdown()
667+
668+
ver, has_ver = optional_import("packaging.version", name="parse")
669+
if has_ver:
670+
return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore
671+
parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2)
672+
while len(parts) < 2:
673+
parts += ["0"]
674+
c_major, c_minor = parts[:2]
675+
c_mn = int(c_major), int(c_minor)
676+
mn = int(major), int(minor)
677+
return c_mn > mn

tests/test_bundle_trt_export.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222
from monai.data import load_net_with_metadata
2323
from monai.networks import save_state
2424
from monai.utils import optional_import
25-
from tests.utils import command_line_tests, skip_if_no_cuda, skip_if_quick, skip_if_windows
25+
from tests.utils import (
26+
SkipIfBeforeComputeCapabilityVersion,
27+
command_line_tests,
28+
skip_if_no_cuda,
29+
skip_if_quick,
30+
skip_if_windows,
31+
)
2632

2733
_, has_torchtrt = optional_import(
2834
"torch_tensorrt",
@@ -47,6 +53,7 @@
4753
@skip_if_windows
4854
@skip_if_no_cuda
4955
@skip_if_quick
56+
@SkipIfBeforeComputeCapabilityVersion((7, 0))
5057
class TestTRTExport(unittest.TestCase):
5158

5259
def setUp(self):

tests/test_convert_to_trt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from monai.networks import convert_to_trt
2121
from monai.networks.nets import UNet
2222
from monai.utils import optional_import
23-
from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows
23+
from tests.utils import SkipIfBeforeComputeCapabilityVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows
2424

2525
_, has_torchtrt = optional_import(
2626
"torch_tensorrt",
@@ -38,6 +38,7 @@
3838
@skip_if_windows
3939
@skip_if_no_cuda
4040
@skip_if_quick
41+
@SkipIfBeforeComputeCapabilityVersion((7, 0))
4142
class TestConvertToTRT(unittest.TestCase):
4243

4344
def setUp(self):

tests/test_trt_compile.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121
from monai.networks import trt_compile
2222
from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132
2323
from monai.utils import min_version, optional_import
24-
from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows
24+
from tests.utils import (
25+
SkipIfAtLeastPyTorchVersion,
26+
SkipIfBeforeComputeCapabilityVersion,
27+
skip_if_no_cuda,
28+
skip_if_quick,
29+
skip_if_windows,
30+
)
2531

2632
trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version)
2733
polygraphy, polygraphy_imported = optional_import("polygraphy")
@@ -36,6 +42,7 @@
3642
@skip_if_quick
3743
@unittest.skipUnless(trt_imported, "tensorrt is required")
3844
@unittest.skipUnless(polygraphy_imported, "polygraphy is required")
45+
@SkipIfBeforeComputeCapabilityVersion((7, 0))
3946
class TestTRTCompile(unittest.TestCase):
4047

4148
def setUp(self):

tests/test_pytorch_version_after.py renamed to tests/test_version_after.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
from parameterized import parameterized
1717

18-
from monai.utils import pytorch_after
18+
from monai.utils import compute_capabilities_after, pytorch_after
1919

20-
TEST_CASES = (
20+
TEST_CASES_PT = (
2121
(1, 5, 9, "1.6.0"),
2222
(1, 6, 0, "1.6.0"),
2323
(1, 6, 1, "1.6.0", False),
@@ -36,14 +36,30 @@
3636
(1, 6, 1, "1.6.0+cpu", False),
3737
)
3838

39+
TEST_CASES_SM = [
40+
# (major, minor, sm, expected)
41+
(6, 1, "6.1", True),
42+
(6, 1, "6.0", False),
43+
(6, 0, "8.6", True),
44+
(7, 0, "8", True),
45+
(8, 6, "8", False),
46+
]
47+
3948

4049
class TestPytorchVersionCompare(unittest.TestCase):
4150

42-
@parameterized.expand(TEST_CASES)
51+
@parameterized.expand(TEST_CASES_PT)
4352
def test_compare(self, a, b, p, current, expected=True):
4453
"""Test pytorch_after with a and b"""
4554
self.assertEqual(pytorch_after(a, b, p, current), expected)
4655

4756

57+
class TestComputeCapabilitiesAfter(unittest.TestCase):
58+
59+
@parameterized.expand(TEST_CASES_SM)
60+
def test_compute_capabilities_after(self, major, minor, sm, expected):
61+
self.assertEqual(compute_capabilities_after(major, minor, sm), expected)
62+
63+
4864
if __name__ == "__main__":
4965
unittest.main()

tests/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from monai.networks import convert_to_onnx, convert_to_torchscript
4848
from monai.utils import optional_import
4949
from monai.utils.misc import MONAIEnvVars
50-
from monai.utils.module import pytorch_after
50+
from monai.utils.module import compute_capabilities_after, pytorch_after
5151
from monai.utils.tf32 import detect_default_tf32
5252
from monai.utils.type_conversion import convert_data_type
5353

@@ -286,6 +286,20 @@ def __call__(self, obj):
286286
)(obj)
287287

288288

289+
class SkipIfBeforeComputeCapabilityVersion:
290+
"""Decorator to be used if test should be skipped
291+
with Compute Capability older than that given."""
292+
293+
def __init__(self, compute_capability_tuple):
294+
self.min_version = compute_capability_tuple
295+
self.version_too_old = not compute_capabilities_after(*compute_capability_tuple)
296+
297+
def __call__(self, obj):
298+
return unittest.skipIf(
299+
self.version_too_old, f"Skipping tests that fail on Compute Capability versions before: {self.min_version}"
300+
)(obj)
301+
302+
289303
def is_main_test_process():
290304
ps = torch.multiprocessing.current_process()
291305
if not ps or not hasattr(ps, "name"):

0 commit comments

Comments
 (0)