diff --git a/tests/v1/engine/test_client_guard.py b/tests/v1/engine/test_client_guard.py new file mode 100644 index 000000000000..a64ccc7ed420 --- /dev/null +++ b/tests/v1/engine/test_client_guard.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import queue +import threading +import time +from unittest.mock import AsyncMock + +import pytest +import zmq + +from vllm.utils.collection_utils import ThreadSafeDict +from vllm.v1.engine.core_client import ClientSentinel +from vllm.v1.engine.utils import FaultHandler, FaultInfo + +FAULT_RECEIVER_ADDR = "tcp://127.0.0.1:8844" +CMD_ADDR = "tcp://127.0.0.1:8845" +FAULT_PUB_ADDR = "tcp://127.0.0.1:8846" +FAULT_PUB_TOPIC = "vllm_fault" + + +def create_test_thread_safe_dict(initial_data=None): + if initial_data is None: + initial_data = {1: "Healthy"} + + tsd = ThreadSafeDict() + if initial_data: + for k, v in initial_data.items(): + tsd[k] = v + return tsd + + +def create_client_sentinel( + engine_exception_q: queue.Queue, engine_status_dict: ThreadSafeDict[int, str] +): + return ClientSentinel( + fault_receiver_addr=FAULT_RECEIVER_ADDR, + cmd_addr=CMD_ADDR, + engine_registry={0: b"engine_identity"}, + engine_exception_q=engine_exception_q, + fault_pub_addr=FAULT_PUB_ADDR, + engine_status_dict=engine_status_dict, + ) + + +def test_client_sentinel_initialization(): + engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() + engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) + sentinel = create_client_sentinel(engine_exception_q, engine_status_dict) + + assert sentinel.engine_registry[0] == b"engine_identity" + assert not sentinel.client_sentinel_dead + assert isinstance(sentinel.fault_handler, FaultHandler) + assert sentinel.engine_exception_q is engine_exception_q + + assert sentinel.fault_receiver_socket.type == zmq.ROUTER + assert sentinel.cmd_socket.type == zmq.ROUTER + assert sentinel.fault_pub_socket.type == zmq.PUB + + sentinel.shutdown_sentinel() + + +@pytest.mark.asyncio +async def test_handle_fault(): + engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() + engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) + sentinel = create_client_sentinel(engine_exception_q, engine_status_dict) + + engine_exception_q.put_nowait( + FaultInfo(engine_id="1", message="test exception", type="test") + ) + + sentinel.fault_handler.handle_fault = AsyncMock(return_value=True) + + result = await sentinel.handle_fault("pause", 5) + assert result is True + sentinel.fault_handler.handle_fault.assert_awaited_once_with("pause", 5) + + sentinel.shutdown_sentinel() + + +def test_fault_receiver(): + engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() + engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) + sentinel = create_client_sentinel(engine_exception_q, engine_status_dict) + + def send_test_message(): + ctx = zmq.Context() + socket = ctx.socket(zmq.DEALER) + socket.setsockopt(zmq.IDENTITY, b"test_sender") + socket.connect(FAULT_RECEIVER_ADDR) + + test_fault = FaultInfo(engine_id="1", type="dead", message="test error") + socket.send_multipart([b"", test_fault.serialize().encode("utf-8")]) + socket.close() + ctx.term() + + sender_thread = threading.Thread(target=send_test_message, daemon=True) + sender_thread.start() + + def check_published_message(): + ctx = zmq.Context() + sub_socket = ctx.socket(zmq.SUB) + sub_socket.connect(FAULT_PUB_ADDR) + sub_socket.setsockopt_string(zmq.SUBSCRIBE, FAULT_PUB_TOPIC) + + message = sub_socket.recv_string() + sub_socket.close() + ctx.term() + + prefix, data = message.split("|", 1) + assert prefix == FAULT_PUB_TOPIC + assert json.loads(data) == {"1": "Dead"} + + check_thread = threading.Thread(target=check_published_message, daemon=True) + check_thread.start() + + time.sleep(0.1) + + assert not engine_exception_q.empty() + received_fault = engine_exception_q.get_nowait() + assert received_fault.engine_id == "1" + assert received_fault.type == "dead" + + assert engine_status_dict[1] == "Dead" + + sentinel.shutdown_sentinel() + + +def test_fault_receiver_unhealthy(): + engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() + engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) + sentinel = create_client_sentinel(engine_exception_q, engine_status_dict) + + def send_unhealthy_message(): + ctx = zmq.Context() + socket = ctx.socket(zmq.DEALER) + socket.setsockopt(zmq.IDENTITY, b"engine_identity") + socket.connect(FAULT_RECEIVER_ADDR) + + test_fault = FaultInfo(engine_id="1", type="error", message="test error") + socket.send_multipart([b"", test_fault.serialize().encode()]) + socket.close() + ctx.term() + + threading.Thread(target=send_unhealthy_message, daemon=True).start() + time.sleep(0.1) + + assert engine_status_dict[1] == "Unhealthy" + + sentinel.shutdown_sentinel() + + +def test_shutdown_sentinel(): + engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() + engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) + sentinel = create_client_sentinel(engine_exception_q, engine_status_dict) + + original_fault_sock = sentinel.fault_receiver_socket + original_cmd_sock = sentinel.cmd_socket + original_pub_sock = sentinel.fault_pub_socket + original_ctx = sentinel.zmq_ctx + + sentinel.shutdown_sentinel() + + assert sentinel.client_sentinel_dead is True + + with pytest.raises(zmq.ZMQError): + original_fault_sock.recv() + + with pytest.raises(zmq.ZMQError): + original_cmd_sock.recv() + + with pytest.raises(zmq.ZMQError): + original_pub_sock.send(b"test") + + assert original_ctx.closed + + +@pytest.mark.asyncio +async def test_handle_fault_async(): + engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() + engine_status_dict = create_test_thread_safe_dict({0: "Unhealthy"}) + sentinel = create_client_sentinel(engine_exception_q, engine_status_dict) + + time.sleep(0.1) + ctx = zmq.Context().instance() + cmd_socket = ctx.socket(zmq.DEALER) + cmd_socket.setsockopt(zmq.IDENTITY, b"engine_identity") + cmd_socket.connect(CMD_ADDR) + time.sleep(0.1) + + uuid = None + + def receive_cmd(cmd_socket): + nonlocal uuid + time.sleep(0.1) + + identity, msg = cmd_socket.recv_multipart() + cmd_dict = json.loads(msg.decode("utf-8")) + assert cmd_dict["method"] == "retry" + assert cmd_dict["timeout"] == 3 + uuid = cmd_dict["method_uuid"] + + def response_cmd(cmd_socket): + nonlocal uuid + while uuid is None: + time.sleep(0.1) + execute_result = {"engine_index": 0, "success": True, "method_uuid": uuid} + cmd_socket.send_multipart([b"", json.dumps(execute_result).encode("utf-8")]) + + threading.Thread(target=receive_cmd, args=(cmd_socket,), daemon=True).start() + threading.Thread(target=response_cmd, args=(cmd_socket,), daemon=True).start() + + result = await sentinel.handle_fault("retry", 3) + + assert result is True + assert engine_status_dict[0] == "Healthy" + + cmd_socket.close() + ctx.term() + sentinel.shutdown_sentinel() diff --git a/tests/v1/engine/test_engine_core_guard.py b/tests/v1/engine/test_engine_core_guard.py new file mode 100644 index 000000000000..f55a903db770 --- /dev/null +++ b/tests/v1/engine/test_engine_core_guard.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import logging +import queue +import threading +import time + +import pytest +import zmq + +from vllm.utils.network_utils import make_zmq_socket +from vllm.v1.engine.core import ( + EngineCoreSentinel, + EngineLoopPausedError, +) +from vllm.v1.serial_utils import serialize_method_call + +CLIENT_CMD_ADDR = "tcp://127.0.0.1:8844" +WORKER_CMD_ADDR = "tcp://127.0.0.1:8845" +FAULT_REPORT_ADDR = "tcp://127.0.0.1:8846" +SENTINEL_IDENTITY = b"engine_sentinel_0" + + +def create_engine_core_sentinel( + fault_signal_q: queue.Queue, busy_loop_active: threading.Event +): + return EngineCoreSentinel( + engine_index=0, + fault_signal_q=fault_signal_q, + cmd_q=queue.Queue(), + busy_loop_active=busy_loop_active, + engine_input_q=queue.Queue(), + client_cmd_addr=CLIENT_CMD_ADDR, + worker_cmd_addr=WORKER_CMD_ADDR, + fault_report_addr=FAULT_REPORT_ADDR, + sentinel_identity=SENTINEL_IDENTITY, + tp_size=1, + pp_size=1, + dp_size=1, + ) + + +def test_engine_core_sentinel_initialization(): + fault_signal_q: queue.Queue = queue.Queue() + busy_loop_active = threading.Event() + + sentinel = create_engine_core_sentinel(fault_signal_q, busy_loop_active) + + assert sentinel.engine_index == 0 + assert sentinel.tp_size == 1 + assert sentinel.pp_size == 1 + assert not sentinel.communicator_aborted + assert sentinel.engine_running is True + assert sentinel.daemon is True + + assert sentinel.fault_report_socket.type == zmq.DEALER + assert sentinel.client_cmd_socket.type == zmq.DEALER + assert sentinel.worker_cmd_socket.type == zmq.ROUTER + + sentinel.shutdown() + + +@pytest.mark.parametrize("instruction", ["pause", "retry"]) +def test_run_handle_instruction(instruction): + fault_signal_q: queue.Queue = queue.Queue() + busy_loop_active = threading.Event() + + client_socket = make_zmq_socket( + ctx=zmq.Context(), path=CLIENT_CMD_ADDR, socket_type=zmq.ROUTER, bind=True + ) + + time.sleep(0.1) + + sentinel = create_engine_core_sentinel(fault_signal_q, busy_loop_active) + time.sleep(0.1) + + ctx = zmq.Context() + worker_cmd_socket = ctx.socket(zmq.DEALER) + worker_cmd_socket.setsockopt(zmq.IDENTITY, b"0_0") + worker_cmd_socket.connect(WORKER_CMD_ADDR) + + def mock_worker_receiver(cmd_socket): + time.sleep(0.1) + logging.info("start worker") + identity, msg = cmd_socket.recv_multipart() + logging.info(identity) + cmd_dict = json.loads(msg.decode("utf-8")) + assert ( + cmd_dict["method"] == "pause_by_signal" + if instruction == "pause" + else "retry" + ) + response_dict = {"success": True, "method_uuid": cmd_dict["method_uuid"]} + logging.info(identity) + cmd_socket.send_multipart([b"", json.dumps(response_dict).encode("utf-8")]) + + threading.Thread(target=sentinel.run, daemon=True).start() + time.sleep(0.1) + + param = {"timeout": 3} + if instruction == "pause": + param["soft_pause"] = True + elif instruction == "retry": + param["new_stateless_dp_group_port"] = 23456 + serial_instruction = serialize_method_call(instruction, **param) + client_socket.send_multipart( + [SENTINEL_IDENTITY, b"", serial_instruction.encode("utf-8")] + ) + if instruction == "pause": + fault_signal_q.put(EngineLoopPausedError(Exception("test error"))) + elif instruction == "retry": + busy_loop_active.set() + + threading.Thread( + target=mock_worker_receiver, args=(worker_cmd_socket,), daemon=True + ).start() + + time.sleep(0.1) + identity, _, msg = client_socket.recv_multipart() + result_dict = json.loads(msg.decode("utf-8")) + assert result_dict["engine_index"] == 0 + assert result_dict["success"] + + time.sleep(0.1) + + client_socket.close() + worker_cmd_socket.close() + sentinel.shutdown() diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index dd76a722106e..eb85862bf240 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -10,6 +10,7 @@ ) from vllm.config.device import DeviceConfig from vllm.config.ec_transfer import ECTransferConfig +from vllm.config.fault_tolerance import FaultToleranceConfig from vllm.config.kv_events import KVEventsConfig from vllm.config.kv_transfer import KVTransferConfig from vllm.config.load import LoadConfig @@ -86,6 +87,8 @@ "SpeechToTextConfig", # From vllm.config.structured_outputs "StructuredOutputsConfig", + # From vllm.config.fault_tolerance + "FaultToleranceConfig", # From vllm.config.utils "ConfigType", "SupportsMetricsInfo", diff --git a/vllm/config/fault_tolerance.py b/vllm/config/fault_tolerance.py new file mode 100644 index 000000000000..24fbd6f1f259 --- /dev/null +++ b/vllm/config/fault_tolerance.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from typing import Any + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@config +@dataclass +class FaultToleranceConfig: + """Configuration for fault tolerance.""" + + enable_fault_tolerance: bool = False + """Enable fault tolerance for detailed error recovery, + such as scaling down fault DPEngineCore. + """ + + engine_recovery_timeout: int = 60 + """Timeout (in seconds) to wait for error handling instructions + before raising an exception. If the EngineCore encounters an + error, it waits up to this many seconds for instructions on how + to handle the error. If no instructions are received within this + time, the original error is raised. + """ + + internal_fault_report_port: int = 22866 + """ + The port to use for internal fault reporting. + """ + + external_fault_notify_port: int = 22867 + """ + The port to use for external fault notify. + """ + + engine_core_cmd_addr: str = "" + """ + The ZMQ address between engine_core_sentinel and worker_sentinel. + It will be initialized and assigned in EngineCore, then passed + to the Worker via vllm_config—this is required for the Worker + to spin up the WorkerSentinel. + """ + + gloo_comm_timeout: int = 30 + """ + The timeout for gloo communication. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + pass diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 9a6326d62e82..634b92101afb 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -337,7 +337,11 @@ def get_next_dp_init_port(self) -> int: return answer - def stateless_init_dp_group(self) -> ProcessGroup: + def stateless_init_dp_group( + self, + gloo_comm_timeout: int | None = None, + dp_init_port: int | None = None, + ) -> ProcessGroup: # NOTE: In high-concurrency scenarios multiple processes # can pick the same (currently free) port through a race # condition when calling `get_open_port()`. When the first @@ -353,21 +357,25 @@ def stateless_init_dp_group(self) -> ProcessGroup: max_retries = 5 last_exc: Exception | None = None + if dp_init_port is None: + dp_init_port = self.get_next_dp_init_port() for _ in range(max_retries): try: # use gloo since the engine process might not have cuda device return stateless_init_torch_distributed_process_group( self.data_parallel_master_ip, - self.get_next_dp_init_port(), + dp_init_port, self.data_parallel_rank, self.data_parallel_size, backend=current_platform.dist_backend, + gloo_comm_timeout=gloo_comm_timeout, ) except DistNetworkError as e: # We only want to retry when the root cause is EADDRINUSE. if "EADDRINUSE" in str(e): logger.warning("Address already in use. Retrying with a new port.") last_exc = e + dp_init_port = self.get_next_dp_init_port() continue # try again with a new port raise e diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 672b004c4aa5..9005d309f3a0 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -10,7 +10,7 @@ import threading import time from contextlib import contextmanager -from dataclasses import replace +from dataclasses import field, replace from datetime import datetime from functools import lru_cache from pathlib import Path @@ -30,6 +30,7 @@ from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode from .device import DeviceConfig from .ec_transfer import ECTransferConfig +from .fault_tolerance import FaultToleranceConfig from .kv_events import KVEventsConfig from .kv_transfer import KVTransferConfig from .load import LoadConfig @@ -107,6 +108,10 @@ class VllmConfig: """The configurations for event publishing.""" ec_transfer_config: ECTransferConfig | None = None """The configurations for distributed EC cache transfer.""" + fault_tolerance_config: FaultToleranceConfig = field( + default_factory=FaultToleranceConfig + ) + """The configurations for fault tolerance.""" # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. @@ -1049,7 +1054,8 @@ def __str__(self): f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, " # noqa f"pooler_config={self.model_config.pooler_config!r}, " - f"compilation_config={self.compilation_config!r}" + f"compilation_config={self.compilation_config!r}," + f"fault_tolerance_config={self.fault_tolerance_config!r}, " ) @model_validator(mode="after") diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 2fc35e80f591..846a0f6c21cc 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -236,7 +236,7 @@ def all_gatherv( cudaStream_t(stream.cuda_stream), ) split_offset += split_size - self.nccl.ncclGroupEnd() + self.nccl.ncclGroupEnd(self.comm) def reduce_scatter( self, @@ -301,7 +301,7 @@ def reduce_scatterv( cudaStream_t(stream.cuda_stream), ) split_offset += split_size - self.nccl.ncclGroupEnd() + self.nccl.ncclGroupEnd(self.comm) def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: @@ -369,7 +369,10 @@ def group_start(self): self.nccl.ncclGroupStart() def group_end(self): - self.nccl.ncclGroupEnd() + self.nccl.ncclGroupEnd(self.comm) + + def nccl_abort_comm(self): + self.nccl.ncclCommAbort(self.comm) def register_comm_window(self, tensor: torch.Tensor): return self.nccl.ncclCommWindowRegister( diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index b2433d58dc1f..e5c3e981823b 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -136,6 +136,15 @@ class NCCLLibrary: Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]), # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t* asyncError) + Function( + "ncclCommGetAsyncError", + ncclResult_t, + [ + ncclComm_t, + ctypes.POINTER(ncclResult_t), + ], + ), # ncclResult_t ncclCommInitRank( # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); # note that ncclComm_t is a pointer type, so the first argument @@ -274,6 +283,8 @@ class NCCLLibrary: # it is better not to call it at all. # ncclResult_t ncclCommDestroy(ncclComm_t comm); Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + # ncclResult_t ncclCommAbort(ncclComm_t comm) + Function("ncclCommAbort", ncclResult_t, [ncclComm_t]), # ncclResult_t ncclGroupStart(); Function("ncclGroupStart", ncclResult_t, []), # ncclResult_t ncclGroupEnd(); @@ -360,10 +371,32 @@ def __init__(self, so_file: str | None = None): def ncclGetErrorString(self, result: ncclResult_t) -> str: return self._funcs["ncclGetErrorString"](result).decode("utf-8") - def NCCL_CHECK(self, result: ncclResult_t) -> None: - if result != 0: - error_str = self.ncclGetErrorString(result) - raise RuntimeError(f"NCCL error: {error_str}") + def NCCL_CHECK(self, result: ncclResult_t, comm: ncclComm_t | None = None) -> None: + ncclSuccess = 0 + ncclInProgress = 7 + + if result == ncclSuccess: + return + + # Handle non-blocking communicators + if result == ncclInProgress: + if comm is None: + raise RuntimeError( + "NCCL_CHECK: ncclInProgress returned but no communicator " + "provided (required for non-blocking NCCL checks)." + ) + while True: + result = self.ncclCommGetAsyncError(comm) + result_value = result.value + if result_value == ncclSuccess: + # Operation has completed successfully + return + if result_value != ncclInProgress: + # Now a definite error occurred + break + + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL_CHECK failed: {error_str} (code={result})") def ncclGetRawVersion(self) -> int: version = ctypes.c_int() @@ -384,6 +417,11 @@ def ncclGetUniqueId(self) -> ncclUniqueId: self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id))) return unique_id + def ncclCommGetAsyncError(self, comm: ncclComm_t) -> ncclResult_t: + async_error = ncclResult_t() + self._funcs["ncclCommGetAsyncError"](comm, ctypes.byref(async_error)) + return async_error + def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId: if len(data) != 128: raise ValueError( @@ -400,7 +438,8 @@ def ncclCommInitRank( self.NCCL_CHECK( self._funcs["ncclCommInitRank"]( ctypes.byref(comm), world_size, unique_id, rank - ) + ), + comm, ) return comm @@ -422,7 +461,8 @@ def ncclAllReduce( self.NCCL_CHECK( self._funcs["ncclAllReduce"]( sendbuff, recvbuff, count, datatype, op, comm, stream - ) + ), + comm, ) def ncclReduce( @@ -444,7 +484,8 @@ def ncclReduce( self.NCCL_CHECK( self._funcs["ncclReduce"]( sendbuff, recvbuff, count, datatype, op, root, comm, stream - ) + ), + comm, ) def ncclReduceScatter( @@ -465,7 +506,8 @@ def ncclReduceScatter( self.NCCL_CHECK( self._funcs["ncclReduceScatter"]( sendbuff, recvbuff, count, datatype, op, comm, stream - ) + ), + comm, ) def ncclAllGather( @@ -484,7 +526,8 @@ def ncclAllGather( self.NCCL_CHECK( self._funcs["ncclAllGather"]( sendbuff, recvbuff, count, datatype, comm, stream - ) + ), + comm, ) def ncclSend( @@ -497,7 +540,7 @@ def ncclSend( stream: cudaStream_t, ) -> None: self.NCCL_CHECK( - self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream) + self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream), comm ) def ncclRecv( @@ -510,7 +553,7 @@ def ncclRecv( stream: cudaStream_t, ) -> None: self.NCCL_CHECK( - self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream) + self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream), comm ) def ncclBroadcast( @@ -526,17 +569,21 @@ def ncclBroadcast( self.NCCL_CHECK( self._funcs["ncclBroadcast"]( sendbuff, recvbuff, count, datatype, root, comm, stream - ) + ), + comm, ) def ncclCommDestroy(self, comm: ncclComm_t) -> None: - self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm), comm) + + def ncclCommAbort(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommAbort"](comm), comm) def ncclGroupStart(self) -> None: self.NCCL_CHECK(self._funcs["ncclGroupStart"]()) - def ncclGroupEnd(self) -> None: - self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) + def ncclGroupEnd(self, comm) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupEnd"](), comm) def ncclCommWindowRegister( self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int @@ -545,12 +592,13 @@ def ncclCommWindowRegister( self.NCCL_CHECK( self._funcs["ncclCommWindowRegister"]( comm, buff, size, ctypes.byref(window), win_flags - ) + ), + comm, ) return window def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None: - self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window)) + self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window), comm) __all__ = [ diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 852c4c644433..b27125b66830 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -25,6 +25,7 @@ import contextlib import gc +import os import pickle import weakref from collections import namedtuple @@ -312,6 +313,8 @@ def __init__( use_device_communicator: bool, # whether to use device communicator use_message_queue_broadcaster: bool = False, group_name: str | None = None, + enable_fault_tolerance: bool = False, + gloo_comm_timeout: timedelta | None = None, ): group_name = group_name or "anonymous" self.unique_name = _get_unique_name(group_name) @@ -323,13 +326,30 @@ def __init__( self_device_group = None self_cpu_group = None + if torch_distributed_backend == "nccl": + options = torch._C._distributed_c10d.ProcessGroupNCCL.Options() + if enable_fault_tolerance: + # need to set communicators as nonblocking to abort safely + options.config.blocking = 0 + os.environ["NCCL_COMM_BLOCKING"] = "0" + for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend - ) + if torch_distributed_backend == "nccl": + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend, pg_options=options + ) + else: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. - cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if not enable_fault_tolerance: + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + else: + cpu_group = torch.distributed.new_group( + ranks, backend="gloo", timeout=gloo_comm_timeout + ) if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) @@ -1039,7 +1059,11 @@ def get_inner_dp_world_group() -> GroupCoordinator: def init_world_group( - ranks: list[int], local_rank: int, backend: str + ranks: list[int], + local_rank: int, + backend: str, + enable_fault_tolerance: bool = False, + gloo_comm_timeout: timedelta | None = None, ) -> GroupCoordinator: return GroupCoordinator( group_ranks=[ranks], @@ -1047,6 +1071,8 @@ def init_world_group( torch_distributed_backend=backend, use_device_communicator=False, group_name="world", + enable_fault_tolerance=enable_fault_tolerance, + gloo_comm_timeout=gloo_comm_timeout, ) @@ -1054,6 +1080,8 @@ def init_model_parallel_group( group_ranks: list[list[int]], local_rank: int, backend: str, + enable_fault_tolerance: bool = False, + gloo_comm_timeout: timedelta | None = None, use_message_queue_broadcaster: bool = False, group_name: str | None = None, use_device_communicator: bool = True, @@ -1065,6 +1093,8 @@ def init_model_parallel_group( use_device_communicator=use_device_communicator, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, + enable_fault_tolerance=enable_fault_tolerance, + gloo_comm_timeout=gloo_comm_timeout, ) @@ -1128,6 +1158,31 @@ def get_pipeline_model_parallel_group(): return get_pp_group() +def get_all_model_groups() -> list[GroupCoordinator]: + group_list = [] + global _TP + if _TP: + group_list.append(_TP) + + global _PP + if _PP: + group_list.append(_PP) + + global _DCP + if _DCP: + group_list.append(_DCP) + + global _DP + if _DP: + group_list.append(_DP) + + global _EP + if _EP: + group_list.append(_EP) + + return group_list + + @contextmanager def graph_capture(device: torch.device): """ @@ -1164,6 +1219,7 @@ def init_distributed_environment( distributed_init_method: str = "env://", local_rank: int = -1, backend: str = "nccl", + enable_fault_tolerance: bool = False, timeout: timedelta | None = None, ): logger.debug( @@ -1244,7 +1300,9 @@ def init_distributed_environment( global _WORLD, _NODE_COUNT, _INNER_DP_WORLD if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) - _WORLD = init_world_group(ranks, local_rank, backend) + _WORLD = init_world_group( + ranks, local_rank, backend, enable_fault_tolerance, timeout + ) if config.parallel_config.nnodes > 1: _NODE_COUNT = config.parallel_config.nnodes else: @@ -1276,6 +1334,8 @@ def init_distributed_environment( def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + enable_fault_tolerance: bool = False, + gloo_comm_timeout: timedelta | None = None, decode_context_model_parallel_size: int | None = 1, backend: str | None = None, ) -> None: @@ -1339,6 +1399,8 @@ def initialize_model_parallel( group_ranks, get_world_group().local_rank, backend, + enable_fault_tolerance, + gloo_comm_timeout, use_message_queue_broadcaster=True, group_name="tp", ) @@ -1356,6 +1418,8 @@ def initialize_model_parallel( group_ranks, get_world_group().local_rank, backend, + enable_fault_tolerance, + gloo_comm_timeout, use_message_queue_broadcaster=True, group_name="dcp", ) @@ -1368,7 +1432,12 @@ def initialize_model_parallel( ) group_ranks = [x.tolist() for x in group_ranks] _PP = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="pp" + group_ranks, + get_world_group().local_rank, + backend, + enable_fault_tolerance, + gloo_comm_timeout, + group_name="pp", ) global _DP @@ -1376,7 +1445,12 @@ def initialize_model_parallel( group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] _DP = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="dp" + group_ranks, + get_world_group().local_rank, + backend, + enable_fault_tolerance, + gloo_comm_timeout, + group_name="dp", ) global _EP @@ -1388,7 +1462,12 @@ def initialize_model_parallel( ) group_ranks = [x.tolist() for x in group_ranks] _EP = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="ep" + group_ranks, + get_world_group().local_rank, + backend, + enable_fault_tolerance, + gloo_comm_timeout, + group_name="ep", ) logger.info_once( @@ -1407,6 +1486,8 @@ def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, decode_context_model_parallel_size: int | None = 1, + enable_fault_tolerance: bool = False, + gloo_comm_timeout: timedelta | None = None, backend: str | None = None, ) -> None: """Helper to initialize model parallel groups if they are not initialized, @@ -1418,6 +1499,8 @@ def ensure_model_parallel_initialized( initialize_model_parallel( tensor_model_parallel_size, pipeline_model_parallel_size, + enable_fault_tolerance, + gloo_comm_timeout, decode_context_model_parallel_size, backend, ) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index debf69c49b7d..abffa84e0c04 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -458,7 +458,12 @@ def init_gloo_process_group( def stateless_init_torch_distributed_process_group( - host: str, port: int, rank: int, world_size: int, backend: str + host: str, + port: int, + rank: int, + world_size: int, + backend: str, + gloo_comm_timeout: int | None, ) -> ProcessGroup: """ A replacement for `torch.distributed.init_process_group` that does not @@ -493,7 +498,11 @@ def stateless_init_torch_distributed_process_group( """ init_method = get_tcp_uri(host, port) backend = Backend(backend) # it is basically string - timeout = _get_default_timeout(backend) + + if gloo_comm_timeout is None: + timeout = _get_default_timeout(backend) + else: + timeout = timedelta(seconds=gloo_comm_timeout) store, rank, world_size = next( rendezvous(init_method, rank, world_size, timeout=timeout) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ab6e5e594c23..ca072fbca6e9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -40,6 +40,7 @@ DeviceConfig, ECTransferConfig, EPLBConfig, + FaultToleranceConfig, KVEventsConfig, KVTransferConfig, LoadConfig, @@ -567,6 +568,13 @@ class EngineArgs: kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill + # fault tolerance fields + enable_fault_tolerance: bool = FaultToleranceConfig.enable_fault_tolerance + engine_recovery_timeout: int = FaultToleranceConfig.engine_recovery_timeout + internal_fault_report_port: int = FaultToleranceConfig.internal_fault_report_port + external_fault_notify_port: int = FaultToleranceConfig.external_fault_notify_port + gloo_comm_timeout: int = FaultToleranceConfig.gloo_comm_timeout + kv_offloading_size: float | None = CacheConfig.kv_offloading_size kv_offloading_backend: KVOffloadingBackend | None = ( CacheConfig.kv_offloading_backend @@ -1155,6 +1163,32 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"] ) + # fault tolerance arguments + fault_tolerance_kwargs = get_kwargs(FaultToleranceConfig) + fault_tolerance_group = parser.add_argument_group( + title="FaultToleranceConfig", + description=FaultToleranceConfig.__doc__, + ) + fault_tolerance_group.add_argument( + "--enable-fault-tolerance", + **fault_tolerance_kwargs["enable_fault_tolerance"], + ) + fault_tolerance_group.add_argument( + "--engine-recovery-timeout", + **fault_tolerance_kwargs["engine_recovery_timeout"], + ) + fault_tolerance_group.add_argument( + "--internal-fault-report-port", + **fault_tolerance_kwargs["internal_fault_report_port"], + ) + fault_tolerance_group.add_argument( + "--external-fault-notify-port", + **fault_tolerance_kwargs["external_fault_notify_port"], + ) + fault_tolerance_group.add_argument( + "--gloo-comm-timeout", + **fault_tolerance_kwargs["gloo_comm_timeout"], + ) # Other arguments parser.add_argument( "--disable-log-stats", @@ -1738,6 +1772,14 @@ def create_engine_config( collect_detailed_traces=self.collect_detailed_traces, ) + fault_tolerance_config = FaultToleranceConfig( + enable_fault_tolerance=self.enable_fault_tolerance, + engine_recovery_timeout=self.engine_recovery_timeout, + internal_fault_report_port=self.internal_fault_report_port, + external_fault_notify_port=self.external_fault_notify_port, + gloo_comm_timeout=self.gloo_comm_timeout, + ) + # Compilation config overrides compilation_config = copy.deepcopy(self.compilation_config) if self.cuda_graph_sizes is not None: @@ -1784,6 +1826,7 @@ def create_engine_config( kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, ec_transfer_config=self.ec_transfer_config, + fault_tolerance_config=fault_tolerance_config, additional_config=self.additional_config, ) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 462d2c4e50e7..06f756e24cb3 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -165,6 +165,19 @@ async def collective_rpc( """Perform a collective RPC call to the given path.""" raise NotImplementedError + async def handle_fault( + self, instruction: str, timeout: int = 300, **kwargs + ) -> bool: + """send fault tolerance instruction to the engine""" + raise NotImplementedError + + async def get_fault_info(self): + """report exception from engine_core""" + raise NotImplementedError + async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: """Get supported tasks""" raise NotImplementedError + + def shutdown(self) -> None: + raise NotImplementedError diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 154cdeb42a3e..3a020b58b37b 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -15,7 +15,7 @@ from collections.abc import AsyncGenerator from typing import Any -from fastapi import FastAPI, Request +from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, Response, StreamingResponse import vllm.envs as envs @@ -56,6 +56,49 @@ async def generate(request: Request) -> Response: return await _generate(request_dict, raw_request=request) +@app.post("/fault_tolerance/apply") +async def process_fault_tolerance_instruction(request: Request) -> Response: + """Apply fault tolerance instructions to the engine. + + This endpoint handles fault recovery operations such as retrying operations. + + The request should be a JSON object with the following fields: + - fault_tolerance_instruction: The name of fault tolerance method. + - fault_tolerance_timeout: Timeout in seconds for the operation to complete. + - fault_tolerance_params: dict, optional. Additional dynamic parameters for + the fault tolerance operation. + """ + request_dict = await request.json() + + fault_tolerance_instruction = request_dict.get("fault_tolerance_instruction") + fault_tolerance_timeout = request_dict.get("fault_tolerance_timeout") + kwargs = request_dict.get("fault_tolerance_params", {}) + assert engine is not None + success = await engine.handle_fault( + fault_tolerance_instruction, fault_tolerance_timeout, **kwargs + ) + if success: + return JSONResponse( + status_code=200, + content={"message": "Instruction executed successfully."}, + ) + + logger.error("Fault tolerance operation failed. Shutting down the engine.") + engine.shutdown() + raise HTTPException( + status_code=400, + detail="Instruction execution failed.", + ) + + +@app.get("/fault_tolerance/status") +async def get_fault_info() -> Response: + """Health check.""" + assert engine is not None + engine_status_dict = await engine.get_fault_info() + return Response(json.dumps(engine_status_dict), status_code=200) + + @with_cancellation async def _generate(request_dict: dict, raw_request: Request) -> Response: prompt = request_dict.pop("prompt") diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 96608f360e17..2faaa6013141 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -153,7 +153,10 @@ def signal_handler(signum, frame): ) try: - engine_manager.join_first() + if vllm_config.fault_tolerance_config.enable_fault_tolerance: + engine_manager.start_engine_core_monitor() + else: + engine_manager.join_first() finally: logger.info("Shutting down.") engine_manager.close() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3cf66fcd27e2..3c8882ccec5f 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -263,6 +263,20 @@ async def validate_json_request(raw_request: Request): ) +async def check_engine_fault(raw_request: Request): + client = engine_client(raw_request) + assert hasattr(client, "engine_core") + core_client = client.engine_core + if ( + hasattr(core_client, "client_sentinel") + and core_client.client_sentinel.is_faulted.is_set() + ): + raise HTTPException( + status_code=503, + detail="Service is in faulted state, cannot process requests.", + ) + + router = APIRouter() @@ -395,7 +409,7 @@ async def get_server_load_metrics(request: Request): @router.post( "/tokenize", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, @@ -430,7 +444,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): @router.post( "/detokenize", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, @@ -505,7 +519,7 @@ async def _convert_stream_to_sse_events( @router.post( "/v1/responses", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, @@ -598,7 +612,7 @@ async def cancel_responses(response_id: str, raw_request: Request): @router.post( "/v1/messages", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse}, @@ -654,7 +668,7 @@ def translate_error_response(response: ErrorResponse) -> JSONResponse: @router.post( "/v1/chat/completions", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, @@ -695,7 +709,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re @router.post( "/v1/completions", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, @@ -741,7 +755,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post( "/v1/embeddings", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, @@ -784,7 +798,7 @@ async def create_embedding( @router.post( "/pooling", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, @@ -820,7 +834,10 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): assert_never(generator) -@router.post("/classify", dependencies=[Depends(validate_json_request)]) +@router.post( + "/classify", + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], +) @with_cancellation @load_aware_call async def create_classify(request: ClassificationRequest, raw_request: Request): @@ -849,7 +866,7 @@ async def create_classify(request: ClassificationRequest, raw_request: Request): @router.post( "/score", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, @@ -882,7 +899,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): @router.post( "/v1/score", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, @@ -979,7 +996,7 @@ async def create_translations( @router.post( "/rerank", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, @@ -1011,7 +1028,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request): @router.post( "/v1/rerank", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, @@ -1030,7 +1047,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request): @router.post( "/v2/rerank", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, @@ -1224,6 +1241,89 @@ async def is_scaling_elastic_ep(raw_request: Request): (PoolingRequest, (pooling, create_pooling)), ] + +@router.post( + "/fault_tolerance/apply", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"model": dict}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +async def process_fault_tolerance_instruction(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail="Invalid JSON format") from e + + client = engine_client(raw_request) + + fault_tolerance_instruction = body.get("fault_tolerance_instruction") + fault_tolerance_timeout = body.get("fault_tolerance_timeout") + dynamic_fault_tolerance_params = body.get("fault_tolerance_params", {}) + + if fault_tolerance_instruction is None or fault_tolerance_timeout is None: + raise HTTPException( + status_code=400, + detail="Both 'fault_tolerance_instruction' and " + "'fault_tolerance_timeout' are required.", + ) + + if not isinstance(fault_tolerance_instruction, str): + raise HTTPException( + status_code=400, detail="'fault_tolerance_instruction' must be a string." + ) + # Supported instructions: ["pause", "retry"]. + # More instruction types may be added in future updates. + elif fault_tolerance_instruction not in ["pause", "retry"]: + raise HTTPException( + status_code=400, detail="Invalid 'fault_tolerance_instruction' value." + ) + + if not isinstance(fault_tolerance_timeout, int) or fault_tolerance_timeout <= 0: + raise HTTPException( + status_code=400, + detail="'fault_tolerance_timeout' must be a positive integer.", + ) + try: + success = await client.handle_fault( + fault_tolerance_instruction, + fault_tolerance_timeout, + **dynamic_fault_tolerance_params, + ) + if success: + return JSONResponse( + { + "message": "Instruction executed successfully.", + } + ) + else: + logger.error("Fault tolerance failed. Shutting down the application.") + client.shutdown() + raise HTTPException( + status_code=400, + detail="Instruction execution failed.", + ) + + except Exception as e: + logger.error("Failed to handle fault: %s", e) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail="Failed to handle fault.", + ) from e + + +@router.get("/fault_tolerance/status") +async def get_fault_info( + raw_request: Request, +): + client = engine_client(raw_request) + engine_status_dict = await client.get_fault_info() + return JSONResponse(content=engine_status_dict) + + # NOTE: Construct the TypeAdapters only once INVOCATION_VALIDATORS = [ (pydantic.TypeAdapter(request_type), (get_handler, endpoint)) diff --git a/vllm/utils/collection_utils.py b/vllm/utils/collection_utils.py index 57271311828c..5bb7b418b0c5 100644 --- a/vllm/utils/collection_utils.py +++ b/vllm/utils/collection_utils.py @@ -6,6 +6,7 @@ This is similar in concept to the `collections` module. """ +import threading from collections import UserDict, defaultdict from collections.abc import Callable, Generator, Hashable, Iterable, Mapping from typing import Generic, Literal, TypeVar @@ -137,3 +138,211 @@ def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: obj[key1] = v2 else: obj.pop(key1, None) + + +# Define type variables for generic key and value types +KT = TypeVar("KT") # Key type variable +VT = TypeVar("VT") # Value type variable + + +class ThreadSafeDict(Generic[KT, VT]): + """ + A thread-safe generic dictionary implementation. + Supports all basic dictionary operations with proper synchronization + using a reentrant lock, and maintains type safety through generics. + """ + + def __init__(self) -> None: + """Initialize an empty thread-safe dictionary with an internal lock.""" + self._storage: dict[KT, VT] = {} # Underlying storage structure + self._lock = threading.RLock() # Reentrant lock for synchronization + + def __setitem__(self, key: KT, value: VT) -> None: + """ + Thread-safe implementation of dictionary item assignment. + Equivalent to dict[key] = value. + + Args: + key: The key to associate with the value + value: The value to store + """ + with self._lock: + self._storage[key] = value + + def __getitem__(self, key: KT) -> VT: + """ + Thread-safe implementation of dictionary item retrieval. + Equivalent to dict[key]. + + Args: + key: The key to look up + + Returns: + The value associated with the key + + Raises: + KeyError: If the key is not found + """ + with self._lock: + return self._storage[key] + + def __delitem__(self, key: KT) -> None: + """ + Thread-safe implementation of dictionary item deletion. + Equivalent to del dict[key]. + + Args: + key: The key to remove + + Raises: + KeyError: If the key is not found + """ + with self._lock: + del self._storage[key] + + def get(self, key: KT, default: VT | None = None) -> VT | None: + """ + Thread-safe implementation of dict.get(). + + Args: + key: The key to look up + default: Value to return if key is not found (default: None) + + Returns: + The value associated with the key, or default if not found + """ + with self._lock: + return self._storage.get(key, default) + + def setdefault(self, key: KT, default: VT) -> VT: + """ + Thread-safe implementation of dict.setdefault(). + Inserts key with default value if key is not present. + + Args: + key: The key to check/insert + default: Value to insert if key is not found + + Returns: + The existing value or the inserted default value + """ + with self._lock: + return self._storage.setdefault(key, default) + + def update(self, items: Iterable[tuple[KT, VT]]) -> None: + """ + Thread-safe implementation of dict.update(). + Updates dictionary with multiple key-value pairs. + + Args: + items: Iterable of (key, value) tuples to add/update + """ + with self._lock: + self._storage.update(items) + + def pop(self, key: KT, default: VT | None = None) -> VT | None: + """ + Thread-safe implementation of dict.pop(). + Removes and returns value associated with key. + + Args: + key: The key to remove + default: Value to return if key is not found (optional) + + Returns: + The removed value or default if key not found + + Raises: + KeyError: If key not found and no default provided + """ + with self._lock: + return self._storage.pop(key, default) + + def __contains__(self, key: KT) -> bool: + """ + Thread-safe implementation of 'key in dict' check. + + Args: + key: The key to check for existence + + Returns: + True if key exists, False otherwise + """ + with self._lock: + return key in self._storage + + def __len__(self) -> int: + """ + Thread-safe implementation of len(dict). + + Returns: + Number of key-value pairs in the dictionary + """ + with self._lock: + return len(self._storage) + + def clear(self) -> None: + """Thread-safe implementation of dict.clear(). Removes all items.""" + with self._lock: + self._storage.clear() + + def keys(self) -> list[KT]: + """ + Thread-safe implementation of dict.keys(). + Returns a copy of all keys to prevent concurrent modification issues. + + Returns: + List of all keys in the dictionary + """ + with self._lock: + return list(self._storage.keys()) + + def values(self) -> list[VT]: + """ + Thread-safe implementation of dict.values(). + Returns a copy of all values to prevent concurrent modification issues. + + Returns: + List of all values in the dictionary + """ + with self._lock: + return list(self._storage.values()) + + def items(self) -> list[tuple[KT, VT]]: + """ + Thread-safe implementation of dict.items(). + Returns a copy of all key-value pairs to prevent concurrent modification issues. + + Returns: + List of (key, value) tuples + """ + with self._lock: + return list(self._storage.items()) + + def __str__(self) -> str: + """ + Thread-safe string representation. + + Returns: + String representation of the dictionary + """ + with self._lock: + return str(self._storage) + + def __repr__(self) -> str: + """ + Thread-safe representation for debugging. + + Returns: + Debug-friendly string representation + """ + with self._lock: + return f"ThreadSafeDict({self._storage!r})" + + # ------------------------------ + # Critical: JSON serialization support + # ------------------------------ + def to_dict(self) -> dict[KT, VT]: + """Convert ThreadSafeDict to a standard Python dict (thread-safe).""" + with self._lock: + return self._storage.copy() # Return a copy of internal data diff --git a/vllm/utils/network_utils.py b/vllm/utils/network_utils.py index 0a68e48ba5e7..5bdff432b872 100644 --- a/vllm/utils/network_utils.py +++ b/vllm/utils/network_utils.py @@ -329,3 +329,47 @@ def zmq_socket_ctx( finally: ctx.destroy(linger=linger) + + +def recv_router_dealer_message( + socket: zmq.Socket, + use_poller: bool = False, + poll_timeout: int = 1000, +) -> tuple[bool, None | bytes, None | str]: + """ + Receive multipart ZMQ messages, automatically inferring message format + based on socket type (ROUTER or DEALER). + + Returns: + (success, identity, message) + - identity is only set for ROUTER sockets + - success=False on timeout or error + """ + sock_type = socket.getsockopt(zmq.TYPE) + + # optional non-blocking receive + if use_poller: + poller = zmq.Poller() + poller.register(socket, zmq.POLLIN) + socks = dict(poller.poll(poll_timeout)) + if socket not in socks: + return (False, None, None) + + parts = socket.recv_multipart() + + if sock_type == zmq.ROUTER: + # ROUTER message: [identity, empty, message] + assert len(parts) == 3, f"expected 3 parts, got {len(parts)}" + identity_bytes, empty_frame, message_bytes = parts + identity = identity_bytes + elif sock_type == zmq.DEALER: + # DEALER message: [empty, message] + assert len(parts) == 2, f"expected 2 parts, got {len(parts)}" + empty_frame, message_bytes = parts + identity = None + else: + raise ValueError(f"Unsupported socket type: {sock_type}") + + assert empty_frame == b"", f"empty frame invalid: {empty_frame}" + message = message_bytes.decode("utf-8") + return (True, identity, message) diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 88d99d940282..010d4d56ac11 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -64,6 +64,30 @@ def get_grammar_bitmask( ) -> "GrammarOutput | None": raise NotImplementedError + @abstractmethod + def preempt_request( + self, + scheduled_timestamp: float | None = None, + preempted_req: Optional["Request"] = None, + ) -> "Request": + """ + Preempt a running request and move it back to the waiting queue. + + This method removes the specified request from the running queue (or the + last running request if none is specified), updates its status and statistics, + and moves it back to the waiting queue. Optionally records a preemption event + if logging is enabled. + + Args: + scheduled_timestamp: Optional timestamp for logging the preemption event. + preempted_req: Specific request to preempt. If None, preempt the last + request in the running queue. + + Returns: + The request that was preempted and returned to the waiting queue. + """ + raise NotImplementedError + @abstractmethod def update_from_output( self, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4323141c435b..658e98eea75f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -37,7 +37,11 @@ ) from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.utils import check_stop, remove_all -from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs +from vllm.v1.engine import ( + EngineCoreEventType, + EngineCoreOutput, + EngineCoreOutputs, +) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput @@ -712,6 +716,31 @@ def schedule(self) -> SchedulerOutput: self._update_after_schedule(scheduler_output) return scheduler_output + def preempt_request( + self, + scheduled_timestamp: float | None = None, + preempted_req: Request | None = None, + ) -> Request: + # Preempt a running request and move it back to the waiting queue. + if preempted_req is not None: + self.running.remove(preempted_req) + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + preempted_req.num_preemptions += 1 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) + + self.waiting.prepend_request(preempted_req) + + return preempted_req + def _update_after_schedule( self, scheduler_output: SchedulerOutput, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 3f621d77c024..882b3a97cede 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -187,6 +187,7 @@ class EngineCoreRequestType(enum.Enum): UTILITY = b"\x03" # Sentinel used within EngineCoreProc. EXECUTOR_FAILED = b"\x04" + PAUSE = b"\x05" class ReconfigureDistributedRequest(msgspec.Struct): diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index c160c7cbcab4..4b4abc8f23fd 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -779,6 +779,16 @@ async def scale_elastic_ep( custom_stat_loggers=None, ) + async def handle_fault( + self, instruction: str, timeout: int = 300, **kwargs + ) -> bool: + """send fault tolerance instruction to the engine""" + return await self.engine_core.handle_fault(instruction, timeout, **kwargs) + + async def get_fault_info(self): + """report exception in engine core""" + return await self.engine_core.fault_reporter() + @property def is_running(self) -> bool: # Is None before the loop is started. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d49eb752d56a..8f3c6cef1a14 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json import os import queue import signal import threading import time +import traceback from collections import deque from collections.abc import Callable, Generator from concurrent.futures import Future @@ -31,7 +33,7 @@ maybe_attach_gc_debug_callback, ) from vllm.utils.hashing import get_hash_fn_by_name -from vllm.utils.network_utils import make_zmq_socket +from vllm.utils.network_utils import make_zmq_socket, recv_router_dealer_message from vllm.utils.system_utils import decorate_logs, set_process_title from vllm.v1.core.kv_cache_utils import ( BlockHash, @@ -51,17 +53,26 @@ UtilityOutput, UtilityResult, ) +from vllm.v1.engine.exceptions import EngineLoopPausedError, FaultInfo from vllm.v1.engine.utils import ( EngineHandshakeMetadata, EngineZmqAddresses, + broadcast_instruction, get_device_indices, + wait_for_instruction_result, ) from vllm.v1.executor import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus -from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder +from vllm.v1.serial_utils import ( + MsgpackDecoder, + MsgpackEncoder, + deserialize_method_call, + run_method, + serialize_method_call, +) from vllm.v1.structured_output import StructuredOutputManager from vllm.version import __version__ as VLLM_VERSION @@ -73,6 +84,342 @@ _R = TypeVar("_R") # Return type for collective_rpc +class EngineCoreSentinel(threading.Thread): + """ + EngineCoreSentinel monitors a single EngineCore instance, responsible for: + 1. Receiving fault signals (exceptions raised in EngineCore busy loop) + 2. Receiving and executing commands from ClientSentinel + 3. Reporting execution results or faults back to the ClientSentinel + """ + + def __init__( + self, + engine_index: int, + fault_signal_q: queue.Queue, + cmd_q: queue.Queue, + busy_loop_active: threading.Event, + engine_input_q: queue.Queue, + client_cmd_addr: str, + worker_cmd_addr: str, + fault_report_addr: str, + sentinel_identity: bytes, + tp_size: int, + pp_size: int, + dp_size: int, + ): + super().__init__(daemon=True) + self.engine_index = engine_index + self.fault_signal_q = fault_signal_q + self.cmd_q = cmd_q + self.busy_loop_active = busy_loop_active + self.engine_input_q = engine_input_q + self.tp_size = tp_size + self.pp_size = pp_size + self.dp_size = dp_size + + self.ctx = zmq.Context() + # Client <-> EngineCoreSentinel sockets + self.fault_report_socket = make_zmq_socket( + self.ctx, + fault_report_addr, + zmq.DEALER, + bind=False, + identity=sentinel_identity, + ) + + self.client_cmd_socket = make_zmq_socket( + self.ctx, + client_cmd_addr, + zmq.DEALER, + bind=False, + identity=sentinel_identity, + ) + # EngineCoreSentinel <-> WorkerSentinel sockets + self.worker_cmd_socket = make_zmq_socket( + self.ctx, worker_cmd_addr, zmq.ROUTER, bind=True + ) + self.poller = zmq.Poller() + self.communicator_aborted = False + self.engine_running = True + self.engine_core_sentinel_dead = False + self.logger = self._make_engine_core_sentinel_logger() + + def _make_engine_core_sentinel_logger(self): + prefix = f"[EngineCoreSentinel_{self.engine_index}] " + + def log(msg, *args, level="info", **kwargs): + """ + level: "info", "warning", "error", "debug" + msg: log message + """ + getattr(logger, level)(prefix + msg, *args, **kwargs) + + return log + + def run(self) -> None: + """ + Run the main monitoring loop for EngineCoreSentinel. + """ + poll_timeout_ms = 100 + while not self.engine_core_sentinel_dead: + # Check for engine fault signals + try: + engine_exception = self.fault_signal_q.get_nowait() + if isinstance(engine_exception, EngineLoopPausedError): + # The busy loop stopped due to another critical exception, + # put it back + self.logger("Engine paused", level="info") + else: + self.logger( + "Detected exception %s: %s\n Call Stack:\n%s", + type(engine_exception).__name__, + engine_exception, + "".join(traceback.format_tb(engine_exception.__traceback__)), + level="error", + ) + self._report_client_exception(engine_exception) + self.engine_running = False + except queue.Empty: + pass + try: + has_msg, _, cmd_str = recv_router_dealer_message( + self.client_cmd_socket, + use_poller=True, + poll_timeout=poll_timeout_ms, + ) + except zmq.ZMQError: + self.logger( + "Socket closed, terminating EngineCoreSentinel", level="info" + ) + break + + if has_msg: + self.logger("Received cmd: %s", cmd_str, level="info") + self._execute_cmd(cmd_str) + + def _stop_worker_execution(self, soft_pause: bool, timeout: int = 2) -> bool: + if soft_pause: + pause_method = "pause_by_signal" + else: + pause_method = "pause_by_abort_communicators" + self.communicator_aborted = True + + success = self._execute_worker_method( + pause_method, timeout=timeout, worker_timeout=timeout + ) + return success + + def _execute_worker_method(self, method_name, timeout: int = 5, **kwargs) -> bool: + identities = set() + for tp_rank in range(self.tp_size): + for pp_rank in range(self.pp_size): + identity = f"{pp_rank}_{tp_rank}".encode() + identities.add(identity) + + method_uuid = broadcast_instruction( + self.worker_cmd_socket, identities, method_name, **kwargs + ) + + all_success = True + worker_responses = wait_for_instruction_result( + self.worker_cmd_socket, identities, method_name, timeout, method_uuid + ) + for identity in identities: + response = worker_responses.get(identity) + if response is None or not response.get("success", False): + all_success = False + + return all_success + + def _report_client_exception(self, exception: Exception) -> None: + msg = FaultInfo.from_exception(exception, self.engine_index).serialize() + msg_bytes = msg.encode("utf-8") + self.fault_report_socket.send_multipart([b"", msg_bytes]) + + def _execute_cmd(self, cmd_str): + """ + Execute a command received from ClientSentinel. + """ + method, method_uuid, method_params = deserialize_method_call(cmd_str) + self.logger("Executing command: %s", method, level="info") + try: + success = run_method(self, method, args=(), kwargs=method_params) + self.logger("Command (%s) succeeded: %s", method, success, level="info") + reason = None + except Exception as e: + self.logger( + "Error executing method %s: %s, %s", + method, + type(e).__name__, + e, + level="error", + ) + success = False + reason = f"{type(e).__name__}: {e}" + + self._send_execution_result(success, method_uuid, reason) + + def pause(self, timeout: int = 1, soft_pause: bool = True) -> bool: + """ + Pause the busy loop safely. + Args: + timeout:wait for the busy loop to acknowledge the pause signal + soft_pause: if True, perform a soft pause using a flag; otherwise + abort the communicator + """ + self.logger("Start pausing EngineCore", level="info") + start_time = time.monotonic() + if self.engine_running: + # Clear the flag to signal busy loop should pause + self.busy_loop_active.clear() + # Put a sentinel (empty request) to unblock the busy loop + # if it's blocked on input_queue.get() + self.engine_input_q.put((EngineCoreRequestType.PAUSE, None)) + success = self._stop_worker_execution( + soft_pause=soft_pause, + timeout=timeout, + ) + elapsed = time.monotonic() - start_time + if success: + remaining_timeout = max(0, timeout - elapsed) + try: + # Wait for engine to acknowledge the pause via fault_signal_q + exception = self.fault_signal_q.get(timeout=remaining_timeout) + self.fault_signal_q.put(exception) + success = True + self.engine_running = False + except queue.Empty: + # Timeout waiting for pause acknowledgment + success = False + else: + # already paused + success = True + if not soft_pause: + # abort the communicators + success = self._stop_worker_execution(soft_pause=False, timeout=timeout) + return success + + def retry(self, new_stateless_dp_group_port: int, timeout: int = 1): + """ + Handle the retry instruction from the ClientSentinel. + This instruction tells the EngineCore to continue its busy loop + after being suspended due to an exception. + """ + if self.engine_running: + return True + + start_time = time.monotonic() + + success = self._execute_worker_method("restore_worker", timeout=timeout) + if not success: + return success + + if self.dp_size > 1: + # If the Gloo communication times out + # the data parallel group (dp_group) needs to be reinitialized + command = "reinit_dp_group_on_fault_tolerance" + self.cmd_q.put( + serialize_method_call( + command, new_stateless_dp_group_port=new_stateless_dp_group_port + ) + ) + else: + self.cmd_q.put(None) + + # Ensure busy loop has been recovered. + elapsed = time.monotonic() - start_time + remaining_timeout = max(0, timeout - elapsed) + success = self.busy_loop_active.wait(timeout=remaining_timeout) + self.engine_running = success + assert self.cmd_q.empty(), "cmd_q must be empty after execution" + return success + + def _send_execution_result( + self, success: bool, method_uuid: str, reason: str | None + ): + msg = { + "engine_index": self.engine_index, + "success": success, + "method_uuid": method_uuid, + } + if not success and reason is not None: + msg["reason"] = reason + msg_bytes = json.dumps(msg).encode("utf-8") + self.client_cmd_socket.send_multipart([b"", msg_bytes]) + + def shutdown(self): + if self.fault_report_socket is not None: + self.fault_report_socket.close() + if self.client_cmd_socket is not None: + self.client_cmd_socket.close() + if self.worker_cmd_socket is not None: + self.worker_cmd_socket.close() + if self.ctx is not None: + self.ctx.term() + self.engine_core_sentinel_dead = True + + +def busy_loop_wrapper(busy_loop_func): + """ + Wrap the busy loop function to perform fault tolerance. + """ + + def run_with_fault_tolerance(self): + while True: + try: + if self.enable_fault_tolerance: + self.busy_loop_active.set() + busy_loop_func(self) + except SystemExit: + raise + except Exception as original_exc: + if self.enable_fault_tolerance: + self.busy_loop_active.clear() + self.fault_signal_q.put(original_exc) + logger.warning( + "[BusyLoopWrapper] EngineCore busy loop raised an exception. " + "Suspended and waiting for fault tolerance " + "instructions." + ) + + # Put running requests into waiting list. + for req in list(self.scheduler.running): + self.scheduler.preempt_request(preempted_req=req) + self.scheduler.prev_step_scheduled_req_ids.clear() + + try: + # Block until recovery command received + cmd_str = self.cmd_q.get(timeout=self.engine_recovery_timeout) + logger.debug( + "[BusyLoopWrapper] Received fault tolerance command: %s", + cmd_str, + ) + if cmd_str is not None: + method, _, params = deserialize_method_call(cmd_str) + run_method(self, method, args=(), kwargs=params) + # recovery succeeded; restart the busy loop + continue + except queue.Empty: + # No handling instruction received within predefined + # timeout period. + logger.error( + "[BusyLoopWrapper] Fault tolerance instruction not received" + " within timeout. Proceeding with default exception " + "handling." + ) + except Exception as cmd_exc: + raise RuntimeError( + "Fault tolerance execution failed." + ) from cmd_exc + + # Fault tolerance not enabled OR no instruction received + # before timeout. Re-raise the original exception + # for upper level handling. + raise original_exc + + return run_with_fault_tolerance + + class EngineCore: """Inner loop of vLLM's Engine.""" @@ -567,10 +914,6 @@ def __init__( ): self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]() - executor_fail_callback = lambda: self.input_queue.put_nowait( - (EngineCoreRequestType.EXECUTOR_FAILED, b"") - ) - self.engine_index = engine_index identity = self.engine_index.to_bytes(length=2, byteorder="little") self.engines_running = False @@ -603,6 +946,50 @@ def __init__( self._init_data_parallel(vllm_config) + # Initialize fault tolerance settings. + ft_config = vllm_config.fault_tolerance_config + self.enable_fault_tolerance = ft_config.enable_fault_tolerance + if self.enable_fault_tolerance: + # Track whether the busy loop is currently active. + self.busy_loop_active = threading.Event() + self.fault_signal_q: queue.Queue[Exception] = queue.Queue() + self.cmd_q: queue.Queue[str | None] = queue.Queue(maxsize=1) + self.engine_recovery_timeout = ft_config.engine_recovery_timeout + engine_core_sentinel_ids = addresses.engine_core_sentinel_identities + assert engine_core_sentinel_ids is not None + assert addresses.fault_report_addr is not None + assert addresses.client_cmd_addr is not None + assert addresses.engine_core_cmd_addrs is not None + engine_core_cmd_addr = addresses.engine_core_cmd_addrs[ + vllm_config.parallel_config.data_parallel_rank + ] + self.engine_core_sentinel = EngineCoreSentinel( + engine_index=self.engine_index, + fault_signal_q=self.fault_signal_q, + cmd_q=self.cmd_q, + busy_loop_active=self.busy_loop_active, + engine_input_q=self.input_queue, + fault_report_addr=addresses.fault_report_addr, + client_cmd_addr=addresses.client_cmd_addr, + worker_cmd_addr=engine_core_cmd_addr, + sentinel_identity=engine_core_sentinel_ids[self.engine_index], + tp_size=vllm_config.parallel_config.tensor_parallel_size, + pp_size=vllm_config.parallel_config.pipeline_parallel_size, + dp_size=vllm_config.parallel_config.data_parallel_size, + ) + self.engine_core_sentinel.start() + vllm_config.fault_tolerance_config.engine_core_cmd_addr = ( + engine_core_cmd_addr + ) + # Do not shut down the engine immediately upon failure. + executor_fail_callback = lambda: self.fault_signal_q.put( + RuntimeError(f"Executor on EngineCore {self.engine_index} failed.") + ) + else: + executor_fail_callback = lambda: self.input_queue.put_nowait( + (EngineCoreRequestType.EXECUTOR_FAILED, b"") + ) + super().__init__( vllm_config, executor_class, log_stats, executor_fail_callback ) @@ -851,16 +1238,23 @@ def signal_handler(signum, frame): def _init_data_parallel(self, vllm_config: VllmConfig): pass + @busy_loop_wrapper def run_busy_loop(self): """Core busy loop of the EngineCore.""" # Loop until process is sent a SIGINT or SIGTERM while True: # 1) Poll the input queue until there is work to do. + self._check_busy_loop_active() self._process_input_queue() # 2) Step the engine core and return the outputs. + self._check_busy_loop_active() self._process_engine_step() + def _check_busy_loop_active(self): + if self.enable_fault_tolerance and not self.busy_loop_active.is_set(): + raise EngineLoopPausedError("Engine busy loop is paused.") + def _process_input_queue(self): """Exits when an engine step needs to be performed.""" @@ -907,6 +1301,8 @@ def _handle_client_request( self.add_request(req, request_wave) elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) + elif request_type == EngineCoreRequestType.PAUSE: + self._check_busy_loop_active() elif request_type == EngineCoreRequestType.UTILITY: client_idx, call_id, method_name, args = request output = UtilityOutput(call_id) @@ -1099,6 +1495,11 @@ def process_output_sockets( # Limit the number of buffers to reuse. reuse_buffers.append(buffer) + def shutdown(self): + super().shutdown() + if self.vllm_config.fault_tolerance_config.enable_fault_tolerance: + self.engine_core_sentinel.shutdown() + class DPEngineCoreProc(EngineCoreProc): """ZMQ-wrapper for running EngineCore in background process @@ -1153,7 +1554,9 @@ def _init_data_parallel(self, vllm_config: VllmConfig): ) self.dp_rank = dp_rank - self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() + self.dp_group = vllm_config.parallel_config.stateless_init_dp_group( + vllm_config.fault_tolerance_config.gloo_comm_timeout, + ) def shutdown(self): super().shutdown() @@ -1201,15 +1604,18 @@ def _maybe_publish_request_counts(self): ) self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats))) + @busy_loop_wrapper def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" # Loop until process is sent a SIGINT or SIGTERM while True: # 1) Poll the input queue until there is work to do. + self._check_busy_loop_active() self._process_input_queue() # 2) Step the engine core. + self._check_busy_loop_active() executed = self._process_engine_step() self._maybe_publish_request_counts() @@ -1221,9 +1627,11 @@ def run_busy_loop(self): # We are in a running state and so must execute a dummy pass # if the model didn't execute any ready requests. + self._check_busy_loop_active() self.execute_dummy_batch() # 3) All-reduce operation to determine global unfinished reqs. + self._check_busy_loop_active() self.engines_running = self._has_global_unfinished_reqs( local_unfinished_reqs ) @@ -1256,6 +1664,14 @@ def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) + def reinit_dp_group_on_fault_tolerance(self, new_stateless_dp_group_port: int): + stateless_destroy_torch_distributed_process_group(self.dp_group) + self.dp_group = self.vllm_config.parallel_config.stateless_init_dp_group( + self.vllm_config.fault_tolerance_config.gloo_comm_timeout, + new_stateless_dp_group_port, + ) + self.step_counter = 0 + def reinitialize_distributed( self, reconfig_request: ReconfigureDistributedRequest ) -> None: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 9b440505bd9d..91a597b3a8a9 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -2,9 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import contextlib +import json import multiprocessing import queue import sys +import threading +import time import uuid import weakref from abc import ABC, abstractmethod @@ -16,19 +19,23 @@ from typing import Any, TypeAlias, TypeVar import msgspec.msgpack +import regex as re import zmq import zmq.asyncio +from ray.util.state import get_actor from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask from vllm.utils.async_utils import in_loop +from vllm.utils.collection_utils import ThreadSafeDict from vllm.utils.network_utils import ( close_sockets, get_open_port, get_open_zmq_inproc_path, make_zmq_socket, + recv_router_dealer_message, ) from vllm.v1.engine import ( EngineCoreOutputs, @@ -40,14 +47,15 @@ ) from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc -from vllm.v1.engine.exceptions import EngineDeadError +from vllm.v1.engine.exceptions import EngineDeadError, FaultInfo from vllm.v1.engine.utils import ( CoreEngineActorManager, CoreEngineProcManager, + FaultHandler, launch_core_engines, ) from vllm.v1.executor import Executor -from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr, run_method logger = init_logger(__name__) @@ -249,6 +257,12 @@ async def collective_rpc_async( ) -> list[_R]: raise NotImplementedError + async def handle_fault(self, instruction: str, timeout: int) -> bool: + raise NotImplementedError + + async def fault_reporter(self): + raise NotImplementedError + class InprocClient(EngineCoreClient): """ @@ -332,6 +346,131 @@ def dp_engines_running(self) -> bool: return False +class ClientSentinel: + def __init__( + self, + fault_receiver_addr: str, + cmd_addr: str, + engine_registry: dict[int, bytes], + engine_exception_q: queue.Queue[FaultInfo], + fault_pub_addr: str, + engine_status_dict: ThreadSafeDict[int, str], + ): + self.is_faulted = threading.Event() + self.engine_registry = engine_registry + self.zmq_ctx = zmq.Context() + self.fault_receiver_socket = make_zmq_socket( + ctx=self.zmq_ctx, + path=fault_receiver_addr, + socket_type=zmq.ROUTER, + bind=True, + ) + self.cmd_socket = make_zmq_socket( + ctx=self.zmq_ctx, path=cmd_addr, socket_type=zmq.ROUTER, bind=True + ) + + self.fault_pub_socket = make_zmq_socket( + ctx=self.zmq_ctx, path=fault_pub_addr, socket_type=zmq.PUB, bind=True + ) + + self.engine_exception_q: queue.Queue[FaultInfo] = engine_exception_q + + self.engine_status_dict: ThreadSafeDict[int, str] = engine_status_dict + + self.fault_handler = FaultHandler( + self.cmd_socket, + self.engine_registry, + self.engine_exception_q, + self.engine_status_dict, + ) + + self.logger = self._make_client_sentinel_logger() + + self.client_sentinel_dead = False + Thread( + target=self.fault_receiver, daemon=True, name="EngineCoreFaultReceiver" + ).start() + + def _make_client_sentinel_logger(self): + prefix = "[client_sentinel] " + + def log(msg, *args, level="info", **kwargs): + """ + level: "info", "warning", "error", "debug" + msg: log message + """ + getattr(logger, level)(prefix + msg, *args, **kwargs) + + return log + + async def handle_fault(self, instruction: str, timeout: int, **kwargs) -> bool: + """ + Executes fault tolerance measures based on the fault tolerance instructions + received from the api_server. + + This method processes the fault tolerance commands/instructions passed by the + api_server, then implements corresponding fault tolerance strategies or actions + to handle system anomalies, ensuring stable operation or graceful degradation + of the relevant components. + """ + result = await run_method( + self.fault_handler, + "handle_fault", + args=(instruction, timeout), + kwargs=kwargs, + ) + if result: + self.is_faulted.clear() + return result + + def fault_receiver(self): + """ + Continuously listens for exception/error information sent from the engine_core. + + This method maintains a persistent listening state to capture and process + fault-related data, exceptions, or error notifications emitted by the + engine_core component. It is designed to run continuously to ensure no critical + error information from the engine core is missed. + """ + while not self.client_sentinel_dead: + try: + _, sender_identity, message = recv_router_dealer_message( + self.fault_receiver_socket + ) + assert message is not None, ( + "message should not be None at fault tolerance scenario" + ) + + fault_info = FaultInfo.from_json(message) + self.engine_exception_q.put_nowait(fault_info) + engine_status = "Dead" if "dead" in fault_info.type else "Unhealthy" + self.engine_status_dict[int(fault_info.engine_id)] = engine_status + self.fault_pub_socket.send_string( + f"vllm_fault|{json.dumps(self.engine_status_dict.to_dict())}" + ) + self.is_faulted.set() + # Pause healthy engines on fault. + # Pause can be invoked again during fault-tolerance handling, + # so it's unnecessary to track whether all engines are currently + # paused. + self.fault_handler.submit_fault("pause", 5, soft_pause=False) + except zmq.ZMQError: + # Socket was closed during polling, exit loop. + self.logger( + "Fault receiver socket closed, stopping thread.", level="info" + ) + break + self.logger("Fault receiver thread has stopped.") + + def shutdown_sentinel(self): + self.client_sentinel_dead = True + self.fault_receiver_socket.close() + self.cmd_socket.close() + self.fault_pub_socket.close() + self.zmq_ctx.term() + self.logger("ClientSentinel is closed.", level="info") + + @dataclass class BackgroundResources: """Used as a finalizer for clean shutdown, avoiding @@ -350,6 +489,7 @@ class BackgroundResources: output_queue_task: asyncio.Task | None = None stats_update_task: asyncio.Task | None = None shutdown_path: str | None = None + client_sentinel: ClientSentinel | None = None # Set if any of the engines are dead. Here so that the output # processing threads can access it without holding a ref to the client. @@ -363,6 +503,8 @@ def __call__(self): self.engine_manager.close() if self.coordinator is not None: self.coordinator.close() + if self.client_sentinel is not None: + self.client_sentinel.shutdown_sentinel() if isinstance(self.output_socket, zmq.asyncio.Socket): # Async case. @@ -454,6 +596,7 @@ def __init__( self.resources = BackgroundResources(ctx=sync_ctx) self._finalizer = weakref.finalize(self, self.resources) success = False + try: # State used for data parallel. self.engines_running = False @@ -536,8 +679,39 @@ def __init__( self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]() # Start monitoring engine core processes for unexpected failures - self.start_engine_core_monitor() + if self.vllm_config.parallel_config.data_parallel_backend == "ray": + self.start_engine_core_actor_monitor() + else: + self.start_engine_core_monitor() + if vllm_config.fault_tolerance_config.enable_fault_tolerance: + self.engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() + assert addresses.fault_report_addr is not None, ( + "addresses.fault_report_addr should not be None at fault tolerance" + " scenario" + ) + assert addresses.client_cmd_addr is not None, ( + "addresses.client_cmd_addr should not be None at fault tolerance" + " scenario" + ) + self.engine_registry = addresses.engine_core_sentinel_identities + assert self.engine_registry is not None + assert addresses.fault_pub_socket_addr is not None, ( + "addresses.fault_pub_socket_addr should not be None at" + "fault tolerance scenario" + ) + self.engine_status_dict: ThreadSafeDict[int, str] = ThreadSafeDict() + for engine_id in range(vllm_config.parallel_config.data_parallel_size): + self.engine_status_dict[engine_id] = "Healthy" + self.client_sentinel = ClientSentinel( + addresses.fault_report_addr, + addresses.client_cmd_addr, + self.engine_registry, + self.engine_exception_q, + addresses.fault_pub_socket_addr, + self.engine_status_dict, + ) + self.resources.client_sentinel = self.client_sentinel success = True finally: if not success: @@ -568,6 +742,51 @@ def free_pending_messages(self): def dp_engines_running(self) -> bool: return self.engines_running + def start_engine_core_actor_monitor(self): + engine_manager = self.resources.engine_manager + if ( + not isinstance(engine_manager, CoreEngineActorManager) + or not self.vllm_config.fault_tolerance_config.enable_fault_tolerance + ): + return + + def monitor_actors(): + all_actors = ( + engine_manager.local_engine_actors + engine_manager.remote_engine_actors + ) + if not all_actors: + return + while True: + for actor in all_actors[:]: + actor_id = actor._actor_id.hex() + if actor in engine_manager.local_engine_actors: + actor_index = engine_manager.local_engine_actors.index(actor) + elif actor in engine_manager.remote_engine_actors: + actor_index = engine_manager.remote_engine_actors.index( + actor + ) + len(engine_manager.local_engine_actors) + else: + logger.error("Unknown actor (ID: %s)", actor_id) + continue + + actor_info = get_actor(actor_id) + if actor_info.state == "DEAD": + fault_info = FaultInfo( + type="engine_actor dead", + message=str(actor_info.death_cause), + engine_id=str(actor_index), + additional_info=None, + ) + all_actors.remove(actor) + engine_manager.engine_down_socket.send_multipart( + [b"", fault_info.serialize().encode("utf-8")] + ) + + time.sleep(3) + # Implements the "check once every 3 seconds" frequency control + + Thread(target=monitor_actors, daemon=True, name="MPClientEngineMonitor").start() + def start_engine_core_monitor(self): """Start a monitor thread for engine core processes.""" engine_manager = self.resources.engine_manager @@ -587,19 +806,54 @@ def start_engine_core_monitor(self): # callback to inform the engine. def monitor_engine_cores(): sentinels = [proc.sentinel for proc in engine_processes] - died = multiprocessing.connection.wait(sentinels) + _self = self_ref() - if not _self or _self.resources.engine_dead: - return - _self.resources.engine_dead = True - proc_name = next( - proc.name for proc in engine_processes if proc.sentinel == died[0] - ) - logger.error( - "Engine core proc %s died unexpectedly, shutting down client.", - proc_name, - ) - _self.shutdown() + if self.vllm_config.fault_tolerance_config.enable_fault_tolerance: + while sentinels: + died = multiprocessing.connection.wait(sentinels) + for sentinel in died: + died_proc = next( + proc + for proc in engine_processes + if proc.sentinel == sentinel + ) + + match = re.match(r"EngineCore_DP(\d+)", died_proc.name) + engine_rank = match.group(1) + + fault_info = FaultInfo( + type="engine_core dead", + message=f"Engine core proc {died_proc.pid} " + f"(PID: {died_proc.name}) died unexpectedly.", + engine_id=engine_rank, + additional_info=None, + ) + + engine_manager.engine_down_socket.send_multipart( + [b"", fault_info.serialize().encode("utf-8")] + ) + + sentinels.remove(sentinel) + + logger.error( + "Engine core proc %s died unexpectedly", + died_proc.name, + ) + else: + died = multiprocessing.connection.wait(sentinels) + _self = self_ref() + if not _self or _self.resources.engine_dead: + return + proc_name = next( + proc.name for proc in engine_processes if proc.sentinel == died[0] + ) + logger.error( + "Engine core proc %s died unexpectedly, shutting down client.", + proc_name, + ) + if _self and _self.resources: + _self.resources.engine_dead = True + _self.shutdown() # Note: For MPClient, we don't have a failure callback mechanism # like MultiprocExecutor, but we set engine_dead flag which will # cause subsequent operations to raise EngineDeadError @@ -608,6 +862,13 @@ def monitor_engine_cores(): target=monitor_engine_cores, daemon=True, name="MPClientEngineMonitor" ).start() + async def handle_fault(self, instruction: str, timeout: int, **kwargs) -> bool: + """handle fault of current instance by instruction""" + return await self.client_sentinel.handle_fault(instruction, timeout, **kwargs) + + async def fault_reporter(self): + return self.engine_status_dict.to_dict() + def _process_utility_output( output: UtilityOutput, utility_results: dict[int, AnyFuture] diff --git a/vllm/v1/engine/exceptions.py b/vllm/v1/engine/exceptions.py index d9f79a019e2d..4f24c7fa9080 100644 --- a/vllm/v1/engine/exceptions.py +++ b/vllm/v1/engine/exceptions.py @@ -1,5 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import time +from dataclasses import dataclass + + class EngineGenerateError(Exception): """Raised when a AsyncLLM.generate() fails. Recoverable.""" @@ -16,3 +21,69 @@ def __init__(self, *args, suppress_context: bool = False, **kwargs): # Make stack trace clearer when using with LLMEngine by # silencing irrelevant ZMQError. self.__suppress_context__ = suppress_context + + +class EngineLoopPausedError(Exception): + """ + Raised when the EngineCore loop is temporarily paused on purpose, + e.g., to handle fault-tolerance. + """ + + pass + + +@dataclass +class FaultInfo: + type: str + message: str + engine_id: str + timestamp: str | None = None + additional_info: dict | None = None + + def __post_init__(self): + # If no exit time is specified, the current timestamp will be used by default. + + local_time = time.localtime(time.time()) + if self.timestamp is None: + self.timestamp = time.strftime("%H:%M:%S", local_time) + + @classmethod + def from_exception( + cls, + exception: Exception, + engine_id: str | int, + additional_info: dict | None = None, + ) -> "FaultInfo": + """Create FaultInfo from an exception.""" + return cls( + type=type(exception).__name__, + message=str(exception), + engine_id=str(engine_id), + additional_info=additional_info or {}, + ) + + def to_dict(self) -> dict: + """Convert to dictionary for serialization.""" + return { + "type": self.type, + "message": self.message, + "timestamp": self.timestamp, + "engine_id": self.engine_id, + "additional_info": self.additional_info, + } + + def serialize(self) -> str: + """Serialize to JSON string.""" + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str: str) -> "FaultInfo": + """Create FaultInfo from JSON string.""" + data = json.loads(json_str) + return cls( + type=data["type"], + message=data["message"], + timestamp=data["timestamp"], + engine_id=data["engine_id"], + additional_info=data["additional_info"], + ) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index d65cad7af03d..907cdb7c9bbf 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -1,8 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import asyncio import contextlib +import json +import multiprocessing import os +import queue +import time +import uuid import weakref from collections.abc import Callable, Iterator from dataclasses import dataclass @@ -13,6 +18,7 @@ from unittest.mock import patch import msgspec +import regex as re import zmq from vllm import envs @@ -20,10 +26,19 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy -from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx +from vllm.utils.collection_utils import ThreadSafeDict +from vllm.utils.network_utils import ( + get_open_port, + get_open_zmq_ipc_path, + make_zmq_socket, + recv_router_dealer_message, + zmq_socket_ctx, +) from vllm.utils.system_utils import get_mp_context from vllm.v1.engine.coordinator import DPCoordinator +from vllm.v1.engine.exceptions import FaultInfo from vllm.v1.executor import Executor +from vllm.v1.serial_utils import run_method, serialize_method_call from vllm.v1.utils import get_engine_client_zmq_addr, shutdown if TYPE_CHECKING: @@ -56,6 +71,8 @@ class EngineZmqAddresses: inputs: list[str] # ZMQ output socket addresses for each front-end client (responses) outputs: list[str] + + engine_core_cmd_addrs: list[str] | None = None # ZMQ input socket address of DP coordinator if applicable coordinator_input: str | None = None # ZMQ output socket address of DP coordinator if applicable @@ -64,6 +81,14 @@ class EngineZmqAddresses: # Not used by engine, just relayed to front-end in handshake response. # Only required for external DP LB case. frontend_stats_publish_address: str | None = None + # + fault_report_addr: str | None = None + # ZMQ client_cmd socket address of client sentinel + client_cmd_addr: str | None = None + # identities of engine_core_sentinel + engine_core_sentinel_identities: dict[int, bytes] | None = None + # ZMQ fault_pub_socket address of client sentinel + fault_pub_socket_addr: str | None = None @dataclass @@ -105,7 +130,23 @@ def __init__( "executor_class": executor_class, "log_stats": log_stats, } - + if vllm_config.fault_tolerance_config.enable_fault_tolerance: + zmq_ctx = zmq.Context() + identity = generate_identity_group( + "core_engine_proc_manager", "client_sentinel", "report", 1 + )[0] + zmq_addr = get_engine_client_zmq_addr( + local_only=False, + host=vllm_config.parallel_config.data_parallel_master_ip, + port=vllm_config.fault_tolerance_config.internal_fault_report_port, + ) + self.engine_down_socket = make_zmq_socket( + ctx=zmq_ctx, + path=zmq_addr, + socket_type=zmq.DEALER, + bind=False, + identity=identity, + ) if client_handshake_address: common_kwargs["client_handshake_address"] = client_handshake_address @@ -131,6 +172,8 @@ def __init__( self._finalizer = weakref.finalize(self, shutdown, self.processes) + self.vllm_config = vllm_config + data_parallel = vllm_config.parallel_config.data_parallel_size > 1 try: for proc, local_dp_rank in zip(self.processes, local_dp_ranks): @@ -154,12 +197,53 @@ def __init__( if self.finished_procs(): self.close() + def _report_engine_dead(self, dead_message): + """Send engine dead message to ClientSentinel""" + try: + self.engine_down_socket.send_multipart( + [ + b"", # Empty frame separator + dead_message.encode("utf-8"), + ] + ) + logger.info("Sent message to ClientSentinel: %s", dead_message) + except Exception as e: + logger.error("Failed to send message: %s", e) + def close(self): """Shutdown all procs.""" self._finalizer() + def start_engine_core_monitor(self): + sentinels = [proc.sentinel for proc in self.processes] + while self.processes: + died = multiprocessing.connection.wait(sentinels) + for sentinel in died: + died_proc = next( + proc for proc in self.processes if proc.sentinel == sentinel + ) + + match = re.match(r"EngineCore_DP(\d+)", died_proc.name) + engine_rank = match.group(1) + + fault_info = FaultInfo( + type="engine_core dead", + message=f"Engine core proc {died_proc.pid} " + f"(PID: {died_proc.name}) died unexpectedly.", + engine_id=engine_rank, + additional_info=None, + ) + self.engine_down_socket.send_multipart( + [b"", fault_info.serialize().encode("utf-8")] + ) + if isinstance(sentinel, int) and sentinel in sentinels: + sentinels.remove(sentinel) + logger.error( + "Engine core proc %s died unexpectedly", + died_proc, + ) + def join_first(self): - """Wait for any process to exit.""" connection.wait(proc.sentinel for proc in self.processes) def sentinels(self) -> list: @@ -267,6 +351,24 @@ def __init__( local_engine_count = vllm_config.parallel_config.data_parallel_size_local world_size = vllm_config.parallel_config.world_size + if vllm_config.fault_tolerance_config.enable_fault_tolerance: + zmq_ctx = zmq.Context() + zmq_addr = get_engine_client_zmq_addr( + local_only=False, + host=vllm_config.parallel_config.data_parallel_master_ip, + port=vllm_config.fault_tolerance_config.internal_fault_report_port, + ) + identity = generate_identity_group( + "core_engine_actor_manager", "clinet_sentinel", "report", 1 + )[0] + self.engine_down_socket = make_zmq_socket( + ctx=zmq_ctx, + path=zmq_addr, + socket_type=zmq.DEALER, + bind=False, + identity=identity, + ) + if ray.is_initialized(): logger.info("Ray is already initialized. Skipping Ray initialization.") else: @@ -332,6 +434,7 @@ def __init__( local_dp_rank=local_index, ) ) + if local_client: self.local_engine_actors.append(actor) else: @@ -808,6 +911,33 @@ def launch_core_engines( ], ) + if vllm_config.fault_tolerance_config.enable_fault_tolerance is True: + addresses.engine_core_cmd_addrs = [ + get_engine_client_zmq_addr(client_local_only, host) for _ in range(dp_size) + ] + addresses.fault_report_addr = get_engine_client_zmq_addr( + local_only=False, + host=vllm_config.parallel_config.data_parallel_master_ip, + port=vllm_config.fault_tolerance_config.internal_fault_report_port, + ) + addresses.client_cmd_addr = get_engine_client_zmq_addr( + local_only=client_local_only, host=host + ) + identity_group = generate_identity_group( + peer1="client", + peer2="engine_core_sentinel", + use="report and cmd", + n=dp_size, + ) + addresses.engine_core_sentinel_identities = { + rank: identity for rank, identity in enumerate(identity_group) + } + addresses.fault_pub_socket_addr = get_engine_client_zmq_addr( + local_only=False, + host="0.0.0.0", + port=vllm_config.fault_tolerance_config.external_fault_notify_port, + ) + # Run the DP Coordinator process with rank 0 when in # online DP mode. run_coordinator = dp_size > 1 and not offline_mode and dp_rank == 0 @@ -1070,3 +1200,301 @@ def wait_for_engine_startup( "local" if local else "remote", eng_index, ) + + +def generate_unique_uuids(n: int) -> set[uuid.UUID]: + """Generate a set of unique UUID v4 objects. + + Generates a specified number of unique UUID (version 4) objects. + UUID v4 uses cryptographically strong random numbers, ensuring + an extremely low probability of collisions. + + Args: + n: The number of unique UUIDs to generate + + Returns: + A set containing 'n' unique UUID objects + """ + uuids: set[uuid.UUID] = set() + while len(uuids) < n: + # Generate a random UUID (version 4) and add to the set + uuids.add(uuid.uuid4()) + return uuids + + +def generate_identity_group(peer1, peer2, use, n): + """ + Generate n unique identities for ZMQ ROUTER nodes + + Format: peer1_peer2_use_random number + Return: list with identities in byte type as elements + """ + identitys = list() + uuids = generate_unique_uuids(n) + for id in uuids: + identity_str = f"{peer1}_{peer2}_{use}_{id}".encode() + identitys.append(identity_str) + return identitys + + +def broadcast_instruction( + cmd_socket, + target_identities: set[bytes] | list[bytes], + method_name: str, + method_uuid: str | None = None, + **kwargs, +) -> str: + """ + Broadcast an instruction message to multiple remote endpoints. + It serializes the specified method_name along with its parameters and + dispatches it to all target identities via the provided ZeroMQ socket. + """ + if method_uuid is None: + method_uuid = str(uuid.uuid4()) + + for identity in target_identities: + serialized_instruction = serialize_method_call( + method_name, method_uuid, **kwargs + ) + cmd_socket.send_multipart( + [identity, b"", serialized_instruction.encode("utf-8")] + ) + + return method_uuid + + +def wait_for_instruction_result( + cmd_socket: zmq.Socket, + target_identities: set[bytes] | list[bytes], + method_name: str, + timeout: int, + method_uuid: str, +) -> dict[bytes, dict]: + """ + Wait for acknowledgment or result messages from multiple endpoints. + This function listens for responses corresponding to a previously broadcasted + instruction, identified by the given `method_uuid`. + + Args: + cmd_socket: The socket used to receive responses. + target_identities: Identities that are expected to respond. + method_name: The name of the method_name (used for logging). + timeout: The maximum wait time (in seconds). + method_uuid: The unique identifier associated with the method_name. + + Notes: + - This function does not raise exceptions for timeouts or parsing errors. + Instead, it logs the issue and returns whatever responses have been collected. + """ + start = time.monotonic() + responses: dict[bytes, dict] = {} + + target_identities = set(target_identities) + + while target_identities: + remaining = timeout - (time.monotonic() - start) + if remaining <= 0: + logger.debug( + 'Timeout while waiting for responses of command "%s" ' + "from identities: %s", + method_name, + target_identities, + ) + # Return partial results collected so far + return responses + + try: + has_msg, identity, response = recv_router_dealer_message( + cmd_socket, + use_poller=True, + poll_timeout=int(remaining * 1000), + ) + + # Skip if no message was received during this polling period + if not has_msg: + continue + + assert identity is not None + assert response is not None + response_dict = json.loads(response) + recv_uuid = response_dict.get("method_uuid") + + # Ignore outdated or unrelated messages + if recv_uuid != method_uuid: + logger.debug( + "Discarding outdated response: expected method_uuid=%s, got %s", + method_uuid, + recv_uuid, + ) + continue + + # Record this engine's response + responses[identity] = response_dict + target_identities.discard(identity) + + except Exception as e: + logger.error("Error while processing engine response: %s", e) + # Return partial results even on exception to avoid data loss + return responses + + return responses + + +class FaultHandler: + def __init__( + self, + cmd_socket: zmq.Socket, + client_cmd_registry: dict[int, bytes], + engine_exception_q: queue.Queue[FaultInfo], + engine_status_dict: ThreadSafeDict[int, str], + ) -> None: + self.cmd_socket = cmd_socket + self.engine_exception_q = engine_exception_q + self.engine_status_dict: ThreadSafeDict[int, str] = engine_status_dict + self.engine_identity_to_index: dict[bytes, int] = { + identity: i for i, identity in client_cmd_registry.items() + } + # ensure handle_fault is executed sequentially + self._task_queue: asyncio.Queue = asyncio.Queue() + self._loop = asyncio.get_event_loop() + self._dispatcher_task = self._loop.create_task(self._dispatcher()) + + self.logger = self._make_fault_handler_logger() + + def _make_fault_handler_logger(self): + prefix = "[FaultHandler] " + + def log(msg, *args, level="info", **kwargs): + """ + level: "info", "warning", "error", "debug" + msg: log message + """ + getattr(logger, level)(prefix + msg, *args, **kwargs) + + return log + + async def _dispatcher(self): + while True: + # each elements in the queue contains: + # (instruction, timeout, kwargs, future) + instruction, timeout, kwargs, fut = await self._task_queue.get() + try: + result = await self._handle_fault_internal( + instruction, timeout, **kwargs + ) + if fut: + fut.set_result(result) + except Exception as e: + if fut: + fut.set_exception(e) + + def retry(self, **kwargs): + if "Dead" in self.engine_status_dict.values(): + self.logger( + "Engine core is dead; retry won't work.", + level="warning", + ) + return False, set(), kwargs + + target_engines = set(self.engine_identity_to_index.keys()) + kwargs["new_stateless_dp_group_port"] = get_open_port() + return True, target_engines, kwargs + + def pause(self, **kwargs): + self.logger( + "Pause operation is best-effort only. Due to the complexity of " + "collective communications (e.g., timing dependencies and " + "synchronization barriers), pausing may not always succeed. If " + "the process remains unresponsive or collective operations " + "cannot be interrupted, consider shutting down and restarting " + "the instance.", + level="warning", + ) + + alive_engines = { + identity + for identity, index in self.engine_identity_to_index.items() + if self.engine_status_dict.get(index) != "Dead" + } + return True, alive_engines, kwargs + + async def _handle_fault_internal( + self, instruction: str, timeout: int, **kwargs + ) -> bool: + success, target_engines, kwargs = run_method(self, instruction, (), kwargs) + + if not success: + return False + + if timeout is not None: + kwargs["timeout"] = timeout + + method_uuid = broadcast_instruction( + self.cmd_socket, + target_engines, + instruction, + **kwargs, + ) + + engine_responses = wait_for_instruction_result( + self.cmd_socket, target_engines, instruction, timeout, method_uuid + ) + + # check the execution results + all_success = True + for engine_id in target_engines: + engine_index = self.engine_identity_to_index.get(engine_id, "?") + response = engine_responses.get(engine_id) + + if response is None: + self.logger( + "EngineCoreSentinel[%s] did not respond" + ' to command "%s" within timeout.', + engine_index, + instruction, + level="info", + ) + all_success = False + elif not response.get("success", False): + self.logger( + "EngineCoreSentinel[%s] failed to execute " + 'command "%s" (reason: %s)', + engine_index, + instruction, + response.get("reason", "unknown"), + level="error", + ) + all_success = False + + if instruction == "retry" and all_success: + for engine_index, _ in self.engine_status_dict.items(): + self.engine_status_dict[engine_index] = "Healthy" + while not self.engine_exception_q.empty(): + try: + self.engine_exception_q.get_nowait() + except queue.Empty: + break + + return all_success + + async def handle_fault(self, instruction: str, timeout: int, **kwargs) -> bool: + """ + Async interface for run_method, returns a Future that can be awaited. + This method **must be called from the event loop thread** where this + FaultHandler was created. + """ + fut = self._loop.create_future() + await self._task_queue.put((instruction, timeout, kwargs, fut)) + return await fut + + def submit_fault(self, instruction: str, timeout: int, **kwargs) -> None: + """ + thread-safe fire-and-forget submission of a fault handling task. + This method can be called from **any thread** + """ + + def _enqueue(): + fut = self._loop.create_future() + self._task_queue.put_nowait((instruction, timeout, kwargs, fut)) + + self._loop.call_soon_threadsafe(_enqueue) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 0a6806390451..1eead471df47 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -3,7 +3,9 @@ import dataclasses import importlib +import json import pickle +import uuid from collections.abc import Callable, Sequence from functools import partial from inspect import isclass @@ -452,6 +454,54 @@ def ext_hook(self, code: int, data: memoryview) -> Any: raise NotImplementedError(f"Extension type code {code} is not supported") +def deserialize_method_call(json_str: str) -> tuple[str, str, dict[str, Any]]: + """ + Deserialize an encoded method call. + + Args: + json_str (str): JSON string representing a serialized method call. + + Returns: + tuple[str, dict[str, Any]]: + - method (str): The method name. + - method_uuid (str): The UUID identifying the method call. + - params (dict[str, Any]): Additional method parameters. + """ + try: + payload = json.loads(json_str) + if not isinstance(payload, dict): + raise ValueError("Top-level JSON must be an object") + except Exception as e: + logger.error("Invalid JSON input: %s", e) + raise ValueError(f"Invalid JSON: {e}") from e + + try: + method = payload.pop("method") + method_uuid = payload.pop("method_uuid") + except KeyError as e: + logger.error( + "Missing required field: %s (payload=%s)", e.args[0], json_str[:200] + ) + raise ValueError(f"Missing required field: {e.args[0]}") from e + + # Remaining fields are treated as parameters + params = payload + + return method, method_uuid, params + + +def serialize_method_call( + method: str, method_uuid: str | None = None, **params: Any +) -> str: + """ + Serialize a method invocation into a JSON string. + """ + if method_uuid is None: + method_uuid = str(uuid.uuid4()) + payload = {"method": method, "method_uuid": method_uuid, **params} + return json.dumps(payload) + + def run_method( obj: Any, method: str | bytes | Callable, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0102ca4739ad..ae2ee4cdf809 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3,6 +3,7 @@ import gc import itertools +import threading import time from collections import defaultdict from collections.abc import Iterator @@ -98,6 +99,7 @@ split_attn_metadata, ) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher +from vllm.v1.engine.exceptions import EngineLoopPausedError from vllm.v1.kv_cache_interface import ( AttentionSpec, ChunkedLocalAttentionSpec, @@ -574,6 +576,12 @@ def __init__( self.execute_model_state: ExecuteModelState | None = None self.kv_connector_output: KVConnectorOutput | None = None + self.pause_event = threading.Event() + + def _check_pause_event(self): + if self.pause_event.is_set(): + raise EngineLoopPausedError("Worker is paused.") + def reset_mm_cache(self) -> None: if self.mm_budget: self.mm_budget.reset_cache() @@ -2614,6 +2622,7 @@ def _model_forward( Returns: Model output tensor """ + self._check_pause_event() return self.model( input_ids=input_ids, positions=positions, @@ -3848,6 +3857,7 @@ def _dummy_run( ubatch_slices=ubatch_slices, ), ): + self._check_pause_event() outputs = self.model( input_ids=input_ids, positions=positions, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 315f01b68499..d6fd3e7fd48a 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -3,18 +3,28 @@ """A GPU worker class.""" import gc +import json import os +import threading +import time +import traceback +from collections.abc import Callable +from concurrent.futures import FIRST_EXCEPTION, ThreadPoolExecutor, wait from contextlib import AbstractContextManager, nullcontext +from datetime import timedelta +from functools import partial from types import NoneType from typing import TYPE_CHECKING, Any import torch import torch.distributed import torch.nn as nn +import zmq import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed import ( + cleanup_dist_env_and_memory, ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce, @@ -26,6 +36,8 @@ has_kv_transfer_group, ) from vllm.distributed.parallel_state import ( + GroupCoordinator, + get_all_model_groups, get_pp_group, get_tp_group, ) @@ -40,6 +52,7 @@ from vllm.tasks import SupportedTask from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_utils import MemorySnapshot, memory_profiling +from vllm.utils.network_utils import make_zmq_socket, recv_router_dealer_message from vllm.v1.core.sched.output import GrammarOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec @@ -48,6 +61,7 @@ DraftTokenIds, ModelRunnerOutput, ) +from vllm.v1.serial_utils import deserialize_method_call, run_method from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.utils import is_residual_scattered_for_sp @@ -60,6 +74,179 @@ from vllm.v1.core.sched.output import SchedulerOutput +class WorkerSentinel: + def __init__( + self, + vllm_config: VllmConfig, + pause_event: threading.Event, + init_distributed_env_callback: Callable, + clear_input_batch_callback: Callable, + device: torch.cuda.device, + ): + self.vllm_config = vllm_config + self.zmq_ctx = zmq.Context() + self.dp_rank = vllm_config.parallel_config.data_parallel_rank + self.tp_rank = get_tp_group().rank_in_group + self.pp_rank = get_pp_group().rank_in_group + self.init_distributed_env_callback = init_distributed_env_callback + self.clear_input_batch_callback = clear_input_batch_callback + self.device = device + identity = f"{self.pp_rank}_{self.tp_rank}".encode() + worker_cmd_addr = vllm_config.fault_tolerance_config.engine_core_cmd_addr + self.cmd_socket = make_zmq_socket( + ctx=self.zmq_ctx, + path=worker_cmd_addr, + socket_type=zmq.DEALER, + bind=False, + identity=identity, + ) + self.worker_sentinel_dead = False + self.pause_event = pause_event + self.communicator_aborted = False + self.logger = self._make_worker_logger() + threading.Thread( + target=self.run, daemon=True, name="WorkerSentinelCmdReceiver" + ).start() + + def _make_worker_logger(self): + prefix = f"[WorkerSentinel_dp{self.dp_rank}_pp{self.pp_rank}_tp{self.tp_rank}] " + + def log(msg, *args, level="info", **kwargs): + """ + level: "info", "warning", "error", "debug" + msg: log message + """ + getattr(logger, level)(prefix + msg, *args, **kwargs) + + return log + + def run(self): + """Run the message receiving loop and handle control commands""" + torch.cuda.set_device(self.device) + while not self.worker_sentinel_dead: + try: + # Use blocking receive - will wait until a message arrives + has_msg, _, cmd_str = recv_router_dealer_message(self.cmd_socket) + if has_msg: + assert cmd_str is not None + method, method_uuid, params = deserialize_method_call(cmd_str) + self.logger("Executing command: %s, %s", method, params) + + try: + success = run_method(self, method, args=(), kwargs=params) + except Exception as e: + self.logger( + "Error executing method %s: %s %s\n Call Stack:\n %s", + method, + type(e).__name__, + e, + "".join(traceback.format_tb(e.__traceback__)), + level="error", + ) + success = False + self._send_execution_result(success, method_uuid) + except zmq.ZMQError: + # Socket was closed, exit loop. + self.logger("Command socket closed, stopping thread.", level="info") + break + self.logger("Worker sentinel thread has stopped.") + + def pause_by_signal(self): + self._set_device_communicator_status(False) + self.pause_event.set() + self.logger("Pause signal sent.") + return True + + def pause_by_abort_communicators(self, worker_timeout=5): + """ + Abort all NCCL communicators and process groups in parallel using a thread pool. + """ + if self.communicator_aborted: + return True + self._set_device_communicator_status(False) + torch.cuda.set_device(self.device) + model_groups = get_all_model_groups() + futures = [] + + def _abort_nccl_comm(group: GroupCoordinator): + if group.device_communicator is not None: + nccl_comm = group.device_communicator.pynccl_comm + nccl_comm.nccl_abort_comm() + + def _abort_process_group(group: GroupCoordinator): + device = torch.device("cuda") + backend = group.device_group._get_backend(device) + backend.abort() + + executor = ThreadPoolExecutor(max_workers=len(model_groups) * 2) + try: + for group in model_groups: + futures.append(executor.submit(_abort_nccl_comm, group)) + futures.append(executor.submit(_abort_process_group, group)) + + done, not_done = wait( + futures, timeout=worker_timeout, return_when=FIRST_EXCEPTION + ) + if not_done: + self.logger( + "%d abort calls did not finish in total %s seconds", + len(not_done), + worker_timeout, + level="warning", + ) + finally: + executor.shutdown(wait=False, cancel_futures=True) + + exception_count = sum(1 for f in done if f.exception() is not None) + self.communicator_aborted = len(not_done) == 0 and exception_count == 0 + if self.communicator_aborted: + cleanup_dist_env_and_memory() + self.logger("Communicators are aborted.") + else: + self.logger( + "Communicator abort failed: %d NCCL comm abort calls timed out," + " %d tasks threw exceptions. This may leave NCCL communicators " + "or process groups in an inconsistent state. Subsequent " + "distributed operations could be unsafe.", + len(not_done), + exception_count, + level="error", + ) + return self.communicator_aborted + + def _set_device_communicator_status(self, active: bool): + model_groups = get_all_model_groups() + for group in model_groups: + if group.device_communicator is not None: + nccl_comm = group.device_communicator.pynccl_comm + nccl_comm.available = active + nccl_comm.disabled = not active + + def restore_worker(self): + if self.communicator_aborted: + torch.cuda.set_device(self.device) + with set_current_vllm_config(self.vllm_config): + self.init_distributed_env_callback() + self.communicator_aborted = False + torch.cuda.synchronize() + self.clear_input_batch_callback() + self.pause_event.clear() + return True + + def _send_execution_result(self, success: bool, method_uuid: str): + msg = { + "success": success, + "method_uuid": method_uuid, + } + msg_bytes = json.dumps(msg).encode("utf-8") + self.cmd_socket.send_multipart([b"", msg_bytes]) + + def shutdown(self): + self.worker_sentinel_dead = True + self.cmd_socket.close() + self.zmq_ctx.term() + + class Worker(WorkerBase): def __init__( self, @@ -77,6 +264,7 @@ def __init__( is_driver_worker=is_driver_worker, ) + self.worker_sentinel: WorkerSentinel | None = None if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils.import_utils import init_cached_hf_modules @@ -221,6 +409,7 @@ def init_device(self): # memory snapshot # This ensures NCCL buffers are allocated before we measure # available memory + start = time.time() init_worker_distributed_environment( self.vllm_config, self.rank, @@ -228,6 +417,10 @@ def init_device(self): self.local_rank, current_platform.dist_backend, ) + elapsed = time.time() - start + logger.info_once( + "init distributed environment took %.2f seconds", elapsed, scope="local" + ) # Set random seed. set_random_seed(self.model_config.seed) @@ -265,6 +458,30 @@ def init_device(self): # If usage stat is enabled, collect relevant info. report_usage_stats(self.vllm_config) + if self.vllm_config.fault_tolerance_config.enable_fault_tolerance: + with set_current_vllm_config(self.vllm_config): + init_distributed_env_callback = partial( + init_worker_distributed_environment, + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + ) + + def clear_input_batch_callback(): + input_batch = self.model_runner.input_batch + cached_req_ids = input_batch.req_id_to_index.keys() + for req_id in list(cached_req_ids): + input_batch.remove_request(req_id) + + self.worker_sentinel = WorkerSentinel( + self.vllm_config, + self.model_runner.pause_event, + init_distributed_env_callback, + clear_input_batch_callback, + self.device, + ) + # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool # to hijack tensor allocation. def load_model(self) -> None: @@ -863,6 +1080,8 @@ def save_tensorized_model( def shutdown(self) -> None: if runner := getattr(self, "model_runner", None): runner.ensure_kv_transfer_shutdown() + if self.worker_sentinel is not None: + self.worker_sentinel.shutdown() def init_worker_distributed_environment( @@ -879,14 +1098,29 @@ def init_worker_distributed_environment( init_batch_invariance() set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + if vllm_config.fault_tolerance_config.enable_fault_tolerance: + timeout = timedelta( + seconds=vllm_config.fault_tolerance_config.gloo_comm_timeout + ) + else: + timeout = None + init_distributed_environment( - parallel_config.world_size, rank, distributed_init_method, local_rank, backend + parallel_config.world_size, + rank, + distributed_init_method, + local_rank, + backend, + vllm_config.fault_tolerance_config.enable_fault_tolerance, + timeout, ) ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, parallel_config.decode_context_parallel_size, + vllm_config.fault_tolerance_config.enable_fault_tolerance, + timeout, ) # Init ec connector here before KV caches caches init