diff --git a/cuda_core/cuda/core/system/_system.pyx b/cuda_core/cuda/core/system/_system.pyx index e6163b94fd..adef2d8afc 100644 --- a/cuda_core/cuda/core/system/_system.pyx +++ b/cuda_core/cuda/core/system/_system.pyx @@ -82,6 +82,16 @@ def get_nvml_version() -> tuple[int, ...]: return tuple(int(v) for v in nvml.system_get_nvml_version().split(".")) +def get_driver_branch() -> str: + """ + Retrieves the driver branch of the NVIDIA driver installed on the system. + """ + if not CUDA_BINDINGS_NVML_IS_COMPATIBLE: + raise RuntimeError("NVML library is not available") + initialize() + return nvml.system_get_driver_branch() + + def get_num_devices() -> int: """ Return the number of devices in the system. @@ -112,6 +122,7 @@ def get_process_name(pid: int) -> str: __all__ = [ + "get_driver_branch", "get_driver_version", "get_driver_version_full", "get_nvml_version", diff --git a/cuda_core/tests/system/test_system_system.py b/cuda_core/tests/system/test_system_system.py index 582c471b8c..ebc8af12bb 100644 --- a/cuda_core/tests/system/test_system_system.py +++ b/cuda_core/tests/system/test_system_system.py @@ -95,3 +95,11 @@ def test_device_count(): device_count = system.get_num_devices() assert isinstance(device_count, int) assert device_count >= 0 + + +@skip_if_nvml_unsupported +def test_get_driver_branch(): + driver_branch = system.get_driver_branch() + assert isinstance(driver_branch, str) + assert len(driver_branch) > 0 + assert driver_branch[0] == "r"