diff --git a/torchrec/distributed/test_utils/multi_process.py b/torchrec/distributed/test_utils/multi_process.py index cb2faf1df..db6f30c00 100644 --- a/torchrec/distributed/test_utils/multi_process.py +++ b/torchrec/distributed/test_utils/multi_process.py @@ -107,9 +107,16 @@ def __init__( ) -> None: super().__init__(methodName) + # In CUDA 12.8 we're seeing hangs from using forkserver, so we're + # switching to spawn. # AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail # Therefore we use spawn for HIP runtime until AMD fixes the issue - self._mp_init_mode: str = mp_init_mode if torch.version.hip is None else "spawn" + if ( + torch.version.cuda is not None and torch.version.cuda >= "12.8" + ) or torch.version.hip is not None: + self._mp_init_mode: str = "spawn" + else: + self._mp_init_mode: str = mp_init_mode logging.info(f"Using {self._mp_init_mode} for multiprocessing") @seed_and_log