Skip to content

Commit fb5995c

Browse files
zhuhan0facebook-github-bot
authored andcommitted
Use spawn for multiprocessing start method
Summary: CUDA context initialization is not fork-safe. If a CUDA context is created in a parent process, and then the process is forked (using `os.fork()`), the child process may encounter errors or undefined behavior when using CUDA. This is because the CUDA driver and runtime are not designed to be safely duplicated via `fork()`. It's recommended to use `spawn` or `forkserver`. Among the two, `forkserver` needs to be use carefully and specifically, it's recommended to call `multiprocessing.set_start_method('forkserver')` at the very start of the program, and the parent process also needs to avoid initializing the CUDA context. When upgrading APS to CUDA 12.8, we encountered a test failure, and the test is apparently initializing the CUDA context before starting up two children processes, and I suspect that caused the test to hang - [post](https://fb.workplace.com/groups/319878845696681/posts/1494595861558301). It's hard to avoid initializing the CUDA context early in this test, because it checks the GPU count in the test method's decorator - [code](https://fburl.com/code/27naz2eg). Among the `spawn` and `forkserver` start methods, `spawn` is less efficient but it's the most robust. Let's switch to that instead to avoid any potential undefined behaviors with CUDA 12.8 and multiprocessing. Differential Revision: D80305233
1 parent fdd8534 commit fb5995c

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

torchrec/distributed/test_utils/multi_process.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,16 +102,9 @@ def __exit__(self, exc_type, exc_instance, traceback) -> None:
102102

103103

104104
class MultiProcessTestBase(unittest.TestCase):
105-
def __init__(
106-
self, methodName: str = "runTest", mp_init_mode: str = "forkserver"
107-
) -> None:
105+
def __init__(self, methodName: str = "runTest") -> None:
108106
super().__init__(methodName)
109107

110-
# AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail
111-
# Therefore we use spawn for HIP runtime until AMD fixes the issue
112-
self._mp_init_mode: str = mp_init_mode if torch.version.hip is None else "spawn"
113-
logging.info(f"Using {self._mp_init_mode} for multiprocessing")
114-
115108
@seed_and_log
116109
def setUp(self) -> None:
117110
os.environ["MASTER_ADDR"] = str("localhost")
@@ -144,7 +137,7 @@ def _run_multi_process_test(
144137
# pyre-ignore
145138
**kwargs,
146139
) -> None:
147-
ctx = multiprocessing.get_context(self._mp_init_mode)
140+
ctx = multiprocessing.get_context(method="spawn")
148141
processes = []
149142
for rank in range(world_size):
150143
kwargs["rank"] = rank
@@ -170,7 +163,7 @@ def _run_multi_process_test_per_rank(
170163
world_size: int,
171164
kwargs_per_rank: List[Dict[str, Any]],
172165
) -> None:
173-
ctx = multiprocessing.get_context(self._mp_init_mode)
166+
ctx = multiprocessing.get_context(method="spawn")
174167
processes = []
175168
for rank in range(world_size):
176169
kwargs = {}

0 commit comments

Comments
 (0)