From e2c1f17a33bc51b32489860f56b8cb9c116402cf Mon Sep 17 00:00:00 2001 From: fangyuchu Date: Fri, 7 Nov 2025 20:46:02 +0800 Subject: [PATCH 1/6] Milestone 1 of Internal Process-level Fault Tolerance (#61) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(fault-tolerance): add class skeletons for fault tolerance Signed-off-by: fangyuchu * config: add configuration options for fault tolerance Signed-off-by: fangyuchu * 增加generate_identity和generate_identitys函数 Generate a unique identity for ZMQ ROUTER node * add service startup configuradtion fault report addr * add init WorkerGuard * add engine_core_cmd_addr、fault_report_addr、client_cmd_addr、engine_core_identitys in EngineZmqAddresses init engine_core_cmd_addr、fault_report_addr、client_cmd_addr in launch_core_engines func add _report_engine_dead func in CoreEngineProcManager * init ClientGuard init EngineZmqAddresses engine_core_identitys * init EngineCoreGuard * change generate_identitys to generate_identity_group * code typesetting is optimized * code typesetting is optimized * changed code format ensure every line < 88 chars * changed code format ensure every line < 88 chars fix error Value of type "dict[Any, Any] | None" is not indexable [index] * fix bug Error: vllm/v1/engine/utils.py:122:89: E501 Line too long (117 > 88) Error: vllm/v1/engine/utils.py:1059:9: F402 Import `uuid` from line 6 shadowed by loop variable * fix Error: vllm/v1/engine/utils.py:1045: error: Need type annotation for "uuids" (hint: "uuids: set[] = ...") [var-annotated] * fix error: Value of type "dict[Any, Any] | None" is not indexable [index] * fix error: Value of type "dict[Any, Any] | None" is not indexable [index] Signed-off-by: a798347923 <2645302020@qq.com> * add _send_msg in EngineCoreGuard Signed-off-by: a798347923 <2645302020@qq.com> * add import torch.cuda * add _recv_cmd function docstring that clearly explains the meaning of the return value. * changed recv_fault_msg to recv_msg add ClientGuard __init__ func parameter types * add engine monitor Signed-off-by: TianZhuo <2770730562@qq.com> * Delete requirements/test.txt~ Signed-off-by: a798347923 <39047817+a798347923@users.noreply.github.com> * Delete vllm/v1/engine/core_client.py~ Signed-off-by: a798347923 <39047817+a798347923@users.noreply.github.com> * simply _send_msg and _recv_cmd in EngineCoreGuard * simply recv_msg in ClientGuard * engine: add fault tolerance features for EngineCore. Signed-off-by: fangyuchu * engine: add timeout mechanism in retry. Signed-off-by: fangyuchu * add engine monitor * Delete vllm/v1/engine/exceptions.py~ Signed-off-by: 205150940 <112750056+205150940@users.noreply.github.com> * updata actor_index * updata enginedead flag * handle fault and report exception Signed-off-by: w00689259 * fix engine_actor * fix engine_actor fault_info * handle fault and report exception Signed-off-by: w00689259 * delete num_identity * changed try expect * fix debug error * fix one bug. Signed-off-by: fangyuchu * add fault_report_addr in FaultToleranceConfig * add handle fault&get_fault_info api Signed-off-by: w00689259 * remove fault_report_address in CoreEngineActorManager __init__ Signed-off-by: a798347923 <2645302020@qq.com> * ruff format Signed-off-by: a798347923 <2645302020@qq.com> * add handle fault&get_fault_info api Signed-off-by: w00689259 * fix one bug. Signed-off-by: fangyuchu * add fault_report_port in FaultToleranceConfig Signed-off-by: a798347923 <2645302020@qq.com> * add zmq_addr concatenate with fault_report_addr and fault_report_port Signed-off-by: a798347923 <2645302020@qq.com> * fault reporter bug fix Signed-off-by: w00689259 * fault reporter bug fix Signed-off-by: w00689259 * fault reporter bug fix Signed-off-by: w00689259 * fault reporter bug fix Signed-off-by: w00689259 * fault reporter bug fix Signed-off-by: w00689259 * fault reporter bug fix Signed-off-by: w00689259 * fix some bug * fault reporter bug fix Signed-off-by: w00689259 * fault reporter bug fix Signed-off-by: w00689259 * remove fault_report_addr in FaultToleranceConfig Signed-off-by: a798347923 <2645302020@qq.com> * refactor: relocate method serialization functions to serial_util.py Signed-off-by: fangyuchu * fix actor bug * fix actor bug * add engine_core_cmd_addr in FaultToleranceConfig Signed-off-by: a798347923 <2645302020@qq.com> * add and use _stop_worker_execution in EngineCoreGuard Signed-off-by: a798347923 <2645302020@qq.com> * add and use run in WorkerGuard Signed-off-by: a798347923 <2645302020@qq.com> * fix actor bug * fix bug * fix sentinel * fix bug vllm/v1/engine/core.py:847: error: Missing positional argument "tp_size" in call to "EngineCoreGuard" Signed-off-by: a798347923 <2645302020@qq.com> * fix bug error: Missing positional arguments "length", "byteorder" in call to "to_bytes" of "int" Signed-off-by: a798347923 <2645302020@qq.com> * fix bug in fault tolerance mode Signed-off-by: w00689259 * fix bug in fault tolerance mode Signed-off-by: w00689259 * change fault_report_port to internal_fault_report_port add external_fault_notify_port Signed-off-by: a798347923 <2645302020@qq.com> * change fault_report_port to internal_fault_report_port add external_fault_notify_port Signed-off-by: a798347923 <2645302020@qq.com> * add _recv_cmd func use deserialize_method_call and run_method in run func Signed-off-by: a798347923 <2645302020@qq.com> * Update core.py fix bug error: Need type annotation for "kwargs" (hint: "kwargs: dict[, ] = ...") Signed-off-by: a798347923 <39047817+a798347923@users.noreply.github.com> * add self.ctx.term() in shutdown() Signed-off-by: a798347923 <2645302020@qq.com> * changed import deserialize_method_call,serialize_method_call Signed-off-by: a798347923 <2645302020@qq.com> * changed init worker_guard in init_device Signed-off-by: a798347923 <2645302020@qq.com> * Update core.py add import serialize_method_call Signed-off-by: a798347923 <39047817+a798347923@users.noreply.github.com> * Update gpu_worker.py changed init WorkerGuard in init_device Signed-off-by: a798347923 <39047817+a798347923@users.noreply.github.com> * Update gpu_worker.py FIX BUG self.worker_guard: WorkerGuard|None = None Signed-off-by: a798347923 <39047817+a798347923@users.noreply.github.com> * Update gpu_worker.py fix bug error: Argument 1 to "deserialize_method_call" has incompatible type "str | None"; expected "str" [arg-type] Signed-off-by: a798347923 <39047817+a798347923@users.noreply.github.com> * Update gpu_worker.py ruff format Signed-off-by: a798347923 <39047817+a798347923@users.noreply.github.com> * Update core.py ruff-format Signed-off-by: a798347923 <39047817+a798347923@users.noreply.github.com> * actively send exception information Signed-off-by: w00689259 * actively send exception information Signed-off-by: w00689259 * actively send exception information Signed-off-by: w00689259 * change engine_core_cmd_addr(str) to engine_core_cmd_addrs(list[str]) in EngineZmqAddresses Signed-off-by: a798347923 <2645302020@qq.com> * change engine_core_cmd_addr(str) to engine_core_cmd_addrs(list[str]) in EngineZmqAddresses Signed-off-by: a798347923 <2645302020@qq.com> * Update utils.py delete engine_core_cmd_addr in EngineZmqAddresses Signed-off-by: a798347923 <39047817+a798347923@users.noreply.github.com> * Remove redundant configuration: fault-pub-port Signed-off-by: fangyuchu * Send pause instructions after receiving fault info in ClientGuard Signed-off-by: fangyuchu * change engine_core_guard_identities from dict[int, bytes] to list[bytes] Signed-off-by: a798347923 <2645302020@qq.com> * fix bug "only the worker guard of engine core 0 can receive messages sent from engine core guard Signed-off-by: a798347923 <2645302020@qq.com> * change local_rank to rank_in_group in WorkerGuard Signed-off-by: a798347923 <2645302020@qq.com> * changed del self.client_cmd_registry[int(unhealthy_engine.engine_id)] Signed-off-by: a798347923 <2645302020@qq.com> * add gloo communication timeout * fix some bug * add stateless_process_group gloo_comm_timeout * reconstruct fault receiver&fault handler Signed-off-by: w00689259 * fix some bug * reconstruct fault receiver&fault handler Signed-off-by: w00689259 * reconstruct fault receiver&fault handler Signed-off-by: w00689259 * fix return format Signed-off-by: w00689259 * fix return format Signed-off-by: w00689259 * fix return format Signed-off-by: w00689259 * add abort request * fix some bug * fix some bug * fix some bug * add dt for client guard Signed-off-by: w00689259 * add dt for client guard Signed-off-by: w00689259 * add dt for client guard Signed-off-by: w00689259 * Implementation of two types of pause: a soft one by using flag signals and a hard one by aborting nccl communicators. Signed-off-by: fangyuchu * Refine certain log forms and fix a minor bug in pause function. Signed-off-by: fangyuchu * Refactor and abstract the recv_msg logic in CG,ECG,WG. Signed-off-by: fangyuchu * Add and check method uuid when sending commands and receiving results. Signed-off-by: fangyuchu * Abstract the logic of sending instructions and waiting responses from FaultHandler Signed-off-by: fangyuchu * Add options in EngineCoreGuard to recv execution results from WorkerGuard Signed-off-by: fangyuchu * Support worker reinitialization after hard pause; add task queue in FaultHandler to ensure sequential task execution Signed-off-by: fangyuchu * resolve conflicts Signed-off-by: w00689259 * resolve conflicts Signed-off-by: w00689259 * resolve conflicts Signed-off-by: w00689259 * resolve conflicts Signed-off-by: w00689259 * resolve conflicts Signed-off-by: w00689259 * resolve conflicts Signed-off-by: w00689259 * add engine core ut Signed-off-by: w00689259 * add engine core ut Signed-off-by: w00689259 * Ensure WorkerGuard command execution returns result; fix missing set_device when TP>1 Signed-off-by: fangyuchu * rename& format logger Signed-off-by: w00689259 * rename& format logger Signed-off-by: w00689259 * feat(nccl): enable non-blocking NCCL communicators to support ncclCommAbort Signed-off-by: fangyuchu * reinit dp_group * fix bug * fix bug * fix bug * fix bug (#54) * Move requests to waiting queue instead of abandoing them directly. Signed-off-by: fangyuchu * add annotation Signed-off-by: w00689259 * fix typos Signed-off-by: fangyuchu --------- Signed-off-by: fangyuchu Signed-off-by: a798347923 <2645302020@qq.com> Signed-off-by: TianZhuo <2770730562@qq.com> Signed-off-by: a798347923 <39047817+a798347923@users.noreply.github.com> Signed-off-by: 205150940 <112750056+205150940@users.noreply.github.com> Signed-off-by: w00689259 Signed-off-by: zWaNg3 <37772915+zWaNg3@users.noreply.github.com> Co-authored-by: zWaNg3 <37772915+zWaNg3@users.noreply.github.com> Co-authored-by: a798347923 <2645302020@qq.com> Co-authored-by: TianZhuo <2770730562@qq.com> Co-authored-by: 205150940 <112750056+205150940@users.noreply.github.com> Co-authored-by: a798347923 <39047817+a798347923@users.noreply.github.com> Co-authored-by: w00689259 --- tests/v1/engine/test_client_guard.py | 222 +++++++++ tests/v1/engine/test_engine_core_guard.py | 125 +++++ vllm/config/__init__.py | 3 + vllm/config/fault_tolerance.py | 71 +++ vllm/config/parallel.py | 8 +- vllm/config/vllm.py | 10 +- .../device_communicators/pynccl.py | 9 +- .../device_communicators/pynccl_wrapper.py | 82 +++- vllm/distributed/parallel_state.py | 101 +++- vllm/distributed/utils.py | 13 +- vllm/engine/arg_utils.py | 45 ++ vllm/engine/protocol.py | 13 + vllm/entrypoints/api_server.py | 28 ++ vllm/entrypoints/cli/serve.py | 5 +- vllm/entrypoints/openai/api_server.py | 82 ++++ vllm/utils/collection_utils.py | 209 +++++++++ vllm/utils/network_utils.py | 44 ++ vllm/v1/core/sched/interface.py | 24 + vllm/v1/core/sched/scheduler.py | 31 +- vllm/v1/engine/__init__.py | 1 + vllm/v1/engine/async_llm.py | 10 + vllm/v1/engine/core.py | 423 ++++++++++++++++- vllm/v1/engine/core_client.py | 288 +++++++++++- vllm/v1/engine/exceptions.py | 71 +++ vllm/v1/engine/utils.py | 438 +++++++++++++++++- vllm/v1/serial_utils.py | 50 ++ vllm/v1/worker/gpu_model_runner.py | 10 + vllm/v1/worker/gpu_worker.py | 242 +++++++++- 28 files changed, 2594 insertions(+), 64 deletions(-) create mode 100644 tests/v1/engine/test_client_guard.py create mode 100644 tests/v1/engine/test_engine_core_guard.py create mode 100644 vllm/config/fault_tolerance.py diff --git a/tests/v1/engine/test_client_guard.py b/tests/v1/engine/test_client_guard.py new file mode 100644 index 000000000000..2448fc3c3ba9 --- /dev/null +++ b/tests/v1/engine/test_client_guard.py @@ -0,0 +1,222 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import json +import threading +import time +from unittest.mock import AsyncMock + +import pytest +import zmq + +from vllm.utils.collection_utils import ThreadSafeDict +from vllm.v1.engine.core_client import ClientGuard +from vllm.v1.engine.utils import FaultHandler, FaultInfo + +FAULT_RECEIVER_ADDR = "tcp://127.0.0.1:8844" +CMD_ADDR = "tcp://127.0.0.1:8845" +FAULT_PUB_ADDR = "tcp://127.0.0.1:8846" +FAULT_PUB_TOPIC = "vllm_fault" + + +def create_test_thread_safe_dict(initial_data=None): + if initial_data is None: + initial_data = {1: "Healthy"} + if initial_data is None: + initial_data = {1: "Healthy"} + tsd = ThreadSafeDict() + if initial_data: + for k, v in initial_data.items(): + tsd[k] = v + return tsd + + +def create_client_guard( + engine_exception_q: asyncio.Queue, engine_status_dict: ThreadSafeDict[int, str] +): + return ClientGuard( + fault_receiver_addr=FAULT_RECEIVER_ADDR, + cmd_addr=CMD_ADDR, + engine_registry=[b"engine_identity"], + engine_exception_q=engine_exception_q, + engine_exception_q_lock=asyncio.Lock(), + fault_pub_addr=FAULT_PUB_ADDR, + engine_status_dict=engine_status_dict, + ) + + +def test_client_guard_initialization(): + engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() + engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) + guard = create_client_guard(engine_exception_q, engine_status_dict) + + assert guard.engine_registry == [b"engine_identity"] + assert not guard.client_guard_dead + assert isinstance(guard.fault_handler, FaultHandler) + assert guard.engine_exception_q is engine_exception_q + + assert guard.fault_receiver_socket.type == zmq.ROUTER + assert guard.cmd_socket.type == zmq.ROUTER + assert guard.fault_pub_socket.type == zmq.PUB + + guard.shutdown_guard() + + +@pytest.mark.asyncio +async def test_handle_fault(): + engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() + engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) + guard = create_client_guard(engine_exception_q, engine_status_dict) + + engine_exception_q.put_nowait( + FaultInfo(engine_id="1", message="test exception", type="test") + ) + + guard.fault_handler.handle_fault = AsyncMock(return_value=True) + + result = await guard.handle_fault("pause", 5) + assert result is True + guard.fault_handler.handle_fault.assert_awaited_once_with("pause", 5) + + guard.shutdown_guard() + + +def test_fault_receiver(): + engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() + engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) + guard = create_client_guard(engine_exception_q, engine_status_dict) + + def send_test_message(): + ctx = zmq.Context() + socket = ctx.socket(zmq.DEALER) + socket.setsockopt(zmq.IDENTITY, b"test_sender") + socket.connect(FAULT_RECEIVER_ADDR) + + test_fault = FaultInfo(engine_id="1", type="dead", message="test error") + socket.send_multipart([b"", test_fault.serialize().encode("utf-8")]) + socket.close() + ctx.term() + + sender_thread = threading.Thread(target=send_test_message) + sender_thread.start() + + def check_published_message(): + ctx = zmq.Context() + sub_socket = ctx.socket(zmq.SUB) + sub_socket.connect(FAULT_PUB_ADDR) + sub_socket.setsockopt_string(zmq.SUBSCRIBE, FAULT_PUB_TOPIC) + + message = sub_socket.recv_string() + sub_socket.close() + ctx.term() + + prefix, data = message.split("|", 1) + assert prefix == FAULT_PUB_TOPIC + assert json.loads(data) == {"1": "Dead"} + + check_thread = threading.Thread(target=check_published_message) + check_thread.start() + + time.sleep(0.1) + + assert not engine_exception_q.empty() + received_fault = engine_exception_q.get_nowait() + assert received_fault.engine_id == "1" + assert received_fault.type == "dead" + + assert engine_status_dict[1] == "Dead" + + guard.shutdown_guard() + + +def test_fault_receiver_unhealthy(): + engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() + engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) + guard = create_client_guard(engine_exception_q, engine_status_dict) + + def send_unhealthy_message(): + ctx = zmq.Context() + socket = ctx.socket(zmq.DEALER) + socket.setsockopt(zmq.IDENTITY, b"engine_identity") + socket.connect(FAULT_RECEIVER_ADDR) + + test_fault = FaultInfo(engine_id="1", type="error", message="test error") + socket.send_multipart([b"", test_fault.serialize().encode()]) + socket.close() + ctx.term() + + threading.Thread(target=send_unhealthy_message).start() + time.sleep(0.1) + + assert engine_status_dict[1] == "Unhealthy" + + guard.shutdown_guard() + + +def test_shutdown_guard(): + engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() + engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) + guard = create_client_guard(engine_exception_q, engine_status_dict) + + original_fault_sock = guard.fault_receiver_socket + original_cmd_sock = guard.cmd_socket + original_pub_sock = guard.fault_pub_socket + original_ctx = guard.zmq_ctx + + guard.shutdown_guard() + + assert guard.client_guard_dead is True + + with pytest.raises(zmq.ZMQError): + original_fault_sock.recv() + + with pytest.raises(zmq.ZMQError): + original_cmd_sock.recv() + + with pytest.raises(zmq.ZMQError): + original_pub_sock.send(b"test") + + assert original_ctx.closed + + +@pytest.mark.asyncio +async def test_handle_fault_async(): + engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() + engine_status_dict = create_test_thread_safe_dict({1: "Unhealthy"}) + guard = create_client_guard(engine_exception_q, engine_status_dict) + + time.sleep(0.1) + ctx = zmq.Context().instance() + cmd_socket = ctx.socket(zmq.DEALER) + cmd_socket.setsockopt(zmq.IDENTITY, b"engine_identity") + cmd_socket.connect(CMD_ADDR) + + uuid = None + + def receive_cmd(cmd_socket): + nonlocal uuid + time.sleep(0.1) + + identity, msg = cmd_socket.recv_multipart() + cmd_dict = json.loads(msg.decode("utf-8")) + assert cmd_dict["method"] == "retry" + assert cmd_dict["timeout"] == 3 + uuid = cmd_dict["method_uuid"] + + def response_cmd(cmd_socket): + nonlocal uuid + while uuid is None: + time.sleep(0.1) + execute_result = {"engine_index": 1, "success": True, "method_uuid": uuid} + cmd_socket.send_multipart([b"", json.dumps(execute_result).encode("utf-8")]) + + threading.Thread(target=receive_cmd, args=(cmd_socket,)).start() + threading.Thread(target=response_cmd, args=(cmd_socket,)).start() + + result = await guard.handle_fault("retry", 3) + + assert result is True + assert engine_status_dict[1] == "Healthy" + + guard.shutdown_guard() diff --git a/tests/v1/engine/test_engine_core_guard.py b/tests/v1/engine/test_engine_core_guard.py new file mode 100644 index 000000000000..f0c9c2098248 --- /dev/null +++ b/tests/v1/engine/test_engine_core_guard.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import logging +import queue +import threading +import time + +import pytest +import zmq + +from vllm.utils.network_utils import make_zmq_socket +from vllm.v1.engine.core import ( + EngineCoreGuard, + EngineLoopPausedError, +) +from vllm.v1.serial_utils import serialize_method_call + +CLIENT_CMD_ADDR = "tcp://127.0.0.1:8844" +WORKER_CMD_ADDR = "tcp://127.0.0.1:8845" +FAULT_REPORT_ADDR = "tcp://127.0.0.1:8846" +GUARD_IDENTITY = b"engine_guard_0" + + +def create_engine_core_guard( + fault_signal_q: queue.Queue, busy_loop_active: threading.Event +): + return EngineCoreGuard( + engine_index=0, + fault_signal_q=fault_signal_q, + cmd_q=queue.Queue(), + busy_loop_active=busy_loop_active, + engine_input_q=queue.Queue(), + client_cmd_addr=CLIENT_CMD_ADDR, + worker_cmd_addr=WORKER_CMD_ADDR, + fault_report_addr=FAULT_REPORT_ADDR, + guard_identity=GUARD_IDENTITY, + tp_size=1, + pp_size=1, + ) + + +def test_engine_core_guard_initialization(): + fault_signal_q: queue.Queue = queue.Queue() + busy_loop_active = threading.Event() + + guard = create_engine_core_guard(fault_signal_q, busy_loop_active) + + assert guard.engine_index == 0 + assert guard.tp_size == 1 + assert guard.pp_size == 1 + assert not guard.communicator_aborted + assert guard.engine_running is True + assert guard.daemon is True + + assert guard.fault_report_socket.type == zmq.DEALER + assert guard.client_cmd_socket.type == zmq.DEALER + assert guard.worker_cmd_socket.type == zmq.ROUTER + + guard.shutdown() + + +@pytest.mark.parametrize("instruction", ["pause", "retry"]) +def test_run_handle_instruction(instruction): + fault_signal_q: queue.Queue = queue.Queue() + busy_loop_active = threading.Event() + + client_socket = make_zmq_socket( + ctx=zmq.Context(), path=CLIENT_CMD_ADDR, socket_type=zmq.ROUTER, bind=True + ) + + time.sleep(0.1) + + guard = create_engine_core_guard(fault_signal_q, busy_loop_active) + time.sleep(0.1) + + ctx = zmq.Context() + worker_cmd_socket = ctx.socket(zmq.DEALER) + worker_cmd_socket.setsockopt(zmq.IDENTITY, b"0_0") + worker_cmd_socket.connect(WORKER_CMD_ADDR) + + def mock_worker_receiver(cmd_socket): + time.sleep(0.1) + logging.info("start worker") + identity, msg = cmd_socket.recv_multipart() + logging.info(identity) + cmd_dict = json.loads(msg.decode("utf-8")) + assert ( + cmd_dict["method"] == "pause_by_signal" + if instruction == "pause" + else "retry" + ) + response_dict = {"success": True, "method_uuid": cmd_dict["method_uuid"]} + logging.info(identity) + cmd_socket.send_multipart([b"", json.dumps(response_dict).encode("utf-8")]) + + threading.Thread(target=guard.run, daemon=True).start() + time.sleep(0.1) + + param = {"timeout": 3} + if instruction == "pause": + param["soft_pause"] = True + serial_instruction = serialize_method_call(instruction, **param) + client_socket.send_multipart( + [GUARD_IDENTITY, b"", serial_instruction.encode("utf-8")] + ) + if instruction == "pause": + fault_signal_q.put(EngineLoopPausedError(Exception("test error"))) + elif instruction == "retry": + busy_loop_active.set() + + threading.Thread(target=mock_worker_receiver, args=(worker_cmd_socket,)).start() + + time.sleep(0.1) + identity, _, msg = client_socket.recv_multipart() + result_dict = json.loads(msg.decode("utf-8")) + assert result_dict["engine_index"] == 0 + assert result_dict["success"] + + time.sleep(0.1) + + client_socket.close() + worker_cmd_socket.close() + guard.shutdown() diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index dd76a722106e..eb85862bf240 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -10,6 +10,7 @@ ) from vllm.config.device import DeviceConfig from vllm.config.ec_transfer import ECTransferConfig +from vllm.config.fault_tolerance import FaultToleranceConfig from vllm.config.kv_events import KVEventsConfig from vllm.config.kv_transfer import KVTransferConfig from vllm.config.load import LoadConfig @@ -86,6 +87,8 @@ "SpeechToTextConfig", # From vllm.config.structured_outputs "StructuredOutputsConfig", + # From vllm.config.fault_tolerance + "FaultToleranceConfig", # From vllm.config.utils "ConfigType", "SupportsMetricsInfo", diff --git a/vllm/config/fault_tolerance.py b/vllm/config/fault_tolerance.py new file mode 100644 index 000000000000..b12be7af81b0 --- /dev/null +++ b/vllm/config/fault_tolerance.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from typing import Any + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@config +@dataclass +class FaultToleranceConfig: + """Configuration for fault tolerance.""" + + enable_fault_tolerance: bool = False + """Enable fault tolerance for detailed error recovery, + such as scaling down fault DPEngineCore. + """ + + engine_recovery_timeout: int = 60 + """Timeout (in seconds) to wait for error handling instructions + before raising an exception. If the EngineCore encounters an + error, it waits up to this many seconds for instructions on how + to handle the error. If no instructions are received within this + time, the original error is raised. + """ + + internal_fault_report_port: int = 22866 + """ + The port to use for internal fault reporting. + """ + + external_fault_notify_port: int = 22867 + """ + The port to use for external fault notify. + """ + + engine_core_cmd_addr: str = "" + """ + The ZMQ address between engine_core_guard and worker_guard. + It will be initialized and assigned in EngineCore, then passed + to the Worker via vllm_config—this is required for the Worker + to spin up the WorkerGuard. + """ + + gloo_comm_timeout: int = 30 + """ + The timeout for gloo communication. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + pass diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 9a6326d62e82..466390335327 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -337,7 +337,11 @@ def get_next_dp_init_port(self) -> int: return answer - def stateless_init_dp_group(self) -> ProcessGroup: + def stateless_init_dp_group( + self, + gloo_comm_timeout: int = 30, + enable_fault_tolerance: bool = False, + ) -> ProcessGroup: # NOTE: In high-concurrency scenarios multiple processes # can pick the same (currently free) port through a race # condition when calling `get_open_port()`. When the first @@ -362,6 +366,8 @@ def stateless_init_dp_group(self) -> ProcessGroup: self.data_parallel_rank, self.data_parallel_size, backend=current_platform.dist_backend, + gloo_comm_timeout=gloo_comm_timeout, + enable_fault_tolerance=enable_fault_tolerance, ) except DistNetworkError as e: # We only want to retry when the root cause is EADDRINUSE. diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 672b004c4aa5..9005d309f3a0 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -10,7 +10,7 @@ import threading import time from contextlib import contextmanager -from dataclasses import replace +from dataclasses import field, replace from datetime import datetime from functools import lru_cache from pathlib import Path @@ -30,6 +30,7 @@ from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode from .device import DeviceConfig from .ec_transfer import ECTransferConfig +from .fault_tolerance import FaultToleranceConfig from .kv_events import KVEventsConfig from .kv_transfer import KVTransferConfig from .load import LoadConfig @@ -107,6 +108,10 @@ class VllmConfig: """The configurations for event publishing.""" ec_transfer_config: ECTransferConfig | None = None """The configurations for distributed EC cache transfer.""" + fault_tolerance_config: FaultToleranceConfig = field( + default_factory=FaultToleranceConfig + ) + """The configurations for fault tolerance.""" # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. @@ -1049,7 +1054,8 @@ def __str__(self): f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, " # noqa f"pooler_config={self.model_config.pooler_config!r}, " - f"compilation_config={self.compilation_config!r}" + f"compilation_config={self.compilation_config!r}," + f"fault_tolerance_config={self.fault_tolerance_config!r}, " ) @model_validator(mode="after") diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 2fc35e80f591..846a0f6c21cc 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -236,7 +236,7 @@ def all_gatherv( cudaStream_t(stream.cuda_stream), ) split_offset += split_size - self.nccl.ncclGroupEnd() + self.nccl.ncclGroupEnd(self.comm) def reduce_scatter( self, @@ -301,7 +301,7 @@ def reduce_scatterv( cudaStream_t(stream.cuda_stream), ) split_offset += split_size - self.nccl.ncclGroupEnd() + self.nccl.ncclGroupEnd(self.comm) def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: @@ -369,7 +369,10 @@ def group_start(self): self.nccl.ncclGroupStart() def group_end(self): - self.nccl.ncclGroupEnd() + self.nccl.ncclGroupEnd(self.comm) + + def nccl_abort_comm(self): + self.nccl.ncclCommAbort(self.comm) def register_comm_window(self, tensor: torch.Tensor): return self.nccl.ncclCommWindowRegister( diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index b2433d58dc1f..e5c3e981823b 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -136,6 +136,15 @@ class NCCLLibrary: Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]), # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t* asyncError) + Function( + "ncclCommGetAsyncError", + ncclResult_t, + [ + ncclComm_t, + ctypes.POINTER(ncclResult_t), + ], + ), # ncclResult_t ncclCommInitRank( # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); # note that ncclComm_t is a pointer type, so the first argument @@ -274,6 +283,8 @@ class NCCLLibrary: # it is better not to call it at all. # ncclResult_t ncclCommDestroy(ncclComm_t comm); Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + # ncclResult_t ncclCommAbort(ncclComm_t comm) + Function("ncclCommAbort", ncclResult_t, [ncclComm_t]), # ncclResult_t ncclGroupStart(); Function("ncclGroupStart", ncclResult_t, []), # ncclResult_t ncclGroupEnd(); @@ -360,10 +371,32 @@ def __init__(self, so_file: str | None = None): def ncclGetErrorString(self, result: ncclResult_t) -> str: return self._funcs["ncclGetErrorString"](result).decode("utf-8") - def NCCL_CHECK(self, result: ncclResult_t) -> None: - if result != 0: - error_str = self.ncclGetErrorString(result) - raise RuntimeError(f"NCCL error: {error_str}") + def NCCL_CHECK(self, result: ncclResult_t, comm: ncclComm_t | None = None) -> None: + ncclSuccess = 0 + ncclInProgress = 7 + + if result == ncclSuccess: + return + + # Handle non-blocking communicators + if result == ncclInProgress: + if comm is None: + raise RuntimeError( + "NCCL_CHECK: ncclInProgress returned but no communicator " + "provided (required for non-blocking NCCL checks)." + ) + while True: + result = self.ncclCommGetAsyncError(comm) + result_value = result.value + if result_value == ncclSuccess: + # Operation has completed successfully + return + if result_value != ncclInProgress: + # Now a definite error occurred + break + + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL_CHECK failed: {error_str} (code={result})") def ncclGetRawVersion(self) -> int: version = ctypes.c_int() @@ -384,6 +417,11 @@ def ncclGetUniqueId(self) -> ncclUniqueId: self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id))) return unique_id + def ncclCommGetAsyncError(self, comm: ncclComm_t) -> ncclResult_t: + async_error = ncclResult_t() + self._funcs["ncclCommGetAsyncError"](comm, ctypes.byref(async_error)) + return async_error + def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId: if len(data) != 128: raise ValueError( @@ -400,7 +438,8 @@ def ncclCommInitRank( self.NCCL_CHECK( self._funcs["ncclCommInitRank"]( ctypes.byref(comm), world_size, unique_id, rank - ) + ), + comm, ) return comm @@ -422,7 +461,8 @@ def ncclAllReduce( self.NCCL_CHECK( self._funcs["ncclAllReduce"]( sendbuff, recvbuff, count, datatype, op, comm, stream - ) + ), + comm, ) def ncclReduce( @@ -444,7 +484,8 @@ def ncclReduce( self.NCCL_CHECK( self._funcs["ncclReduce"]( sendbuff, recvbuff, count, datatype, op, root, comm, stream - ) + ), + comm, ) def ncclReduceScatter( @@ -465,7 +506,8 @@ def ncclReduceScatter( self.NCCL_CHECK( self._funcs["ncclReduceScatter"]( sendbuff, recvbuff, count, datatype, op, comm, stream - ) + ), + comm, ) def ncclAllGather( @@ -484,7 +526,8 @@ def ncclAllGather( self.NCCL_CHECK( self._funcs["ncclAllGather"]( sendbuff, recvbuff, count, datatype, comm, stream - ) + ), + comm, ) def ncclSend( @@ -497,7 +540,7 @@ def ncclSend( stream: cudaStream_t, ) -> None: self.NCCL_CHECK( - self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream) + self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream), comm ) def ncclRecv( @@ -510,7 +553,7 @@ def ncclRecv( stream: cudaStream_t, ) -> None: self.NCCL_CHECK( - self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream) + self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream), comm ) def ncclBroadcast( @@ -526,17 +569,21 @@ def ncclBroadcast( self.NCCL_CHECK( self._funcs["ncclBroadcast"]( sendbuff, recvbuff, count, datatype, root, comm, stream - ) + ), + comm, ) def ncclCommDestroy(self, comm: ncclComm_t) -> None: - self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm), comm) + + def ncclCommAbort(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommAbort"](comm), comm) def ncclGroupStart(self) -> None: self.NCCL_CHECK(self._funcs["ncclGroupStart"]()) - def ncclGroupEnd(self) -> None: - self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) + def ncclGroupEnd(self, comm) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupEnd"](), comm) def ncclCommWindowRegister( self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int @@ -545,12 +592,13 @@ def ncclCommWindowRegister( self.NCCL_CHECK( self._funcs["ncclCommWindowRegister"]( comm, buff, size, ctypes.byref(window), win_flags - ) + ), + comm, ) return window def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None: - self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window)) + self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window), comm) __all__ = [ diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 852c4c644433..b27125b66830 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -25,6 +25,7 @@ import contextlib import gc +import os import pickle import weakref from collections import namedtuple @@ -312,6 +313,8 @@ def __init__( use_device_communicator: bool, # whether to use device communicator use_message_queue_broadcaster: bool = False, group_name: str | None = None, + enable_fault_tolerance: bool = False, + gloo_comm_timeout: timedelta | None = None, ): group_name = group_name or "anonymous" self.unique_name = _get_unique_name(group_name) @@ -323,13 +326,30 @@ def __init__( self_device_group = None self_cpu_group = None + if torch_distributed_backend == "nccl": + options = torch._C._distributed_c10d.ProcessGroupNCCL.Options() + if enable_fault_tolerance: + # need to set communicators as nonblocking to abort safely + options.config.blocking = 0 + os.environ["NCCL_COMM_BLOCKING"] = "0" + for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend - ) + if torch_distributed_backend == "nccl": + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend, pg_options=options + ) + else: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. - cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if not enable_fault_tolerance: + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + else: + cpu_group = torch.distributed.new_group( + ranks, backend="gloo", timeout=gloo_comm_timeout + ) if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) @@ -1039,7 +1059,11 @@ def get_inner_dp_world_group() -> GroupCoordinator: def init_world_group( - ranks: list[int], local_rank: int, backend: str + ranks: list[int], + local_rank: int, + backend: str, + enable_fault_tolerance: bool = False, + gloo_comm_timeout: timedelta | None = None, ) -> GroupCoordinator: return GroupCoordinator( group_ranks=[ranks], @@ -1047,6 +1071,8 @@ def init_world_group( torch_distributed_backend=backend, use_device_communicator=False, group_name="world", + enable_fault_tolerance=enable_fault_tolerance, + gloo_comm_timeout=gloo_comm_timeout, ) @@ -1054,6 +1080,8 @@ def init_model_parallel_group( group_ranks: list[list[int]], local_rank: int, backend: str, + enable_fault_tolerance: bool = False, + gloo_comm_timeout: timedelta | None = None, use_message_queue_broadcaster: bool = False, group_name: str | None = None, use_device_communicator: bool = True, @@ -1065,6 +1093,8 @@ def init_model_parallel_group( use_device_communicator=use_device_communicator, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, + enable_fault_tolerance=enable_fault_tolerance, + gloo_comm_timeout=gloo_comm_timeout, ) @@ -1128,6 +1158,31 @@ def get_pipeline_model_parallel_group(): return get_pp_group() +def get_all_model_groups() -> list[GroupCoordinator]: + group_list = [] + global _TP + if _TP: + group_list.append(_TP) + + global _PP + if _PP: + group_list.append(_PP) + + global _DCP + if _DCP: + group_list.append(_DCP) + + global _DP + if _DP: + group_list.append(_DP) + + global _EP + if _EP: + group_list.append(_EP) + + return group_list + + @contextmanager def graph_capture(device: torch.device): """ @@ -1164,6 +1219,7 @@ def init_distributed_environment( distributed_init_method: str = "env://", local_rank: int = -1, backend: str = "nccl", + enable_fault_tolerance: bool = False, timeout: timedelta | None = None, ): logger.debug( @@ -1244,7 +1300,9 @@ def init_distributed_environment( global _WORLD, _NODE_COUNT, _INNER_DP_WORLD if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) - _WORLD = init_world_group(ranks, local_rank, backend) + _WORLD = init_world_group( + ranks, local_rank, backend, enable_fault_tolerance, timeout + ) if config.parallel_config.nnodes > 1: _NODE_COUNT = config.parallel_config.nnodes else: @@ -1276,6 +1334,8 @@ def init_distributed_environment( def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + enable_fault_tolerance: bool = False, + gloo_comm_timeout: timedelta | None = None, decode_context_model_parallel_size: int | None = 1, backend: str | None = None, ) -> None: @@ -1339,6 +1399,8 @@ def initialize_model_parallel( group_ranks, get_world_group().local_rank, backend, + enable_fault_tolerance, + gloo_comm_timeout, use_message_queue_broadcaster=True, group_name="tp", ) @@ -1356,6 +1418,8 @@ def initialize_model_parallel( group_ranks, get_world_group().local_rank, backend, + enable_fault_tolerance, + gloo_comm_timeout, use_message_queue_broadcaster=True, group_name="dcp", ) @@ -1368,7 +1432,12 @@ def initialize_model_parallel( ) group_ranks = [x.tolist() for x in group_ranks] _PP = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="pp" + group_ranks, + get_world_group().local_rank, + backend, + enable_fault_tolerance, + gloo_comm_timeout, + group_name="pp", ) global _DP @@ -1376,7 +1445,12 @@ def initialize_model_parallel( group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] _DP = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="dp" + group_ranks, + get_world_group().local_rank, + backend, + enable_fault_tolerance, + gloo_comm_timeout, + group_name="dp", ) global _EP @@ -1388,7 +1462,12 @@ def initialize_model_parallel( ) group_ranks = [x.tolist() for x in group_ranks] _EP = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="ep" + group_ranks, + get_world_group().local_rank, + backend, + enable_fault_tolerance, + gloo_comm_timeout, + group_name="ep", ) logger.info_once( @@ -1407,6 +1486,8 @@ def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, decode_context_model_parallel_size: int | None = 1, + enable_fault_tolerance: bool = False, + gloo_comm_timeout: timedelta | None = None, backend: str | None = None, ) -> None: """Helper to initialize model parallel groups if they are not initialized, @@ -1418,6 +1499,8 @@ def ensure_model_parallel_initialized( initialize_model_parallel( tensor_model_parallel_size, pipeline_model_parallel_size, + enable_fault_tolerance, + gloo_comm_timeout, decode_context_model_parallel_size, backend, ) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index debf69c49b7d..f453015ac1a2 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -458,7 +458,13 @@ def init_gloo_process_group( def stateless_init_torch_distributed_process_group( - host: str, port: int, rank: int, world_size: int, backend: str + host: str, + port: int, + rank: int, + world_size: int, + backend: str, + gloo_comm_timeout: int, + enable_fault_tolerance: bool = False, ) -> ProcessGroup: """ A replacement for `torch.distributed.init_process_group` that does not @@ -493,7 +499,10 @@ def stateless_init_torch_distributed_process_group( """ init_method = get_tcp_uri(host, port) backend = Backend(backend) # it is basically string - timeout = _get_default_timeout(backend) + if enable_fault_tolerance: + timeout = timedelta(seconds=gloo_comm_timeout) + else: + timeout = _get_default_timeout(backend) store, rank, world_size = next( rendezvous(init_method, rank, world_size, timeout=timeout) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ab6e5e594c23..54e7d048de93 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -40,6 +40,7 @@ DeviceConfig, ECTransferConfig, EPLBConfig, + FaultToleranceConfig, KVEventsConfig, KVTransferConfig, LoadConfig, @@ -567,6 +568,13 @@ class EngineArgs: kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill + # fault tolerance fields + enable_fault_tolerance: bool = FaultToleranceConfig.enable_fault_tolerance + engine_recovery_timeout: int = FaultToleranceConfig.engine_recovery_timeout + internal_fault_report_port: int = FaultToleranceConfig.internal_fault_report_port + external_fault_notify_port: int = FaultToleranceConfig.external_fault_notify_port + gloo_comm_timeout: int = FaultToleranceConfig.gloo_comm_timeout + kv_offloading_size: float | None = CacheConfig.kv_offloading_size kv_offloading_backend: KVOffloadingBackend | None = ( CacheConfig.kv_offloading_backend @@ -1155,6 +1163,32 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"] ) + # fault tolerance arguments + fault_tolerance_kwargs = get_kwargs(FaultToleranceConfig) + fault_tolerance_group = parser.add_argument_group( + title="FaultToleranceConfig", + description=FaultToleranceConfig.__doc__, + ) + fault_tolerance_group.add_argument( + "--enable-fault-tolerance", + **fault_tolerance_kwargs["enable_fault_tolerance"], + ) + fault_tolerance_group.add_argument( + "--engine-recovery-timeout", + **fault_tolerance_kwargs["engine_recovery_timeout"], + ) + fault_tolerance_group.add_argument( + "--internal-fault-report-port", + **fault_tolerance_kwargs["internal_fault_report_port"], + ) + fault_tolerance_group.add_argument( + "--external-fault-notify-port", + **fault_tolerance_kwargs["external_fault_notify_port"], + ) + fault_tolerance_group.add_argument( + "--gloo-comm-timeout", + **fault_tolerance_kwargs["gloo_comm_timeout"], + ) # Other arguments parser.add_argument( "--disable-log-stats", @@ -1738,6 +1772,16 @@ def create_engine_config( collect_detailed_traces=self.collect_detailed_traces, ) + fault_tolerance_config = FaultToleranceConfig( + enable_fault_tolerance=self.enable_fault_tolerance, + engine_recovery_timeout=self.engine_recovery_timeout, + internal_fault_report_port=self.internal_fault_report_port + or FaultToleranceConfig.internal_fault_report_port, + external_fault_notify_port=self.external_fault_notify_port + or FaultToleranceConfig.external_fault_notify_port, + gloo_comm_timeout=self.gloo_comm_timeout, + ) + # Compilation config overrides compilation_config = copy.deepcopy(self.compilation_config) if self.cuda_graph_sizes is not None: @@ -1784,6 +1828,7 @@ def create_engine_config( kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, ec_transfer_config=self.ec_transfer_config, + fault_tolerance_config=fault_tolerance_config, additional_config=self.additional_config, ) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 462d2c4e50e7..33dbcc496a45 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -165,6 +165,19 @@ async def collective_rpc( """Perform a collective RPC call to the given path.""" raise NotImplementedError + async def handle_fault( + self, instruction: str, timeout: int = 300, **kwargs + ) -> bool: + """send fault tolerance instruction to the engine""" + raise NotImplementedError + + async def exception_reporter(self): + """report exception from engine_core""" + raise NotImplementedError + async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: """Get supported tasks""" raise NotImplementedError + + def shutdown(self) -> None: + raise NotImplementedError diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 154cdeb42a3e..b2034bbada70 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -56,6 +56,34 @@ async def generate(request: Request) -> Response: return await _generate(request_dict, raw_request=request) +@app.post("/fault_tolerance/apply") +async def send_fault_tolerance_instruction(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + request_dict = await request.json() + + fault_tolerance_instruction = request_dict.get("fault_tolerance_instruction") + fault_tolerance_timeout = request_dict.get("fault_tolerance_timeout") + kwargs = request_dict.get("kwargs", {}) + assert engine is not None + return await engine.handle_fault( + fault_tolerance_instruction, fault_tolerance_timeout, **kwargs + ) + + +@app.get("/fault_tolerance/status") +async def get_fault_info() -> Response: + """Health check.""" + assert engine is not None + engine_exception_dict = await engine.exception_reporter() + return Response(json.dumps(engine_exception_dict), status_code=200) + + @with_cancellation async def _generate(request_dict: dict, raw_request: Request) -> Response: prompt = request_dict.pop("prompt") diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 96608f360e17..2faaa6013141 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -153,7 +153,10 @@ def signal_handler(signum, frame): ) try: - engine_manager.join_first() + if vllm_config.fault_tolerance_config.enable_fault_tolerance: + engine_manager.start_engine_core_monitor() + else: + engine_manager.join_first() finally: logger.info("Shutting down.") engine_manager.close() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3cf66fcd27e2..a53bc8fe8cf3 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1224,6 +1224,88 @@ async def is_scaling_elastic_ep(raw_request: Request): (PoolingRequest, (pooling, create_pooling)), ] + +@router.post( + "/fault_tolerance/apply", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"model": dict}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +async def send_fault_tolerance_instruction(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail="Invalid JSON format") from e + + client = engine_client(raw_request) + + fault_tolerance_instruction = body.get("fault_tolerance_instruction") + fault_tolerance_timeout = body.get("fault_tolerance_timeout") + dynamic_fault_tolerance_params = body.get("fault_tolerance_params", {}) + + if fault_tolerance_instruction is None or fault_tolerance_timeout is None: + raise HTTPException( + status_code=400, + detail="fault_tolerance_instruction and" + " fault_tolerance_timeout is required", + ) + + if not isinstance(fault_tolerance_instruction, str): + raise HTTPException( + status_code=400, detail="fault_tolerance_instruction must be a str" + ) + # Currently, only two types of instructions are supported: [pause, retry]. + # Additional descaling instructions will be supported in future updates. + elif fault_tolerance_instruction not in ["pause", "retry"]: + raise HTTPException( + status_code=400, detail="not a valid fault_tolerance_instruction" + ) + + if not isinstance(fault_tolerance_timeout, int) or fault_tolerance_timeout <= 0: + raise HTTPException( + status_code=400, detail="fault_tolerance_timeout must be a positive integer" + ) + try: + execute_result = await client.handle_fault( + fault_tolerance_instruction, + fault_tolerance_timeout, + **dynamic_fault_tolerance_params, + ) + if execute_result: + return JSONResponse( + { + "message": "instruction has been executed successfully", + } + ) + else: + logger.error("Fault tolerance failed, shutdown the app.") + client.shutdown() + raise HTTPException( + status_code=400, + detail="Instruction execution failed.", + ) + + except Exception as e: + logger.error("Handle fault failed: %s", e) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail="Handle fault failed", + ) from e + + +@router.get("/fault_tolerance/status") +async def get_fault_info( + raw_request: Request, +): + client = engine_client(raw_request) + engine_exception_dict = await client.exception_reporter() + return JSONResponse(content=engine_exception_dict) + + # NOTE: Construct the TypeAdapters only once INVOCATION_VALIDATORS = [ (pydantic.TypeAdapter(request_type), (get_handler, endpoint)) diff --git a/vllm/utils/collection_utils.py b/vllm/utils/collection_utils.py index 57271311828c..5bb7b418b0c5 100644 --- a/vllm/utils/collection_utils.py +++ b/vllm/utils/collection_utils.py @@ -6,6 +6,7 @@ This is similar in concept to the `collections` module. """ +import threading from collections import UserDict, defaultdict from collections.abc import Callable, Generator, Hashable, Iterable, Mapping from typing import Generic, Literal, TypeVar @@ -137,3 +138,211 @@ def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: obj[key1] = v2 else: obj.pop(key1, None) + + +# Define type variables for generic key and value types +KT = TypeVar("KT") # Key type variable +VT = TypeVar("VT") # Value type variable + + +class ThreadSafeDict(Generic[KT, VT]): + """ + A thread-safe generic dictionary implementation. + Supports all basic dictionary operations with proper synchronization + using a reentrant lock, and maintains type safety through generics. + """ + + def __init__(self) -> None: + """Initialize an empty thread-safe dictionary with an internal lock.""" + self._storage: dict[KT, VT] = {} # Underlying storage structure + self._lock = threading.RLock() # Reentrant lock for synchronization + + def __setitem__(self, key: KT, value: VT) -> None: + """ + Thread-safe implementation of dictionary item assignment. + Equivalent to dict[key] = value. + + Args: + key: The key to associate with the value + value: The value to store + """ + with self._lock: + self._storage[key] = value + + def __getitem__(self, key: KT) -> VT: + """ + Thread-safe implementation of dictionary item retrieval. + Equivalent to dict[key]. + + Args: + key: The key to look up + + Returns: + The value associated with the key + + Raises: + KeyError: If the key is not found + """ + with self._lock: + return self._storage[key] + + def __delitem__(self, key: KT) -> None: + """ + Thread-safe implementation of dictionary item deletion. + Equivalent to del dict[key]. + + Args: + key: The key to remove + + Raises: + KeyError: If the key is not found + """ + with self._lock: + del self._storage[key] + + def get(self, key: KT, default: VT | None = None) -> VT | None: + """ + Thread-safe implementation of dict.get(). + + Args: + key: The key to look up + default: Value to return if key is not found (default: None) + + Returns: + The value associated with the key, or default if not found + """ + with self._lock: + return self._storage.get(key, default) + + def setdefault(self, key: KT, default: VT) -> VT: + """ + Thread-safe implementation of dict.setdefault(). + Inserts key with default value if key is not present. + + Args: + key: The key to check/insert + default: Value to insert if key is not found + + Returns: + The existing value or the inserted default value + """ + with self._lock: + return self._storage.setdefault(key, default) + + def update(self, items: Iterable[tuple[KT, VT]]) -> None: + """ + Thread-safe implementation of dict.update(). + Updates dictionary with multiple key-value pairs. + + Args: + items: Iterable of (key, value) tuples to add/update + """ + with self._lock: + self._storage.update(items) + + def pop(self, key: KT, default: VT | None = None) -> VT | None: + """ + Thread-safe implementation of dict.pop(). + Removes and returns value associated with key. + + Args: + key: The key to remove + default: Value to return if key is not found (optional) + + Returns: + The removed value or default if key not found + + Raises: + KeyError: If key not found and no default provided + """ + with self._lock: + return self._storage.pop(key, default) + + def __contains__(self, key: KT) -> bool: + """ + Thread-safe implementation of 'key in dict' check. + + Args: + key: The key to check for existence + + Returns: + True if key exists, False otherwise + """ + with self._lock: + return key in self._storage + + def __len__(self) -> int: + """ + Thread-safe implementation of len(dict). + + Returns: + Number of key-value pairs in the dictionary + """ + with self._lock: + return len(self._storage) + + def clear(self) -> None: + """Thread-safe implementation of dict.clear(). Removes all items.""" + with self._lock: + self._storage.clear() + + def keys(self) -> list[KT]: + """ + Thread-safe implementation of dict.keys(). + Returns a copy of all keys to prevent concurrent modification issues. + + Returns: + List of all keys in the dictionary + """ + with self._lock: + return list(self._storage.keys()) + + def values(self) -> list[VT]: + """ + Thread-safe implementation of dict.values(). + Returns a copy of all values to prevent concurrent modification issues. + + Returns: + List of all values in the dictionary + """ + with self._lock: + return list(self._storage.values()) + + def items(self) -> list[tuple[KT, VT]]: + """ + Thread-safe implementation of dict.items(). + Returns a copy of all key-value pairs to prevent concurrent modification issues. + + Returns: + List of (key, value) tuples + """ + with self._lock: + return list(self._storage.items()) + + def __str__(self) -> str: + """ + Thread-safe string representation. + + Returns: + String representation of the dictionary + """ + with self._lock: + return str(self._storage) + + def __repr__(self) -> str: + """ + Thread-safe representation for debugging. + + Returns: + Debug-friendly string representation + """ + with self._lock: + return f"ThreadSafeDict({self._storage!r})" + + # ------------------------------ + # Critical: JSON serialization support + # ------------------------------ + def to_dict(self) -> dict[KT, VT]: + """Convert ThreadSafeDict to a standard Python dict (thread-safe).""" + with self._lock: + return self._storage.copy() # Return a copy of internal data diff --git a/vllm/utils/network_utils.py b/vllm/utils/network_utils.py index 0a68e48ba5e7..5bdff432b872 100644 --- a/vllm/utils/network_utils.py +++ b/vllm/utils/network_utils.py @@ -329,3 +329,47 @@ def zmq_socket_ctx( finally: ctx.destroy(linger=linger) + + +def recv_router_dealer_message( + socket: zmq.Socket, + use_poller: bool = False, + poll_timeout: int = 1000, +) -> tuple[bool, None | bytes, None | str]: + """ + Receive multipart ZMQ messages, automatically inferring message format + based on socket type (ROUTER or DEALER). + + Returns: + (success, identity, message) + - identity is only set for ROUTER sockets + - success=False on timeout or error + """ + sock_type = socket.getsockopt(zmq.TYPE) + + # optional non-blocking receive + if use_poller: + poller = zmq.Poller() + poller.register(socket, zmq.POLLIN) + socks = dict(poller.poll(poll_timeout)) + if socket not in socks: + return (False, None, None) + + parts = socket.recv_multipart() + + if sock_type == zmq.ROUTER: + # ROUTER message: [identity, empty, message] + assert len(parts) == 3, f"expected 3 parts, got {len(parts)}" + identity_bytes, empty_frame, message_bytes = parts + identity = identity_bytes + elif sock_type == zmq.DEALER: + # DEALER message: [empty, message] + assert len(parts) == 2, f"expected 2 parts, got {len(parts)}" + empty_frame, message_bytes = parts + identity = None + else: + raise ValueError(f"Unsupported socket type: {sock_type}") + + assert empty_frame == b"", f"empty frame invalid: {empty_frame}" + message = message_bytes.decode("utf-8") + return (True, identity, message) diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 88d99d940282..ee8af447484e 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -64,6 +64,30 @@ def get_grammar_bitmask( ) -> "GrammarOutput | None": raise NotImplementedError + @abstractmethod + def preempt_request( + self, + scheduled_timestamp: float | None = None, + preempted_req: Request | None = None, + ) -> Request: + """ + Preempt a running request and move it back to the waiting queue. + + This method removes the specified request from the running queue (or the + last running request if none is specified), updates its status and statistics, + and moves it back to the waiting queue. Optionally records a preemption event + if logging is enabled. + + Args: + scheduled_timestamp: Optional timestamp for logging the preemption event. + preempted_req: Specific request to preempt. If None, preempt the last + request in the running queue. + + Returns: + The request that was preempted and returned to the waiting queue. + """ + raise NotImplementedError + @abstractmethod def update_from_output( self, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4323141c435b..658e98eea75f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -37,7 +37,11 @@ ) from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.utils import check_stop, remove_all -from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs +from vllm.v1.engine import ( + EngineCoreEventType, + EngineCoreOutput, + EngineCoreOutputs, +) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput @@ -712,6 +716,31 @@ def schedule(self) -> SchedulerOutput: self._update_after_schedule(scheduler_output) return scheduler_output + def preempt_request( + self, + scheduled_timestamp: float | None = None, + preempted_req: Request | None = None, + ) -> Request: + # Preempt a running request and move it back to the waiting queue. + if preempted_req is not None: + self.running.remove(preempted_req) + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + preempted_req.num_preemptions += 1 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) + + self.waiting.prepend_request(preempted_req) + + return preempted_req + def _update_after_schedule( self, scheduler_output: SchedulerOutput, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 3f621d77c024..882b3a97cede 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -187,6 +187,7 @@ class EngineCoreRequestType(enum.Enum): UTILITY = b"\x03" # Sentinel used within EngineCoreProc. EXECUTOR_FAILED = b"\x04" + PAUSE = b"\x05" class ReconfigureDistributedRequest(msgspec.Struct): diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index c160c7cbcab4..9a51fe711ff3 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -779,6 +779,16 @@ async def scale_elastic_ep( custom_stat_loggers=None, ) + async def handle_fault( + self, instruction: str, timeout: int = 300, **kwargs + ) -> bool: + """send fault tolerance instruction to the engine""" + return await self.engine_core.handle_fault(instruction, timeout, **kwargs) + + async def exception_reporter(self): + """report exception in engine core""" + return await self.engine_core.fault_reporter() + @property def is_running(self) -> bool: # Is None before the loop is started. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d49eb752d56a..124d220c0a74 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json import os import queue import signal import threading import time +import traceback from collections import deque from collections.abc import Callable, Generator from concurrent.futures import Future @@ -31,7 +33,7 @@ maybe_attach_gc_debug_callback, ) from vllm.utils.hashing import get_hash_fn_by_name -from vllm.utils.network_utils import make_zmq_socket +from vllm.utils.network_utils import make_zmq_socket, recv_router_dealer_message from vllm.utils.system_utils import decorate_logs, set_process_title from vllm.v1.core.kv_cache_utils import ( BlockHash, @@ -43,25 +45,36 @@ from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import ( + EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, + FinishReason, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, UtilityResult, ) +from vllm.v1.engine.exceptions import EngineLoopPausedError, FaultInfo from vllm.v1.engine.utils import ( EngineHandshakeMetadata, EngineZmqAddresses, + broadcast_instruction, get_device_indices, + wait_for_instruction_result, ) from vllm.v1.executor import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus -from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder +from vllm.v1.serial_utils import ( + MsgpackDecoder, + MsgpackEncoder, + deserialize_method_call, + run_method, + serialize_method_call, +) from vllm.v1.structured_output import StructuredOutputManager from vllm.version import __version__ as VLLM_VERSION @@ -73,6 +86,318 @@ _R = TypeVar("_R") # Return type for collective_rpc +class EngineCoreGuard(threading.Thread): # changed + """ + EngineCoreGuard monitors a single EngineCore instance, responsible for: + 1. Receiving fault signals (exceptions raised in EngineCore busy loop) + 2. Receiving and executing commands from ClientGuard + 3. Reporting execution results or faults back to the ClientGuard + """ + + def __init__( + self, + engine_index: int, + fault_signal_q: queue.Queue, + cmd_q: queue.Queue, + busy_loop_active: threading.Event, + engine_input_q: queue.Queue, + client_cmd_addr: str, + worker_cmd_addr: str, + fault_report_addr: str, + guard_identity: bytes, + tp_size: int, + pp_size: int, + dp_size: int, + ): + super().__init__(daemon=True) + self.engine_index = engine_index + self.fault_signal_q = fault_signal_q + self.cmd_q = cmd_q + self.busy_loop_active = busy_loop_active + self.engine_input_q = engine_input_q + self.tp_size = tp_size + self.pp_size = pp_size + self.dp_size = dp_size + + self.ctx = zmq.Context() + # Client <-> EngineCoreGuard sockets + self.fault_report_socket = make_zmq_socket( + self.ctx, + fault_report_addr, + zmq.DEALER, + bind=False, + identity=guard_identity, + ) + + self.client_cmd_socket = make_zmq_socket( + self.ctx, client_cmd_addr, zmq.DEALER, bind=False, identity=guard_identity + ) + # EngineCoreGuard <-> WorkerGuard sockets + self.worker_cmd_socket = make_zmq_socket( + self.ctx, worker_cmd_addr, zmq.ROUTER, bind=True + ) + self.poller = zmq.Poller() + self.communicator_aborted = False + self.engine_running = True + self.engine_core_guard_dead = False + self.logger = self._make_engine_core_guard_logger() + + def _make_engine_core_guard_logger(self): + prefix = f"[EngineCoreGuard_{self.engine_index}] " + + def log(msg, *args, level="info", **kwargs): + """ + level: "info", "warning", "error", "debug" + msg: log message + """ + getattr(logger, level)(prefix + msg, *args, **kwargs) + + return log + + def run(self) -> None: + """ + Run the main monitoring loop for EngineCoreGuard. + """ + poll_timeout_ms = 100 + while not self.engine_core_guard_dead: + # Check for engine fault signals + try: + engine_exception = self.fault_signal_q.get_nowait() + if isinstance(engine_exception, EngineLoopPausedError): + # The busy loop stopped due to another critical exception, + # put it back + self.logger("Engine paused", level="info") + else: + self.logger( + "[EngineCoreGuard] Detected exception %s: %s\n Call Stack:\n%s", + type(engine_exception).__name__, + engine_exception, + "".join(traceback.format_tb(engine_exception.__traceback__)), + level="error", + ) + self._report_client_exception(engine_exception) + self.engine_running = False + except queue.Empty: + pass + + if self.client_cmd_socket.closed: + self.logger("Client socket closed", level="info") + break + has_msg, _, cmd_str = recv_router_dealer_message( + self.client_cmd_socket, use_poller=True, poll_timeout=poll_timeout_ms + ) + if has_msg: + self.logger("Received cmd: %s", cmd_str, level="info") + self._execute_cmd(cmd_str) + + def _stop_worker_execution(self, soft_pause: bool, timeout: int = 2) -> bool: + if soft_pause: + pause_method = "pause_by_signal" + else: + pause_method = "pause_by_abort_communicators" + self.communicator_aborted = True + + success = self._execute_worker_method(pause_method, timeout=timeout) + return success + + def _execute_worker_method(self, method_name, timeout: int = 5, **kwargs) -> bool: + identities = set() + for tp_rank in range(self.tp_size): + for pp_rank in range(self.pp_size): + identity = f"{tp_rank}_{pp_rank}".encode() + identities.add(identity) + + method_uuid = broadcast_instruction( + self.worker_cmd_socket, identities, method_name, **kwargs + ) + + all_success = True + worker_responses = wait_for_instruction_result( + self.worker_cmd_socket, identities, method_name, timeout, method_uuid + ) + for identity in identities: + response = worker_responses.get(identity) + if response is None or not response.get("success", False): + all_success = False + + return all_success + + def _report_client_exception(self, exception: Exception) -> None: + msg = FaultInfo.from_exception(exception, self.engine_index).serialize() + msg_bytes = msg.encode("utf-8") + self.fault_report_socket.send_multipart([b"", msg_bytes]) + + def _execute_cmd(self, cmd_str): + """ + Execute a command received from ClientGuard. + """ + method, method_uuid, method_params = deserialize_method_call(cmd_str) + self.logger("Executing command: %s", method, level="info") + try: + success = run_method(self, method, args=(), kwargs=method_params) + self.logger("Command (%s) succeeded: %s", method, success, level="info") + + except Exception as e: + self.logger( + "Error executing method %s: %s %s", + method, + type(e).__name__, + e, + level="error", + ) + success = False + + self._send_execution_result(success, method_uuid) + + def pause(self, timeout: int = 1, soft_pause: bool = True) -> bool: + """ + Pause the busy loop safely. + Args: + timeout:wait for the busy loop to acknowledge the pause signal + soft_pause: if True, perform a soft pause using a flag; otherwise + abort the communicator + """ + self.logger("Start pausing EngineCore", level="info") + start_time = time.monotonic() + if self.engine_running: + # Clear the flag to signal busy loop should pause + self.busy_loop_active.clear() + # Put a sentinel (empty request) to unblock the busy loop + # if it's blocked on input_queue.get() + self.engine_input_q.put((EngineCoreRequestType.PAUSE, None)) + success = self._stop_worker_execution( + soft_pause=soft_pause, + timeout=timeout, + ) + elapsed = time.monotonic() - start_time + if success: + remaining_timeout = max(0, timeout - elapsed) + try: + # Wait for engine to acknowledge the pause via fault_signal_q + exception = self.fault_signal_q.get(timeout=remaining_timeout) + self.fault_signal_q.put(exception) + success = True + self.engine_running = False + except queue.Empty: + # Timeout waiting for pause acknowledgment + success = False + else: + # already paused + success = True + if not soft_pause: + # abort the communicators + self._stop_worker_execution(soft_pause=False, timeout=timeout) + return success + + def retry(self, timeout: int = 1): + """ + Handle the retry instruction from the ClientGuard. + This instruction tells the EngineCore to continue its busy loop + after being suspended due to an exception. + """ + start_time = time.monotonic() + + success = self._execute_worker_method("restart_worker", timeout=timeout) + if not success: + return success + + if self.dp_size > 1: + # If the Gloo communication times out + # the data parallel group (dp_group) needs to be reinitialized + command = "reinit_dp_group_on_fault_tolerance" + self.cmd_q.put(serialize_method_call(command)) + else: + self.cmd_q.put(None) + + # Ensure busy loop has been recovered. + elapsed = time.monotonic() - start_time + remaining_timeout = max(0, timeout - elapsed) + success = self.busy_loop_active.wait(timeout=remaining_timeout) + self.engine_running = success + return success + + def _send_execution_result(self, success: bool, method_uuid: str): + msg = { + "engine_index": self.engine_index, + "success": success, + "method_uuid": method_uuid, + } + msg_bytes = json.dumps(msg).encode("utf-8") + self.client_cmd_socket.send_multipart([b"", msg_bytes]) + + def shutdown(self): + if self.fault_report_socket is not None: + self.fault_report_socket.close() + if self.client_cmd_socket is not None: + self.client_cmd_socket.close() + if self.worker_cmd_socket is not None: + self.worker_cmd_socket.close() + if self.ctx is not None: + self.ctx.term() + self.engine_core_guard_dead = True + + +def busy_loop_wrapper(busy_loop_func): + """ + Wrap the busy loop function to perform fault tolerance. + """ + + def run_with_fault_tolerance(self): + while True: + try: + if self.enable_fault_tolerance: + self.busy_loop_active.set() + busy_loop_func(self) + except SystemExit: + raise + except Exception as original_exc: + if self.enable_fault_tolerance: + self.busy_loop_active.clear() + self.fault_signal_q.put(original_exc) + logger.warning( + "[BusyLoopWrapper] EngineCore busy loop raised an exception. " + "Suspended and waiting for fault tolerance " + "instructions." + ) + + # Put running requests into waiting list. + for req in list(self.scheduler.running): + self.scheduler.preempt_request(preempted_req=req) + self.scheduler.prev_step_scheduled_req_ids.clear() + + try: + # Block until recovery command received + cmd_str = self.cmd_q.get(timeout=self.engine_recovery_timeout) + logger.debug( + "[BusyLoopWrapper] Received fault tolerance command: %s", + cmd_str, + ) + if cmd_str is not None: + method, _, params = deserialize_method_call(cmd_str) + run_method(self, method, args=(), kwargs=params) + # recovery succeeded; restart the busy loop + continue + except queue.Empty: + # No handling instruction received within predefined + # timeout period. + logger.error( + "[BusyLoopWrapper] Fault tolerance instruction not received" + " within timeout. Proceeding with default exception " + "handling." + ) + except Exception as cmd_exc: + raise RuntimeError( + "Fault tolerance execution failed." + ) from cmd_exc + + # Fault tolerance not enabled OR no instruction received + # before timeout. Re-raise the original exception + # for upper level handling. + raise original_exc + + return run_with_fault_tolerance + + class EngineCore: """Inner loop of vLLM's Engine.""" @@ -567,10 +892,6 @@ def __init__( ): self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]() - executor_fail_callback = lambda: self.input_queue.put_nowait( - (EngineCoreRequestType.EXECUTOR_FAILED, b"") - ) - self.engine_index = engine_index identity = self.engine_index.to_bytes(length=2, byteorder="little") self.engines_running = False @@ -603,6 +924,50 @@ def __init__( self._init_data_parallel(vllm_config) + # Initialize fault tolerance settings. + ft_config = vllm_config.fault_tolerance_config + self.enable_fault_tolerance = ft_config.enable_fault_tolerance + if self.enable_fault_tolerance: + # Track whether the busy loop is currently active. + self.busy_loop_active = threading.Event() + self.fault_signal_q: queue.Queue[Exception] = queue.Queue() + self.cmd_q: queue.Queue[str | None] = queue.Queue() + self.engine_recovery_timeout = ft_config.engine_recovery_timeout + engine_core_guard_ids = addresses.engine_core_guard_identities + assert engine_core_guard_ids is not None + assert addresses.fault_report_addr is not None + assert addresses.client_cmd_addr is not None + assert addresses.engine_core_cmd_addrs is not None + engine_core_cmd_addr = addresses.engine_core_cmd_addrs[ + vllm_config.parallel_config.data_parallel_rank + ] + self.engine_core_guard = EngineCoreGuard( + engine_index=self.engine_index, + fault_signal_q=self.fault_signal_q, + cmd_q=self.cmd_q, + busy_loop_active=self.busy_loop_active, + engine_input_q=self.input_queue, + fault_report_addr=addresses.fault_report_addr, + client_cmd_addr=addresses.client_cmd_addr, + worker_cmd_addr=engine_core_cmd_addr, + guard_identity=engine_core_guard_ids[self.engine_index], + tp_size=vllm_config.parallel_config.tensor_parallel_size, + pp_size=vllm_config.parallel_config.pipeline_parallel_size, + dp_size=vllm_config.parallel_config.data_parallel_size, + ) + self.engine_core_guard.start() + vllm_config.fault_tolerance_config.engine_core_cmd_addr = ( + engine_core_cmd_addr + ) + # Do not shut down the engine immediately upon failure. + executor_fail_callback = lambda: self.fault_signal_q.put( + RuntimeError(f"Executor on EngineCore {self.engine_index} failed.") + ) + else: + executor_fail_callback = lambda: self.input_queue.put_nowait( + (EngineCoreRequestType.EXECUTOR_FAILED, b"") + ) + super().__init__( vllm_config, executor_class, log_stats, executor_fail_callback ) @@ -851,16 +1216,23 @@ def signal_handler(signum, frame): def _init_data_parallel(self, vllm_config: VllmConfig): pass + @busy_loop_wrapper def run_busy_loop(self): """Core busy loop of the EngineCore.""" # Loop until process is sent a SIGINT or SIGTERM while True: # 1) Poll the input queue until there is work to do. + self._check_busy_loop_active() self._process_input_queue() # 2) Step the engine core and return the outputs. + self._check_busy_loop_active() self._process_engine_step() + def _check_busy_loop_active(self): + if self.enable_fault_tolerance and not self.busy_loop_active.is_set(): + raise EngineLoopPausedError("Engine busy loop is paused.") + def _process_input_queue(self): """Exits when an engine step needs to be performed.""" @@ -907,6 +1279,8 @@ def _handle_client_request( self.add_request(req, request_wave) elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) + elif request_type == EngineCoreRequestType.PAUSE: + self._check_busy_loop_active() elif request_type == EngineCoreRequestType.UTILITY: client_idx, call_id, method_name, args = request output = UtilityOutput(call_id) @@ -1099,6 +1473,26 @@ def process_output_sockets( # Limit the number of buffers to reuse. reuse_buffers.append(buffer) + def engine_finish_requests(self): + assert isinstance(self.scheduler, V1Scheduler) + engine_finish_outputs = EngineCoreOutputs() + engine_finish_outputs.engine_index = self.engine_index + for request_id in list(self.scheduler.requests.keys()): + self.scheduler.finish_requests(request_id, RequestStatus.FINISHED_ABORTED) + engine_finish_outputs.outputs.append( + EngineCoreOutput( + request_id=request_id, + finish_reason=FinishReason.ABORT, + new_token_ids=[], + ) + ) + self.output_queue.put((0, engine_finish_outputs)) + + def shutdown(self): + super().shutdown() + if self.vllm_config.fault_tolerance_config.enable_fault_tolerance: + self.engine_core_guard.shutdown() + class DPEngineCoreProc(EngineCoreProc): """ZMQ-wrapper for running EngineCore in background process @@ -1153,7 +1547,10 @@ def _init_data_parallel(self, vllm_config: VllmConfig): ) self.dp_rank = dp_rank - self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() + self.dp_group = vllm_config.parallel_config.stateless_init_dp_group( + vllm_config.fault_tolerance_config.gloo_comm_timeout, + vllm_config.fault_tolerance_config.enable_fault_tolerance, + ) def shutdown(self): super().shutdown() @@ -1201,15 +1598,18 @@ def _maybe_publish_request_counts(self): ) self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats))) + @busy_loop_wrapper def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" # Loop until process is sent a SIGINT or SIGTERM while True: # 1) Poll the input queue until there is work to do. + self._check_busy_loop_active() self._process_input_queue() # 2) Step the engine core. + self._check_busy_loop_active() executed = self._process_engine_step() self._maybe_publish_request_counts() @@ -1221,9 +1621,11 @@ def run_busy_loop(self): # We are in a running state and so must execute a dummy pass # if the model didn't execute any ready requests. + self._check_busy_loop_active() self.execute_dummy_batch() # 3) All-reduce operation to determine global unfinished reqs. + self._check_busy_loop_active() self.engines_running = self._has_global_unfinished_reqs( local_unfinished_reqs ) @@ -1256,6 +1658,13 @@ def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) + def reinit_dp_group_on_fault_tolerance(self): + stateless_destroy_torch_distributed_process_group(self.dp_group) + self.dp_group = self.vllm_config.parallel_config.stateless_init_dp_group( + self.vllm_config.fault_tolerance_config.gloo_comm_timeout, + self.vllm_config.fault_tolerance_config.enable_fault_tolerance, + ) + def reinitialize_distributed( self, reconfig_request: ReconfigureDistributedRequest ) -> None: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 9b440505bd9d..626ff97354e8 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -2,9 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import contextlib +import json import multiprocessing import queue import sys +import time import uuid import weakref from abc import ABC, abstractmethod @@ -16,19 +18,23 @@ from typing import Any, TypeAlias, TypeVar import msgspec.msgpack +import regex as re import zmq import zmq.asyncio +from ray.util.state import get_actor from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask from vllm.utils.async_utils import in_loop +from vllm.utils.collection_utils import ThreadSafeDict from vllm.utils.network_utils import ( close_sockets, get_open_port, get_open_zmq_inproc_path, make_zmq_socket, + recv_router_dealer_message, ) from vllm.v1.engine import ( EngineCoreOutputs, @@ -40,14 +46,15 @@ ) from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc -from vllm.v1.engine.exceptions import EngineDeadError +from vllm.v1.engine.exceptions import EngineDeadError, FaultInfo from vllm.v1.engine.utils import ( CoreEngineActorManager, CoreEngineProcManager, + FaultHandler, launch_core_engines, ) from vllm.v1.executor import Executor -from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr, run_method logger = init_logger(__name__) @@ -249,6 +256,12 @@ async def collective_rpc_async( ) -> list[_R]: raise NotImplementedError + async def handle_fault(self, instruction: str, timeout: int) -> bool: + raise NotImplementedError + + async def fault_reporter(self): + raise NotImplementedError + class InprocClient(EngineCoreClient): """ @@ -332,6 +345,127 @@ def dp_engines_running(self) -> bool: return False +class ClientGuard: + def __init__( + self, + fault_receiver_addr: str, + cmd_addr: str, + engine_registry: list[bytes], + engine_exception_q: asyncio.Queue[FaultInfo], + engine_exception_q_lock: asyncio.Lock, + fault_pub_addr: str, + engine_status_dict: ThreadSafeDict[int, str], + ): + self.engine_registry = engine_registry + self.zmq_ctx = zmq.Context() + self.fault_receiver_socket = make_zmq_socket( + ctx=self.zmq_ctx, + path=fault_receiver_addr, + socket_type=zmq.ROUTER, + bind=True, + ) + self.cmd_socket = make_zmq_socket( + ctx=self.zmq_ctx, path=cmd_addr, socket_type=zmq.ROUTER, bind=True + ) + + self.fault_pub_socket = make_zmq_socket( + ctx=self.zmq_ctx, path=fault_pub_addr, socket_type=zmq.PUB, bind=True + ) + + self.engine_exception_q: asyncio.Queue[FaultInfo] = engine_exception_q + + self.engine_exception_q_lock = engine_exception_q_lock + + self.engine_status_dict: ThreadSafeDict[int, str] = engine_status_dict + + self.fault_handler = FaultHandler( + self.cmd_socket, + self.engine_registry, + self.engine_exception_q, + self.engine_exception_q_lock, + self.engine_status_dict, + ) + + self.logger = self._make_client_guard_logger() + + self.client_guard_dead = False + Thread( + target=self.fault_receiver, daemon=True, name="EngineCoreFaultReceiver" + ).start() + + def _make_client_guard_logger(self): + prefix = "[client_guard] " + + def log(msg, *args, level="info", **kwargs): + """ + level: "info", "warning", "error", "debug" + msg: log message + """ + getattr(logger, level)(prefix + msg, *args, **kwargs) + + return log + + async def handle_fault(self, instruction: str, timeout: int, **kwargs) -> bool: + """ + Executes fault tolerance measures based on the fault tolerance instructions + received from the api_server. + + This method processes the fault tolerance commands/instructions passed by the + api_server, then implements corresponding fault tolerance strategies or actions + to handle system anomalies, ensuring stable operation or graceful degradation + of the relevant components. + """ + return await run_method( + self.fault_handler, + "handle_fault", + args=(instruction, timeout), + kwargs=kwargs, + ) + + def fault_receiver(self): + """ + Continuously listens for exception/error information sent from the engine_core. + + This method maintains a persistent listening state to capture and process + fault-related data, exceptions, or error notifications emitted by the + engine_core component. It is designed to run continuously to ensure no critical + error information from the engine core is missed. + """ + while True: + _, sender_identity, message = recv_router_dealer_message( + self.fault_receiver_socket + ) + if self.client_guard_dead: + self.logger("client guard dead, stop receiving fault") + break + assert message is not None, ( + "message should not be None at fault tolerance scenario" + ) + + fault_info = FaultInfo.from_json(message) + self.engine_exception_q.put_nowait(fault_info) + engine_status = "Dead" if "dead" in fault_info.type else "Unhealthy" + self.engine_status_dict[int(fault_info.engine_id)] = engine_status + self.fault_pub_socket.send_string( + f"vllm_fault|{json.dumps(self.engine_status_dict.to_dict())}" + ) + # TODO Asynchronous issuance of pause commands and design of engine + # core status + # Pause healthy engines on fault. + # Pause will be invoked again during fault-tolerance handling, + # so it's unnecessary to track whether all engines are currently + # paused. + self.fault_handler.submit_fault("pause", 5, soft_pause=False) + + def shutdown_guard(self): + self.client_guard_dead = True + self.fault_receiver_socket.close() + self.cmd_socket.close() + self.fault_pub_socket.close() + self.zmq_ctx.term() + self.logger("ClientGuard is closed.", level="info") + + @dataclass class BackgroundResources: """Used as a finalizer for clean shutdown, avoiding @@ -350,6 +484,7 @@ class BackgroundResources: output_queue_task: asyncio.Task | None = None stats_update_task: asyncio.Task | None = None shutdown_path: str | None = None + client_guard: ClientGuard | None = None # Set if any of the engines are dead. Here so that the output # processing threads can access it without holding a ref to the client. @@ -363,6 +498,8 @@ def __call__(self): self.engine_manager.close() if self.coordinator is not None: self.coordinator.close() + if self.client_guard is not None: + self.client_guard.shutdown_guard() if isinstance(self.output_socket, zmq.asyncio.Socket): # Async case. @@ -454,6 +591,7 @@ def __init__( self.resources = BackgroundResources(ctx=sync_ctx) self._finalizer = weakref.finalize(self, self.resources) success = False + try: # State used for data parallel. self.engines_running = False @@ -536,8 +674,41 @@ def __init__( self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]() # Start monitoring engine core processes for unexpected failures - self.start_engine_core_monitor() + if self.vllm_config.parallel_config.data_parallel_backend == "ray": + self.start_engine_core_actor_monitor() + else: + self.start_engine_core_monitor() + if vllm_config.fault_tolerance_config.enable_fault_tolerance: + self.engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() + assert addresses.fault_report_addr is not None, ( + "addresses.fault_report_addr should not be None at fault tolerance" + " scenario" + ) + assert addresses.client_cmd_addr is not None, ( + "addresses.client_cmd_addr should not be None at fault tolerance" + " scenario" + ) + self.engine_registry = addresses.engine_core_guard_identities + self.engine_exception_q_lock = asyncio.Lock() + assert self.engine_registry is not None + assert addresses.fault_pub_socket_addr is not None, ( + "addresses.fault_pub_socket_addr should not be None at" + "fault tolerance scenario" + ) + self.engine_status_dict: ThreadSafeDict[int, str] = ThreadSafeDict() + for engine_id in range(vllm_config.parallel_config.data_parallel_size): + self.engine_status_dict[engine_id] = "Healthy" + self.client_guard = ClientGuard( + addresses.fault_report_addr, + addresses.client_cmd_addr, + self.engine_registry, + self.engine_exception_q, + self.engine_exception_q_lock, + addresses.fault_pub_socket_addr, + self.engine_status_dict, + ) + self.resources.client_guard = self.client_guard success = True finally: if not success: @@ -568,6 +739,51 @@ def free_pending_messages(self): def dp_engines_running(self) -> bool: return self.engines_running + def start_engine_core_actor_monitor(self): + engine_manager = self.resources.engine_manager + if ( + not isinstance(engine_manager, CoreEngineActorManager) + or not self.vllm_config.fault_tolerance_config.enable_fault_tolerance + ): + return + + def monitor_actors(): + all_actors = ( + engine_manager.local_engine_actors + engine_manager.remote_engine_actors + ) + if not all_actors: + return + while True: + for actor in all_actors: + actor_id = actor._actor_id.hex() + if actor in engine_manager.local_engine_actors: + actor_index = engine_manager.local_engine_actors.index(actor) + elif actor in engine_manager.remote_engine_actors: + actor_index = engine_manager.remote_engine_actors.index( + actor + ) + len(engine_manager.local_engine_actors) + else: + logger.error("Unknown actor (ID: %s)", actor_id) + continue + + actor_info = get_actor(actor_id) + if actor_info.state == "DEAD": + fault_info = FaultInfo( + type="engine_actor dead", + message=str(actor_info.death_cause), + engine_id=str(actor_index), + additional_info=None, + ) + all_actors.remove(actor) + engine_manager.engine_down_socket.send_multipart( + [b"", fault_info.serialize().encode("utf-8")] + ) + + time.sleep(3) + # Implements the "check once every 3 seconds" frequency control + + Thread(target=monitor_actors, daemon=True, name="MPClientEngineMonitor").start() + def start_engine_core_monitor(self): """Start a monitor thread for engine core processes.""" engine_manager = self.resources.engine_manager @@ -587,19 +803,54 @@ def start_engine_core_monitor(self): # callback to inform the engine. def monitor_engine_cores(): sentinels = [proc.sentinel for proc in engine_processes] - died = multiprocessing.connection.wait(sentinels) + _self = self_ref() - if not _self or _self.resources.engine_dead: - return - _self.resources.engine_dead = True - proc_name = next( - proc.name for proc in engine_processes if proc.sentinel == died[0] - ) - logger.error( - "Engine core proc %s died unexpectedly, shutting down client.", - proc_name, - ) - _self.shutdown() + if self.vllm_config.fault_tolerance_config.enable_fault_tolerance: + while sentinels: + died = multiprocessing.connection.wait(sentinels) + for sentinel in died: + died_proc = next( + proc + for proc in engine_processes + if proc.sentinel == sentinel + ) + + match = re.match(r"EngineCore_DP(\d+)", died_proc.name) + engine_rank = match.group(1) + + fault_info = FaultInfo( + type="engine_core dead", + message=f"Engine core proc {died_proc.pid} " + f"(PID: {died_proc.name}) died unexpectedly.", + engine_id=engine_rank, + additional_info=None, + ) + + engine_manager.engine_down_socket.send_multipart( + [b"", fault_info.serialize().encode("utf-8")] + ) + + sentinels.remove(sentinel) + + logger.error( + "Engine core proc %s died unexpectedly", + died_proc.name, + ) + else: + died = multiprocessing.connection.wait(sentinels) + _self = self_ref() + if not _self or _self.resources.engine_dead: + return + proc_name = next( + proc.name for proc in engine_processes if proc.sentinel == died[0] + ) + logger.error( + "Engine core proc %s died unexpectedly, shutting down client.", + proc_name, + ) + if _self and _self.resources: + _self.resources.engine_dead = True + _self.shutdown() # Note: For MPClient, we don't have a failure callback mechanism # like MultiprocExecutor, but we set engine_dead flag which will # cause subsequent operations to raise EngineDeadError @@ -608,6 +859,13 @@ def monitor_engine_cores(): target=monitor_engine_cores, daemon=True, name="MPClientEngineMonitor" ).start() + async def handle_fault(self, instruction: str, timeout: int, **kwargs) -> bool: + """handle fault of current instance by instruction""" + return await self.client_guard.handle_fault(instruction, timeout, **kwargs) + + async def fault_reporter(self): + return self.engine_status_dict.to_dict() + def _process_utility_output( output: UtilityOutput, utility_results: dict[int, AnyFuture] diff --git a/vllm/v1/engine/exceptions.py b/vllm/v1/engine/exceptions.py index d9f79a019e2d..4f24c7fa9080 100644 --- a/vllm/v1/engine/exceptions.py +++ b/vllm/v1/engine/exceptions.py @@ -1,5 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import time +from dataclasses import dataclass + + class EngineGenerateError(Exception): """Raised when a AsyncLLM.generate() fails. Recoverable.""" @@ -16,3 +21,69 @@ def __init__(self, *args, suppress_context: bool = False, **kwargs): # Make stack trace clearer when using with LLMEngine by # silencing irrelevant ZMQError. self.__suppress_context__ = suppress_context + + +class EngineLoopPausedError(Exception): + """ + Raised when the EngineCore loop is temporarily paused on purpose, + e.g., to handle fault-tolerance. + """ + + pass + + +@dataclass +class FaultInfo: + type: str + message: str + engine_id: str + timestamp: str | None = None + additional_info: dict | None = None + + def __post_init__(self): + # If no exit time is specified, the current timestamp will be used by default. + + local_time = time.localtime(time.time()) + if self.timestamp is None: + self.timestamp = time.strftime("%H:%M:%S", local_time) + + @classmethod + def from_exception( + cls, + exception: Exception, + engine_id: str | int, + additional_info: dict | None = None, + ) -> "FaultInfo": + """Create FaultInfo from an exception.""" + return cls( + type=type(exception).__name__, + message=str(exception), + engine_id=str(engine_id), + additional_info=additional_info or {}, + ) + + def to_dict(self) -> dict: + """Convert to dictionary for serialization.""" + return { + "type": self.type, + "message": self.message, + "timestamp": self.timestamp, + "engine_id": self.engine_id, + "additional_info": self.additional_info, + } + + def serialize(self) -> str: + """Serialize to JSON string.""" + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str: str) -> "FaultInfo": + """Create FaultInfo from JSON string.""" + data = json.loads(json_str) + return cls( + type=data["type"], + message=data["message"], + timestamp=data["timestamp"], + engine_id=data["engine_id"], + additional_info=data["additional_info"], + ) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index d65cad7af03d..aa9e97edd58d 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -1,8 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import asyncio import contextlib +import json +import multiprocessing import os +import time +import uuid import weakref from collections.abc import Callable, Iterator from dataclasses import dataclass @@ -13,6 +17,7 @@ from unittest.mock import patch import msgspec +import regex as re import zmq from vllm import envs @@ -20,10 +25,18 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy -from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx +from vllm.utils.collection_utils import ThreadSafeDict +from vllm.utils.network_utils import ( + get_open_zmq_ipc_path, + make_zmq_socket, + recv_router_dealer_message, + zmq_socket_ctx, +) from vllm.utils.system_utils import get_mp_context from vllm.v1.engine.coordinator import DPCoordinator +from vllm.v1.engine.exceptions import FaultInfo from vllm.v1.executor import Executor +from vllm.v1.serial_utils import serialize_method_call from vllm.v1.utils import get_engine_client_zmq_addr, shutdown if TYPE_CHECKING: @@ -56,6 +69,8 @@ class EngineZmqAddresses: inputs: list[str] # ZMQ output socket addresses for each front-end client (responses) outputs: list[str] + + engine_core_cmd_addrs: list[str] | None = None # ZMQ input socket address of DP coordinator if applicable coordinator_input: str | None = None # ZMQ output socket address of DP coordinator if applicable @@ -64,6 +79,14 @@ class EngineZmqAddresses: # Not used by engine, just relayed to front-end in handshake response. # Only required for external DP LB case. frontend_stats_publish_address: str | None = None + # + fault_report_addr: str | None = None + # ZMQ client_cmd socket address of client guard + client_cmd_addr: str | None = None + # identities of engine_core_guard + engine_core_guard_identities: list[bytes] | None = None + # ZMQ fault_pub_socket address of client guard + fault_pub_socket_addr: str | None = None @dataclass @@ -105,7 +128,23 @@ def __init__( "executor_class": executor_class, "log_stats": log_stats, } - + if vllm_config.fault_tolerance_config.enable_fault_tolerance: + zmq_ctx = zmq.Context() + identity = generate_identity_group( + "core_engine_proc_manager", "client_guard", "report", 1 + )[0] + zmq_addr = get_engine_client_zmq_addr( + local_only=False, + host=vllm_config.parallel_config.data_parallel_master_ip, + port=vllm_config.fault_tolerance_config.internal_fault_report_port, + ) + self.engine_down_socket = make_zmq_socket( + ctx=zmq_ctx, + path=zmq_addr, + socket_type=zmq.DEALER, + bind=False, + identity=identity, + ) if client_handshake_address: common_kwargs["client_handshake_address"] = client_handshake_address @@ -131,6 +170,8 @@ def __init__( self._finalizer = weakref.finalize(self, shutdown, self.processes) + self.vllm_config = vllm_config + data_parallel = vllm_config.parallel_config.data_parallel_size > 1 try: for proc, local_dp_rank in zip(self.processes, local_dp_ranks): @@ -154,12 +195,53 @@ def __init__( if self.finished_procs(): self.close() + def _report_engine_dead(self, dead_message): + """Send engine dead message to ClientGuard""" + try: + self.engine_down_socket.send_multipart( + [ + b"", # Empty frame separator + dead_message.encode("utf-8"), + ] + ) + logger.info("Sent message to ClientGuard: %s", dead_message) + except Exception as e: + logger.error("Failed to send message: %s", e) + def close(self): """Shutdown all procs.""" self._finalizer() + def start_engine_core_monitor(self): + sentinels = [proc.sentinel for proc in self.processes] + while self.processes: + died = multiprocessing.connection.wait(sentinels) + for sentinel in died: + died_proc = next( + proc for proc in self.processes if proc.sentinel == sentinel + ) + + match = re.match(r"EngineCore_DP(\d+)", died_proc.name) + engine_rank = match.group(1) + + fault_info = FaultInfo( + type="engine_core dead", + message=f"Engine core proc {died_proc.pid} " + f"(PID: {died_proc.name}) died unexpectedly.", + engine_id=engine_rank, + additional_info=None, + ) + self.engine_down_socket.send_multipart( + [b"", fault_info.serialize().encode("utf-8")] + ) + if isinstance(sentinel, int) and sentinel in sentinels: + sentinels.remove(sentinel) + logger.error( + "Engine core proc %s died unexpectedly", + died_proc, + ) + def join_first(self): - """Wait for any process to exit.""" connection.wait(proc.sentinel for proc in self.processes) def sentinels(self) -> list: @@ -267,6 +349,24 @@ def __init__( local_engine_count = vllm_config.parallel_config.data_parallel_size_local world_size = vllm_config.parallel_config.world_size + if vllm_config.fault_tolerance_config.enable_fault_tolerance: + zmq_ctx = zmq.Context() + zmq_addr = get_engine_client_zmq_addr( + local_only=False, + host=vllm_config.parallel_config.data_parallel_master_ip, + port=vllm_config.fault_tolerance_config.internal_fault_report_port, + ) + identity = generate_identity_group( + "core_engine_actor_manager", "clinet_guard", "report", 1 + )[0] + self.engine_down_socket = make_zmq_socket( + ctx=zmq_ctx, + path=zmq_addr, + socket_type=zmq.DEALER, + bind=False, + identity=identity, + ) + if ray.is_initialized(): logger.info("Ray is already initialized. Skipping Ray initialization.") else: @@ -332,6 +432,7 @@ def __init__( local_dp_rank=local_index, ) ) + if local_client: self.local_engine_actors.append(actor) else: @@ -808,6 +909,30 @@ def launch_core_engines( ], ) + if vllm_config.fault_tolerance_config.enable_fault_tolerance is True: + addresses.engine_core_cmd_addrs = [ + get_engine_client_zmq_addr(client_local_only, host) for _ in range(dp_size) + ] + addresses.fault_report_addr = get_engine_client_zmq_addr( + local_only=False, + host=vllm_config.parallel_config.data_parallel_master_ip, + port=vllm_config.fault_tolerance_config.internal_fault_report_port, + ) + addresses.client_cmd_addr = get_engine_client_zmq_addr( + local_only=client_local_only, host=host + ) + addresses.engine_core_guard_identities = generate_identity_group( + peer1="client", + peer2="engine_core_guard", + use="report and cmd", + n=dp_size, + ) + addresses.fault_pub_socket_addr = get_engine_client_zmq_addr( + local_only=False, + host="0.0.0.0", + port=vllm_config.fault_tolerance_config.external_fault_notify_port, + ) + # Run the DP Coordinator process with rank 0 when in # online DP mode. run_coordinator = dp_size > 1 and not offline_mode and dp_rank == 0 @@ -1070,3 +1195,308 @@ def wait_for_engine_startup( "local" if local else "remote", eng_index, ) + + +def generate_unique_uuids(n: int) -> set[uuid.UUID]: + """Generate a set of unique UUID v4 objects. + + Generates a specified number of unique UUID (version 4) objects. + UUID v4 uses cryptographically strong random numbers, ensuring + an extremely low probability of collisions. + + Args: + n: The number of unique UUIDs to generate + + Returns: + A set containing 'n' unique UUID objects + """ + uuids: set[uuid.UUID] = set() + while len(uuids) < n: + # Generate a random UUID (version 4) and add to the set + uuids.add(uuid.uuid4()) + return uuids + + +def generate_identity_group(peer1, peer2, use, n): + """ + Generate n unique identities for ZMQ ROUTER nodes + + Format: peer1_peer2_use_random number + Return: list with identities in byte type as elements + """ + identitys = list() + uuids = generate_unique_uuids(n) + for id in uuids: + identity_str = f"{peer1}_{peer2}_{use}_{id}".encode() + identitys.append(identity_str) + return identitys + + +async def get_queue_snapshot(queue: asyncio.Queue, queue_lock: asyncio.Lock) -> list: + """Thread-safe snapshot of the exception queue.""" + async with queue_lock: + items = [] + # get item at first + while not queue.empty(): + item = queue.get_nowait() + items.append(item) + # put item into queue again + for item in items: + queue.put_nowait(item) + return items + + +def broadcast_instruction( + cmd_socket, + target_identities: set[bytes] | list[bytes], + method_name: str, + method_uuid: str | None = None, + **kwargs, +) -> str: + """ + Broadcast an instruction message to multiple remote endpoints. + It serializes the specified method_name along with its parameters and + dispatches it to all target identities via the provided ZeroMQ socket. + """ + if method_uuid is None: + method_uuid = str(uuid.uuid4()) + + for identity in target_identities: + serialized_instruction = serialize_method_call( + method_name, method_uuid, **kwargs + ) + cmd_socket.send_multipart( + [identity, b"", serialized_instruction.encode("utf-8")] + ) + + return method_uuid + + +def wait_for_instruction_result( + cmd_socket: zmq.Socket, + target_identities: set[bytes] | list[bytes], + method_name: str, + timeout: int, + method_uuid: str, +) -> dict[bytes, dict]: + """ + Wait for acknowledgment or result messages from multiple endpoints. + This function listens for responses corresponding to a previously broadcasted + instruction, identified by the given `method_uuid`. + + Args: + cmd_socket: The socket used to receive responses. + target_identities: Identities that are expected to respond. + method_name: The name of the method_name (used for logging). + timeout: The maximum wait time (in seconds). + method_uuid: The unique identifier associated with the method_name. + + Notes: + - This function does not raise exceptions for timeouts or parsing errors. + Instead, it logs the issue and returns whatever responses have been collected. + """ + start = time.monotonic() + responses: dict[bytes, dict] = {} + + target_identities = set(target_identities) + + while target_identities: + remaining = timeout - (time.monotonic() - start) + if remaining <= 0: + logger.debug( + 'Timeout while waiting for responses of command "%s" ' + "from identities: %s", + method_name, + target_identities, + ) + # Return partial results collected so far + return responses + + try: + has_msg, identity, response = recv_router_dealer_message( + cmd_socket, + use_poller=True, + poll_timeout=int(remaining * 1000), + ) + + # Skip if no message was received during this polling period + if not has_msg: + continue + + assert identity is not None + assert response is not None + response_dict = json.loads(response) + recv_uuid = response_dict.get("method_uuid") + + # Ignore outdated or unrelated messages + if recv_uuid != method_uuid: + logger.debug( + "Discarding outdated response: expected method_uuid=%s, got %s", + method_uuid, + recv_uuid, + ) + continue + + # Record this engine's response + responses[identity] = response_dict + target_identities.discard(identity) + + except Exception as e: + logger.error("Error while processing engine response: %s", e) + # Return partial results even on exception to avoid data loss + return responses + + return responses + + +class FaultHandler: + def __init__( + self, + cmd_socket: zmq.Socket, + client_cmd_registry: list[bytes], + engine_exception_q: asyncio.Queue[FaultInfo], + engine_exception_q_lock: asyncio.Lock, + engine_status_dict: ThreadSafeDict[int, str], + ) -> None: + self.cmd_socket = cmd_socket + self.engine_exception_q = engine_exception_q + self.engine_exception_q_lock = engine_exception_q_lock + self.engine_status_dict: ThreadSafeDict[int, str] = engine_status_dict + self.engine_identity_to_index: dict[bytes, int] = { + identity: i for i, identity in enumerate(client_cmd_registry) + } + # ensure handle_fault is executed sequentially + self._task_queue: asyncio.Queue = asyncio.Queue() + self._loop = asyncio.get_event_loop() + self._dispatcher_task = self._loop.create_task(self._dispatcher()) + + self.logger = self._make_fault_handler_logger() + + def _make_fault_handler_logger(self): + prefix = "[FaultHandler] " + + def log(msg, *args, level="info", **kwargs): + """ + level: "info", "warning", "error", "debug" + msg: log message + """ + getattr(logger, level)(prefix + msg, *args, **kwargs) + + return log + + async def _dispatcher(self): + while True: + # each elements in the queue contains: + # (instruction, timeout, kwargs, future) + instruction, timeout, kwargs, fut = await self._task_queue.get() + try: + result = await self._handle_fault_internal( + instruction, timeout, **kwargs + ) + if fut: + fut.set_result(result) + except Exception as e: + if fut: + fut.set_exception(e) + + async def _handle_fault_internal( + self, instruction: str, timeout: int, **kwargs + ) -> bool: + if instruction == "retry" and "Dead" in self.engine_status_dict.values(): + self.logger( + "engine_core dead unexpectedly, retry is impossible," + "shutdown will be performed", + level="info", + ) + return False + + if instruction == "pause": + logger.warning( + "Pause operation is best-effort only. Due to the complexity of " + "collective communications (e.g., timing dependencies and " + "synchronization barriers), pausing may not always succeed. If " + "the process remains unresponsive or collective operations " + "cannot be interrupted, consider shutting down and restarting " + "the instance." + ) + + dead_engine_indices = { + index + for index, status in self.engine_status_dict.items() + if status == "Dead" + } + + target_engines = { + identity + for identity, index in self.engine_identity_to_index.items() + if index not in dead_engine_indices + } + else: + target_engines = set(self.engine_identity_to_index.keys()) + + if timeout is not None: + kwargs["timeout"] = timeout + + method_uuid = broadcast_instruction( + self.cmd_socket, + target_engines, + instruction, + **kwargs, + ) + + engine_responses = wait_for_instruction_result( + self.cmd_socket, target_engines, instruction, timeout, method_uuid + ) + + # check the execution results + all_success = True + for engine_id in target_engines: + engine_index = self.engine_identity_to_index.get(engine_id, "?") + response = engine_responses.get(engine_id) + + if response is None: + self.logger( + "EngineCoreGuard[%s] did not respond" + ' to command "%s" within timeout.', + engine_index, + instruction, + level="info", + ) + all_success = False + elif not response.get("success", False): + self.logger( + 'EngineCoreGuard[%s] failed to execute command "%s" (reason: %s)', + engine_index, + instruction, + response.get("reason", "unknown"), + level="error", + ) + all_success = False + + if instruction == "retry" and all_success: + for engine_index, _ in self.engine_status_dict.items(): + self.engine_status_dict[engine_index] = "Healthy" + # todo: should we also clear the engine_exception_q here? + return all_success + + async def handle_fault(self, instruction: str, timeout: int, **kwargs) -> bool: + """ + Async interface for run_method, returns a Future that can be awaited. + This method **must be called from the event loop thread** where this + FaultHandler was created. + """ + fut = self._loop.create_future() + await self._task_queue.put((instruction, timeout, kwargs, fut)) + return await fut + + def submit_fault(self, instruction: str, timeout: int, **kwargs) -> None: + """ + thread-safe fire-and-forget submission of a fault handling task. + This method can be called from **any thread** + """ + + def _enqueue(): + fut = self._loop.create_future() + self._task_queue.put_nowait((instruction, timeout, kwargs, fut)) + + self._loop.call_soon_threadsafe(_enqueue) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 0a6806390451..1eead471df47 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -3,7 +3,9 @@ import dataclasses import importlib +import json import pickle +import uuid from collections.abc import Callable, Sequence from functools import partial from inspect import isclass @@ -452,6 +454,54 @@ def ext_hook(self, code: int, data: memoryview) -> Any: raise NotImplementedError(f"Extension type code {code} is not supported") +def deserialize_method_call(json_str: str) -> tuple[str, str, dict[str, Any]]: + """ + Deserialize an encoded method call. + + Args: + json_str (str): JSON string representing a serialized method call. + + Returns: + tuple[str, dict[str, Any]]: + - method (str): The method name. + - method_uuid (str): The UUID identifying the method call. + - params (dict[str, Any]): Additional method parameters. + """ + try: + payload = json.loads(json_str) + if not isinstance(payload, dict): + raise ValueError("Top-level JSON must be an object") + except Exception as e: + logger.error("Invalid JSON input: %s", e) + raise ValueError(f"Invalid JSON: {e}") from e + + try: + method = payload.pop("method") + method_uuid = payload.pop("method_uuid") + except KeyError as e: + logger.error( + "Missing required field: %s (payload=%s)", e.args[0], json_str[:200] + ) + raise ValueError(f"Missing required field: {e.args[0]}") from e + + # Remaining fields are treated as parameters + params = payload + + return method, method_uuid, params + + +def serialize_method_call( + method: str, method_uuid: str | None = None, **params: Any +) -> str: + """ + Serialize a method invocation into a JSON string. + """ + if method_uuid is None: + method_uuid = str(uuid.uuid4()) + payload = {"method": method, "method_uuid": method_uuid, **params} + return json.dumps(payload) + + def run_method( obj: Any, method: str | bytes | Callable, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0102ca4739ad..ae2ee4cdf809 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3,6 +3,7 @@ import gc import itertools +import threading import time from collections import defaultdict from collections.abc import Iterator @@ -98,6 +99,7 @@ split_attn_metadata, ) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher +from vllm.v1.engine.exceptions import EngineLoopPausedError from vllm.v1.kv_cache_interface import ( AttentionSpec, ChunkedLocalAttentionSpec, @@ -574,6 +576,12 @@ def __init__( self.execute_model_state: ExecuteModelState | None = None self.kv_connector_output: KVConnectorOutput | None = None + self.pause_event = threading.Event() + + def _check_pause_event(self): + if self.pause_event.is_set(): + raise EngineLoopPausedError("Worker is paused.") + def reset_mm_cache(self) -> None: if self.mm_budget: self.mm_budget.reset_cache() @@ -2614,6 +2622,7 @@ def _model_forward( Returns: Model output tensor """ + self._check_pause_event() return self.model( input_ids=input_ids, positions=positions, @@ -3848,6 +3857,7 @@ def _dummy_run( ubatch_slices=ubatch_slices, ), ): + self._check_pause_event() outputs = self.model( input_ids=input_ids, positions=positions, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 315f01b68499..ad50474ecae3 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -3,18 +3,28 @@ """A GPU worker class.""" import gc +import json import os +import threading +import time +import traceback +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor, TimeoutError, as_completed from contextlib import AbstractContextManager, nullcontext +from datetime import timedelta +from functools import partial from types import NoneType from typing import TYPE_CHECKING, Any import torch import torch.distributed import torch.nn as nn +import zmq import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed import ( + cleanup_dist_env_and_memory, ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce, @@ -26,6 +36,8 @@ has_kv_transfer_group, ) from vllm.distributed.parallel_state import ( + GroupCoordinator, + get_all_model_groups, get_pp_group, get_tp_group, ) @@ -40,6 +52,7 @@ from vllm.tasks import SupportedTask from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_utils import MemorySnapshot, memory_profiling +from vllm.utils.network_utils import make_zmq_socket, recv_router_dealer_message from vllm.v1.core.sched.output import GrammarOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec @@ -48,6 +61,7 @@ DraftTokenIds, ModelRunnerOutput, ) +from vllm.v1.serial_utils import deserialize_method_call, run_method from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.utils import is_residual_scattered_for_sp @@ -60,6 +74,183 @@ from vllm.v1.core.sched.output import SchedulerOutput +class WorkerGuard: + def __init__( + self, + vllm_config: VllmConfig, + pause_event: threading.Event, + init_distributed_env_callback: Callable, + clear_input_batch_callback: Callable, + device: torch.cuda.device, + ): + self.vllm_config = vllm_config + self.zmq_ctx = zmq.Context() + self.dp_rank = vllm_config.parallel_config.data_parallel_rank + self.tp_rank = get_tp_group().rank_in_group + self.pp_rank = get_pp_group().rank_in_group + self.init_distributed_env_callback = init_distributed_env_callback + self.clear_input_batch_callback = clear_input_batch_callback + self.device = device + identity = f"{self.tp_rank}_{self.pp_rank}".encode() + worker_cmd_addr = vllm_config.fault_tolerance_config.engine_core_cmd_addr + self.cmd_socket = make_zmq_socket( + ctx=self.zmq_ctx, + path=worker_cmd_addr, + socket_type=zmq.DEALER, + bind=False, + identity=identity, + ) + self.worker_guard_dead = False + self.pause_event = pause_event + self.communicator_aborted = False + self.logger = self._make_worker_logger() + threading.Thread( + target=self.run, daemon=True, name="WorkerGuardCmdReceiver" + ).start() + + def _make_worker_logger(self): + prefix = f"[WorkerGuard_dp{self.dp_rank}_tp{self.tp_rank}_pp{self.pp_rank}] " + + def log(msg, *args, level="info", **kwargs): + """ + level: "info", "warning", "error", "debug" + msg: log message + """ + getattr(logger, level)(prefix + msg, *args, **kwargs) + + return log + + def run(self): + """Run the message receiving loop and handle control commands""" + torch.cuda.set_device(self.device) + while True: + # Use blocking receive - will wait until a message arrives + has_msg, _, cmd_str = recv_router_dealer_message(self.cmd_socket) + if self.worker_guard_dead: + self.logger("Worker guard dead, exiting") + break + if has_msg: + assert cmd_str is not None + method, method_uuid, params = deserialize_method_call(cmd_str) + self.logger("Executing command: %s", method) + try: + success = run_method(self, method, args=(), kwargs=params) + except Exception as e: + self.logger( + "Error executing method %s: %s %s\n Call Stack:\n %s", + method, + type(e).__name__, + e, + "".join(traceback.format_tb(e.__traceback__)), + level="error", + ) + success = False + self._send_execution_result(success, method_uuid) + + def pause_by_signal(self): + self._set_device_communicator_status(False) + self.pause_event.set() + self.logger("Pause signal sent.") + return True + + def pause_by_abort_communicators(self, timeout=5): + """ + Abort all NCCL communicators and process groups in parallel using a thread pool. + """ + if self.communicator_aborted: + return True + self._set_device_communicator_status(False) + torch.cuda.set_device(self.device) + model_groups = get_all_model_groups() + futures = [] + start_time = time.time() + + def _abort_nccl_comm(group: GroupCoordinator): + if group.device_communicator is not None: + nccl_comm = group.device_communicator.pynccl_comm + nccl_comm.nccl_abort_comm() + + def _abort_process_group(group: GroupCoordinator): + device = torch.device("cuda") + backend = group.device_group._get_backend(device) + backend.abort() + + with ThreadPoolExecutor(max_workers=len(model_groups) * 2) as executor: + for group in model_groups: + futures.append(executor.submit(_abort_nccl_comm, group)) + futures.append(executor.submit(_abort_process_group, group)) + + done, not_done = [], [] + for future in as_completed(futures): + elapsed = time.time() - start_time + remaining = max(timeout - elapsed, 0) + if remaining == 0: + self.logger( + "Timeout while waiting for abort operations", level="warning" + ) + break + try: + # Wait at most 'remaining' seconds for this future + future.result(timeout=remaining) + done.append(future) + except TimeoutError: + not_done.append(future) + except Exception as e: + self.logger("Abort call raised exception: %s", e, level="warning") + not_done.append(future) + + # Add any futures that were not processed yet + not_done.extend([f for f in futures if f not in done and f not in not_done]) + if not_done: + self.logger( + "%d abort calls did not finish in total %s seconds", + len(not_done), + timeout, + level="warning", + ) + + self.communicator_aborted = True + success = len(not_done) == 0 + if success: + cleanup_dist_env_and_memory() + self.logger("Communicators are aborted.") + else: + self.logger("Communicators did not abort in time.", level="warning") + return success + + def _set_device_communicator_status(self, active: bool): + model_groups = get_all_model_groups() + for group in model_groups: + if group.device_communicator is not None: + nccl_comm = group.device_communicator.pynccl_comm + nccl_comm.available = active + nccl_comm.disabled = not active + + def restart_worker(self): + if self.communicator_aborted: + torch.cuda.set_device(self.device) + with set_current_vllm_config(self.vllm_config): + self.init_distributed_env_callback() + self.communicator_aborted = False + torch.cuda.synchronize() + self.clear_input_batch_callback() + self.pause_event.clear() + return True + + def _send_execution_result(self, success: bool, method_uuid: str): + msg = { + "success": success, + "method_uuid": method_uuid, + } + msg_bytes = json.dumps(msg).encode("utf-8") + self.cmd_socket.send_multipart([b"", msg_bytes]) + + def shutdown(self): + self.worker_guard_dead = True + self.cmd_socket.close() + self.zmq_ctx.term() + + class Worker(WorkerBase): def __init__( self, @@ -77,6 +268,7 @@ def __init__( is_driver_worker=is_driver_worker, ) + self.worker_guard: WorkerGuard | None = None if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils.import_utils import init_cached_hf_modules @@ -221,6 +413,7 @@ def init_device(self): # memory snapshot # This ensures NCCL buffers are allocated before we measure # available memory + start = time.time() init_worker_distributed_environment( self.vllm_config, self.rank, @@ -228,6 +421,10 @@ def init_device(self): self.local_rank, current_platform.dist_backend, ) + elapsed = time.time() - start + logger.info_once( + "init distributed environment took %.2f seconds", elapsed, scope="local" + ) # Set random seed. set_random_seed(self.model_config.seed) @@ -265,6 +462,30 @@ def init_device(self): # If usage stat is enabled, collect relevant info. report_usage_stats(self.vllm_config) + if self.vllm_config.fault_tolerance_config.enable_fault_tolerance: + with set_current_vllm_config(self.vllm_config): + init_distributed_env_callback = partial( + init_worker_distributed_environment, + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + ) + + def clear_input_batch_callback(): + input_batch = self.model_runner.input_batch + cached_req_ids = input_batch.req_id_to_index.keys() + for req_id in list(cached_req_ids): + input_batch.remove_request(req_id) + + self.worker_guard = WorkerGuard( + self.vllm_config, + self.model_runner.pause_event, + init_distributed_env_callback, + clear_input_batch_callback, + self.device, + ) + # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool # to hijack tensor allocation. def load_model(self) -> None: @@ -863,6 +1084,8 @@ def save_tensorized_model( def shutdown(self) -> None: if runner := getattr(self, "model_runner", None): runner.ensure_kv_transfer_shutdown() + if self.worker_guard is not None: + self.worker_guard.shutdown() def init_worker_distributed_environment( @@ -879,14 +1102,29 @@ def init_worker_distributed_environment( init_batch_invariance() set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + if vllm_config.fault_tolerance_config.enable_fault_tolerance: + timeout = timedelta( + seconds=vllm_config.fault_tolerance_config.gloo_comm_timeout + ) + else: + timeout = None + init_distributed_environment( - parallel_config.world_size, rank, distributed_init_method, local_rank, backend + parallel_config.world_size, + rank, + distributed_init_method, + local_rank, + backend, + vllm_config.fault_tolerance_config.enable_fault_tolerance, + timeout, ) ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, parallel_config.decode_context_parallel_size, + vllm_config.fault_tolerance_config.enable_fault_tolerance, + timeout, ) # Init ec connector here before KV caches caches init From 3b203d6cb84f3de932ebc39d040c1ea74be3d765 Mon Sep 17 00:00:00 2001 From: fangyuchu Date: Tue, 11 Nov 2025 21:39:56 +0800 Subject: [PATCH 2/6] Fix DT and zmq socket closing issues, updated names per feedback and reinitialize dp_group with new port Signed-off-by: fangyuchu --- tests/v1/engine/test_client_guard.py | 6 +-- tests/v1/engine/test_engine_core_guard.py | 3 ++ vllm/config/parallel.py | 10 ++-- vllm/distributed/utils.py | 10 ++-- vllm/v1/core/sched/interface.py | 4 +- vllm/v1/engine/core.py | 52 ++++++++----------- vllm/v1/engine/core_client.py | 52 ++++++++++--------- vllm/v1/engine/utils.py | 61 ++++++++++++----------- vllm/v1/worker/gpu_worker.py | 53 ++++++++++---------- 9 files changed, 129 insertions(+), 122 deletions(-) diff --git a/tests/v1/engine/test_client_guard.py b/tests/v1/engine/test_client_guard.py index 2448fc3c3ba9..bfa5d370e7ba 100644 --- a/tests/v1/engine/test_client_guard.py +++ b/tests/v1/engine/test_client_guard.py @@ -183,7 +183,7 @@ def test_shutdown_guard(): @pytest.mark.asyncio async def test_handle_fault_async(): engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() - engine_status_dict = create_test_thread_safe_dict({1: "Unhealthy"}) + engine_status_dict = create_test_thread_safe_dict({0: "Unhealthy"}) guard = create_client_guard(engine_exception_q, engine_status_dict) time.sleep(0.1) @@ -208,7 +208,7 @@ def response_cmd(cmd_socket): nonlocal uuid while uuid is None: time.sleep(0.1) - execute_result = {"engine_index": 1, "success": True, "method_uuid": uuid} + execute_result = {"engine_index": 0, "success": True, "method_uuid": uuid} cmd_socket.send_multipart([b"", json.dumps(execute_result).encode("utf-8")]) threading.Thread(target=receive_cmd, args=(cmd_socket,)).start() @@ -217,6 +217,6 @@ def response_cmd(cmd_socket): result = await guard.handle_fault("retry", 3) assert result is True - assert engine_status_dict[1] == "Healthy" + assert engine_status_dict[0] == "Healthy" guard.shutdown_guard() diff --git a/tests/v1/engine/test_engine_core_guard.py b/tests/v1/engine/test_engine_core_guard.py index f0c9c2098248..29f3e8e0f049 100644 --- a/tests/v1/engine/test_engine_core_guard.py +++ b/tests/v1/engine/test_engine_core_guard.py @@ -38,6 +38,7 @@ def create_engine_core_guard( guard_identity=GUARD_IDENTITY, tp_size=1, pp_size=1, + dp_size=1, ) @@ -101,6 +102,8 @@ def mock_worker_receiver(cmd_socket): param = {"timeout": 3} if instruction == "pause": param["soft_pause"] = True + elif instruction == "retry": + param["new_stateless_dp_group_port"] = 23456 serial_instruction = serialize_method_call(instruction, **param) client_socket.send_multipart( [GUARD_IDENTITY, b"", serial_instruction.encode("utf-8")] diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 466390335327..634b92101afb 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -339,8 +339,8 @@ def get_next_dp_init_port(self) -> int: def stateless_init_dp_group( self, - gloo_comm_timeout: int = 30, - enable_fault_tolerance: bool = False, + gloo_comm_timeout: int | None = None, + dp_init_port: int | None = None, ) -> ProcessGroup: # NOTE: In high-concurrency scenarios multiple processes # can pick the same (currently free) port through a race @@ -357,23 +357,25 @@ def stateless_init_dp_group( max_retries = 5 last_exc: Exception | None = None + if dp_init_port is None: + dp_init_port = self.get_next_dp_init_port() for _ in range(max_retries): try: # use gloo since the engine process might not have cuda device return stateless_init_torch_distributed_process_group( self.data_parallel_master_ip, - self.get_next_dp_init_port(), + dp_init_port, self.data_parallel_rank, self.data_parallel_size, backend=current_platform.dist_backend, gloo_comm_timeout=gloo_comm_timeout, - enable_fault_tolerance=enable_fault_tolerance, ) except DistNetworkError as e: # We only want to retry when the root cause is EADDRINUSE. if "EADDRINUSE" in str(e): logger.warning("Address already in use. Retrying with a new port.") last_exc = e + dp_init_port = self.get_next_dp_init_port() continue # try again with a new port raise e diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index f453015ac1a2..abffa84e0c04 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -463,8 +463,7 @@ def stateless_init_torch_distributed_process_group( rank: int, world_size: int, backend: str, - gloo_comm_timeout: int, - enable_fault_tolerance: bool = False, + gloo_comm_timeout: int | None, ) -> ProcessGroup: """ A replacement for `torch.distributed.init_process_group` that does not @@ -499,10 +498,11 @@ def stateless_init_torch_distributed_process_group( """ init_method = get_tcp_uri(host, port) backend = Backend(backend) # it is basically string - if enable_fault_tolerance: - timeout = timedelta(seconds=gloo_comm_timeout) - else: + + if gloo_comm_timeout is None: timeout = _get_default_timeout(backend) + else: + timeout = timedelta(seconds=gloo_comm_timeout) store, rank, world_size = next( rendezvous(init_method, rank, world_size, timeout=timeout) diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index ee8af447484e..010d4d56ac11 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -68,8 +68,8 @@ def get_grammar_bitmask( def preempt_request( self, scheduled_timestamp: float | None = None, - preempted_req: Request | None = None, - ) -> Request: + preempted_req: Optional["Request"] = None, + ) -> "Request": """ Preempt a running request and move it back to the waiting queue. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 124d220c0a74..2743caef4d1c 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -45,11 +45,9 @@ from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import ( - EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, - FinishReason, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, @@ -179,13 +177,16 @@ def run(self) -> None: self.engine_running = False except queue.Empty: pass - - if self.client_cmd_socket.closed: - self.logger("Client socket closed", level="info") + try: + has_msg, _, cmd_str = recv_router_dealer_message( + self.client_cmd_socket, + use_poller=True, + poll_timeout=poll_timeout_ms, + ) + except zmq.ZMQError: + self.logger("Socket closed, terminating EngineCoreGuard", level="info") break - has_msg, _, cmd_str = recv_router_dealer_message( - self.client_cmd_socket, use_poller=True, poll_timeout=poll_timeout_ms - ) + if has_msg: self.logger("Received cmd: %s", cmd_str, level="info") self._execute_cmd(cmd_str) @@ -204,7 +205,7 @@ def _execute_worker_method(self, method_name, timeout: int = 5, **kwargs) -> boo identities = set() for tp_rank in range(self.tp_size): for pp_rank in range(self.pp_size): - identity = f"{tp_rank}_{pp_rank}".encode() + identity = f"{pp_rank}_{tp_rank}".encode() identities.add(identity) method_uuid = broadcast_instruction( @@ -286,10 +287,10 @@ def pause(self, timeout: int = 1, soft_pause: bool = True) -> bool: success = True if not soft_pause: # abort the communicators - self._stop_worker_execution(soft_pause=False, timeout=timeout) + success = self._stop_worker_execution(soft_pause=False, timeout=timeout) return success - def retry(self, timeout: int = 1): + def retry(self, new_stateless_dp_group_port: int, timeout: int = 1): """ Handle the retry instruction from the ClientGuard. This instruction tells the EngineCore to continue its busy loop @@ -297,7 +298,7 @@ def retry(self, timeout: int = 1): """ start_time = time.monotonic() - success = self._execute_worker_method("restart_worker", timeout=timeout) + success = self._execute_worker_method("restore_worker", timeout=timeout) if not success: return success @@ -305,7 +306,11 @@ def retry(self, timeout: int = 1): # If the Gloo communication times out # the data parallel group (dp_group) needs to be reinitialized command = "reinit_dp_group_on_fault_tolerance" - self.cmd_q.put(serialize_method_call(command)) + self.cmd_q.put( + serialize_method_call( + command, new_stateless_dp_group_port=new_stateless_dp_group_port + ) + ) else: self.cmd_q.put(None) @@ -1473,21 +1478,6 @@ def process_output_sockets( # Limit the number of buffers to reuse. reuse_buffers.append(buffer) - def engine_finish_requests(self): - assert isinstance(self.scheduler, V1Scheduler) - engine_finish_outputs = EngineCoreOutputs() - engine_finish_outputs.engine_index = self.engine_index - for request_id in list(self.scheduler.requests.keys()): - self.scheduler.finish_requests(request_id, RequestStatus.FINISHED_ABORTED) - engine_finish_outputs.outputs.append( - EngineCoreOutput( - request_id=request_id, - finish_reason=FinishReason.ABORT, - new_token_ids=[], - ) - ) - self.output_queue.put((0, engine_finish_outputs)) - def shutdown(self): super().shutdown() if self.vllm_config.fault_tolerance_config.enable_fault_tolerance: @@ -1549,7 +1539,6 @@ def _init_data_parallel(self, vllm_config: VllmConfig): self.dp_rank = dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group( vllm_config.fault_tolerance_config.gloo_comm_timeout, - vllm_config.fault_tolerance_config.enable_fault_tolerance, ) def shutdown(self): @@ -1658,12 +1647,13 @@ def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) - def reinit_dp_group_on_fault_tolerance(self): + def reinit_dp_group_on_fault_tolerance(self, new_stateless_dp_group_port: int): stateless_destroy_torch_distributed_process_group(self.dp_group) self.dp_group = self.vllm_config.parallel_config.stateless_init_dp_group( self.vllm_config.fault_tolerance_config.gloo_comm_timeout, - self.vllm_config.fault_tolerance_config.enable_fault_tolerance, + new_stateless_dp_group_port, ) + self.step_counter = 0 def reinitialize_distributed( self, reconfig_request: ReconfigureDistributedRequest diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 626ff97354e8..6104e0a0631f 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -431,31 +431,35 @@ def fault_receiver(self): engine_core component. It is designed to run continuously to ensure no critical error information from the engine core is missed. """ - while True: - _, sender_identity, message = recv_router_dealer_message( - self.fault_receiver_socket - ) - if self.client_guard_dead: - self.logger("client guard dead, stop receiving fault") - break - assert message is not None, ( - "message should not be None at fault tolerance scenario" - ) + while not self.client_guard_dead: + try: + _, sender_identity, message = recv_router_dealer_message( + self.fault_receiver_socket + ) + assert message is not None, ( + "message should not be None at fault tolerance scenario" + ) - fault_info = FaultInfo.from_json(message) - self.engine_exception_q.put_nowait(fault_info) - engine_status = "Dead" if "dead" in fault_info.type else "Unhealthy" - self.engine_status_dict[int(fault_info.engine_id)] = engine_status - self.fault_pub_socket.send_string( - f"vllm_fault|{json.dumps(self.engine_status_dict.to_dict())}" - ) - # TODO Asynchronous issuance of pause commands and design of engine - # core status - # Pause healthy engines on fault. - # Pause will be invoked again during fault-tolerance handling, - # so it's unnecessary to track whether all engines are currently - # paused. - self.fault_handler.submit_fault("pause", 5, soft_pause=False) + fault_info = FaultInfo.from_json(message) + self.engine_exception_q.put_nowait(fault_info) + engine_status = "Dead" if "dead" in fault_info.type else "Unhealthy" + self.engine_status_dict[int(fault_info.engine_id)] = engine_status + self.fault_pub_socket.send_string( + f"vllm_fault|{json.dumps(self.engine_status_dict.to_dict())}" + ) + + # Pause healthy engines on fault. + # Pause will be invoked again during fault-tolerance handling, + # so it's unnecessary to track whether all engines are currently + # paused. + self.fault_handler.submit_fault("pause", 5, soft_pause=False) + except zmq.ZMQError: + # Socket was closed during polling, exit loop. + self.logger( + "Fault receiver socket closed, stopping thread.", level="info" + ) + break + self.logger("Fault receiver thread has stopped.") def shutdown_guard(self): self.client_guard_dead = True diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index aa9e97edd58d..22064048ed71 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -27,6 +27,7 @@ from vllm.ray.ray_env import get_env_vars_to_copy from vllm.utils.collection_utils import ThreadSafeDict from vllm.utils.network_utils import ( + get_open_port, get_open_zmq_ipc_path, make_zmq_socket, recv_router_dealer_message, @@ -36,7 +37,7 @@ from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.exceptions import FaultInfo from vllm.v1.executor import Executor -from vllm.v1.serial_utils import serialize_method_call +from vllm.v1.serial_utils import run_method, serialize_method_call from vllm.v1.utils import get_engine_client_zmq_addr, shutdown if TYPE_CHECKING: @@ -1399,40 +1400,44 @@ async def _dispatcher(self): if fut: fut.set_exception(e) - async def _handle_fault_internal( - self, instruction: str, timeout: int, **kwargs - ) -> bool: - if instruction == "retry" and "Dead" in self.engine_status_dict.values(): + def retry(self, **kwargs): + if "Dead" in self.engine_status_dict.values(): self.logger( "engine_core dead unexpectedly, retry is impossible," "shutdown will be performed", level="info", ) - return False + return False, set(), kwargs + + target_engines = set(self.engine_identity_to_index.keys()) + kwargs["new_stateless_dp_group_port"] = get_open_port() + return True, target_engines, kwargs + + def pause(self, **kwargs): + self.logger( + "Pause operation is best-effort only. Due to the complexity of " + "collective communications (e.g., timing dependencies and " + "synchronization barriers), pausing may not always succeed. If " + "the process remains unresponsive or collective operations " + "cannot be interrupted, consider shutting down and restarting " + "the instance.", + level="warning", + ) - if instruction == "pause": - logger.warning( - "Pause operation is best-effort only. Due to the complexity of " - "collective communications (e.g., timing dependencies and " - "synchronization barriers), pausing may not always succeed. If " - "the process remains unresponsive or collective operations " - "cannot be interrupted, consider shutting down and restarting " - "the instance." - ) + alive_engines = { + identity + for identity, index in self.engine_identity_to_index.items() + if self.engine_status_dict.get(index) != "Dead" + } + return True, alive_engines, kwargs - dead_engine_indices = { - index - for index, status in self.engine_status_dict.items() - if status == "Dead" - } - - target_engines = { - identity - for identity, index in self.engine_identity_to_index.items() - if index not in dead_engine_indices - } - else: - target_engines = set(self.engine_identity_to_index.keys()) + async def _handle_fault_internal( + self, instruction: str, timeout: int, **kwargs + ) -> bool: + success, target_engines, kwargs = run_method(self, instruction, (), kwargs) + + if not success: + return False if timeout is not None: kwargs["timeout"] = timeout diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index ad50474ecae3..892b9740c1b3 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -91,7 +91,7 @@ def __init__( self.init_distributed_env_callback = init_distributed_env_callback self.clear_input_batch_callback = clear_input_batch_callback self.device = device - identity = f"{self.tp_rank}_{self.pp_rank}".encode() + identity = f"{self.pp_rank}_{self.tp_rank}".encode() worker_cmd_addr = vllm_config.fault_tolerance_config.engine_core_cmd_addr self.cmd_socket = make_zmq_socket( ctx=self.zmq_ctx, @@ -109,7 +109,7 @@ def __init__( ).start() def _make_worker_logger(self): - prefix = f"[WorkerGuard_dp{self.dp_rank}_tp{self.tp_rank}_pp{self.pp_rank}] " + prefix = f"[WorkerGuard_dp{self.dp_rank}_pp{self.pp_rank}_tp{self.tp_rank}] " def log(msg, *args, level="info", **kwargs): """ @@ -123,29 +123,32 @@ def log(msg, *args, level="info", **kwargs): def run(self): """Run the message receiving loop and handle control commands""" torch.cuda.set_device(self.device) - while True: - # Use blocking receive - will wait until a message arrives - has_msg, _, cmd_str = recv_router_dealer_message(self.cmd_socket) - if self.worker_guard_dead: - self.logger("Worker guard dead, exiting") + while not self.worker_guard_dead: + try: + # Use blocking receive - will wait until a message arrives + has_msg, _, cmd_str = recv_router_dealer_message(self.cmd_socket) + if has_msg: + assert cmd_str is not None + method, method_uuid, params = deserialize_method_call(cmd_str) + self.logger("Executing command: %s", method) + try: + success = run_method(self, method, args=(), kwargs=params) + except Exception as e: + self.logger( + "Error executing method %s: %s %s\n Call Stack:\n %s", + method, + type(e).__name__, + e, + "".join(traceback.format_tb(e.__traceback__)), + level="error", + ) + success = False + self._send_execution_result(success, method_uuid) + except zmq.ZMQError: + # Socket was closed, exit loop. + self.logger("Command socket closed, stopping thread.", level="info") break - if has_msg: - assert cmd_str is not None - method, method_uuid, params = deserialize_method_call(cmd_str) - self.logger("Executing command: %s", method) - try: - success = run_method(self, method, args=(), kwargs=params) - except Exception as e: - self.logger( - "Error executing method %s: %s %s\n Call Stack:\n %s", - method, - type(e).__name__, - e, - "".join(traceback.format_tb(e.__traceback__)), - level="error", - ) - success = False - self._send_execution_result(success, method_uuid) + self.logger("Worker guard thread has stopped.") def pause_by_signal(self): self._set_device_communicator_status(False) @@ -226,7 +229,7 @@ def _set_device_communicator_status(self, active: bool): nccl_comm.available = active nccl_comm.disabled = not active - def restart_worker(self): + def restore_worker(self): if self.communicator_aborted: torch.cuda.set_device(self.device) with set_current_vllm_config(self.vllm_config): From 20a5a5ae21a71bfb97f0102a07c01763a3ec6a2b Mon Sep 17 00:00:00 2001 From: fangyuchu Date: Wed, 12 Nov 2025 12:06:45 +0800 Subject: [PATCH 3/6] Improve documentation and logging in API server Signed-off-by: fangyuchu --- tests/v1/engine/test_client_guard.py | 4 ++-- vllm/engine/arg_utils.py | 6 ++---- vllm/entrypoints/api_server.py | 31 ++++++++++++++++++++------- vllm/entrypoints/openai/api_server.py | 29 +++++++++++++------------ 4 files changed, 42 insertions(+), 28 deletions(-) diff --git a/tests/v1/engine/test_client_guard.py b/tests/v1/engine/test_client_guard.py index bfa5d370e7ba..401af6bf7756 100644 --- a/tests/v1/engine/test_client_guard.py +++ b/tests/v1/engine/test_client_guard.py @@ -23,8 +23,7 @@ def create_test_thread_safe_dict(initial_data=None): if initial_data is None: initial_data = {1: "Healthy"} - if initial_data is None: - initial_data = {1: "Healthy"} + tsd = ThreadSafeDict() if initial_data: for k, v in initial_data.items(): @@ -219,4 +218,5 @@ def response_cmd(cmd_socket): assert result is True assert engine_status_dict[0] == "Healthy" + cmd_socket.close() guard.shutdown_guard() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 54e7d048de93..ca072fbca6e9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1775,10 +1775,8 @@ def create_engine_config( fault_tolerance_config = FaultToleranceConfig( enable_fault_tolerance=self.enable_fault_tolerance, engine_recovery_timeout=self.engine_recovery_timeout, - internal_fault_report_port=self.internal_fault_report_port - or FaultToleranceConfig.internal_fault_report_port, - external_fault_notify_port=self.external_fault_notify_port - or FaultToleranceConfig.external_fault_notify_port, + internal_fault_report_port=self.internal_fault_report_port, + external_fault_notify_port=self.external_fault_notify_port, gloo_comm_timeout=self.gloo_comm_timeout, ) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index b2034bbada70..ce81ceffb604 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -15,7 +15,7 @@ from collections.abc import AsyncGenerator from typing import Any -from fastapi import FastAPI, Request +from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, Response, StreamingResponse import vllm.envs as envs @@ -57,23 +57,38 @@ async def generate(request: Request) -> Response: @app.post("/fault_tolerance/apply") -async def send_fault_tolerance_instruction(request: Request) -> Response: - """Generate completion for the request. +async def process_fault_tolerance_instruction(request: Request) -> Response: + """Apply fault tolerance instructions to the engine. + + This endpoint handles fault recovery operations such as retrying operations. The request should be a JSON object with the following fields: - - prompt: the prompt to use for the generation. - - stream: whether to stream the results or not. - - other fields: the sampling parameters (See `SamplingParams` for details). + - fault_tolerance_instruction: The name of fault tolerance method. + - fault_tolerance_timeout: Timeout in seconds for the operation to complete. + - fault_tolerance_params: dict, optional. Additional dynamic parameters for + the fault tolerance operation. """ request_dict = await request.json() fault_tolerance_instruction = request_dict.get("fault_tolerance_instruction") fault_tolerance_timeout = request_dict.get("fault_tolerance_timeout") - kwargs = request_dict.get("kwargs", {}) + kwargs = request_dict.get("fault_tolerance_params", {}) assert engine is not None - return await engine.handle_fault( + success = await engine.handle_fault( fault_tolerance_instruction, fault_tolerance_timeout, **kwargs ) + if success: + return JSONResponse( + status_code=200, + content={"message": "Instruction executed successfully."}, + ) + + logger.error("Fault tolerance operation failed. Shutting down the engine.") + engine.shutdown() + raise HTTPException( + status_code=400, + detail="Instruction execution failed.", + ) @app.get("/fault_tolerance/status") diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a53bc8fe8cf3..c296109a05ef 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1235,7 +1235,7 @@ async def is_scaling_elastic_ep(raw_request: Request): HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, }, ) -async def send_fault_tolerance_instruction(raw_request: Request): +async def process_fault_tolerance_instruction(raw_request: Request): try: body = await raw_request.json() except json.JSONDecodeError as e: @@ -1250,39 +1250,40 @@ async def send_fault_tolerance_instruction(raw_request: Request): if fault_tolerance_instruction is None or fault_tolerance_timeout is None: raise HTTPException( status_code=400, - detail="fault_tolerance_instruction and" - " fault_tolerance_timeout is required", + detail="Both 'fault_tolerance_instruction' and " + "'fault_tolerance_timeout' are required.", ) if not isinstance(fault_tolerance_instruction, str): raise HTTPException( - status_code=400, detail="fault_tolerance_instruction must be a str" + status_code=400, detail="'fault_tolerance_instruction' must be a string." ) - # Currently, only two types of instructions are supported: [pause, retry]. - # Additional descaling instructions will be supported in future updates. + # Supported instructions: ["pause", "retry"]. + # More instruction types may be added in future updates. elif fault_tolerance_instruction not in ["pause", "retry"]: raise HTTPException( - status_code=400, detail="not a valid fault_tolerance_instruction" + status_code=400, detail="Invalid 'fault_tolerance_instruction' value." ) if not isinstance(fault_tolerance_timeout, int) or fault_tolerance_timeout <= 0: raise HTTPException( - status_code=400, detail="fault_tolerance_timeout must be a positive integer" + status_code=400, + detail="'fault_tolerance_timeout' must be a positive integer.", ) try: - execute_result = await client.handle_fault( + success = await client.handle_fault( fault_tolerance_instruction, fault_tolerance_timeout, **dynamic_fault_tolerance_params, ) - if execute_result: + if success: return JSONResponse( { - "message": "instruction has been executed successfully", + "message": "Instruction executed successfully.", } ) else: - logger.error("Fault tolerance failed, shutdown the app.") + logger.error("Fault tolerance failed. Shutting down the application.") client.shutdown() raise HTTPException( status_code=400, @@ -1290,10 +1291,10 @@ async def send_fault_tolerance_instruction(raw_request: Request): ) except Exception as e: - logger.error("Handle fault failed: %s", e) + logger.error("Failed to handle fault: %s", e) raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail="Handle fault failed", + detail="Failed to handle fault.", ) from e From b5be237c05361411e5b849905c7e2dc6fd0e7da1 Mon Sep 17 00:00:00 2001 From: fangyuchu Date: Thu, 13 Nov 2025 14:19:56 +0800 Subject: [PATCH 4/6] Fix hanging issue in DT; fix hang when aborting communicators from Python side; use queue.Queue for engine_exception_q Signed-off-by: fangyuchu --- tests/v1/engine/test_client_guard.py | 29 ++++++------ tests/v1/engine/test_engine_core_guard.py | 4 +- vllm/v1/engine/core.py | 23 ++++++--- vllm/v1/engine/core_client.py | 16 ++----- vllm/v1/engine/utils.py | 31 ++++-------- vllm/v1/worker/gpu_worker.py | 57 ++++++++++------------- 6 files changed, 75 insertions(+), 85 deletions(-) diff --git a/tests/v1/engine/test_client_guard.py b/tests/v1/engine/test_client_guard.py index 401af6bf7756..20d363876303 100644 --- a/tests/v1/engine/test_client_guard.py +++ b/tests/v1/engine/test_client_guard.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio import json +import queue import threading import time from unittest.mock import AsyncMock @@ -32,21 +32,20 @@ def create_test_thread_safe_dict(initial_data=None): def create_client_guard( - engine_exception_q: asyncio.Queue, engine_status_dict: ThreadSafeDict[int, str] + engine_exception_q: queue.Queue, engine_status_dict: ThreadSafeDict[int, str] ): return ClientGuard( fault_receiver_addr=FAULT_RECEIVER_ADDR, cmd_addr=CMD_ADDR, engine_registry=[b"engine_identity"], engine_exception_q=engine_exception_q, - engine_exception_q_lock=asyncio.Lock(), fault_pub_addr=FAULT_PUB_ADDR, engine_status_dict=engine_status_dict, ) def test_client_guard_initialization(): - engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() + engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) guard = create_client_guard(engine_exception_q, engine_status_dict) @@ -64,7 +63,7 @@ def test_client_guard_initialization(): @pytest.mark.asyncio async def test_handle_fault(): - engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() + engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) guard = create_client_guard(engine_exception_q, engine_status_dict) @@ -82,7 +81,7 @@ async def test_handle_fault(): def test_fault_receiver(): - engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() + engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) guard = create_client_guard(engine_exception_q, engine_status_dict) @@ -97,7 +96,7 @@ def send_test_message(): socket.close() ctx.term() - sender_thread = threading.Thread(target=send_test_message) + sender_thread = threading.Thread(target=send_test_message, daemon=True) sender_thread.start() def check_published_message(): @@ -114,7 +113,7 @@ def check_published_message(): assert prefix == FAULT_PUB_TOPIC assert json.loads(data) == {"1": "Dead"} - check_thread = threading.Thread(target=check_published_message) + check_thread = threading.Thread(target=check_published_message, daemon=True) check_thread.start() time.sleep(0.1) @@ -130,7 +129,7 @@ def check_published_message(): def test_fault_receiver_unhealthy(): - engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() + engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) guard = create_client_guard(engine_exception_q, engine_status_dict) @@ -145,7 +144,7 @@ def send_unhealthy_message(): socket.close() ctx.term() - threading.Thread(target=send_unhealthy_message).start() + threading.Thread(target=send_unhealthy_message, daemon=True).start() time.sleep(0.1) assert engine_status_dict[1] == "Unhealthy" @@ -154,7 +153,7 @@ def send_unhealthy_message(): def test_shutdown_guard(): - engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() + engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) guard = create_client_guard(engine_exception_q, engine_status_dict) @@ -181,7 +180,7 @@ def test_shutdown_guard(): @pytest.mark.asyncio async def test_handle_fault_async(): - engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() + engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() engine_status_dict = create_test_thread_safe_dict({0: "Unhealthy"}) guard = create_client_guard(engine_exception_q, engine_status_dict) @@ -190,6 +189,7 @@ async def test_handle_fault_async(): cmd_socket = ctx.socket(zmq.DEALER) cmd_socket.setsockopt(zmq.IDENTITY, b"engine_identity") cmd_socket.connect(CMD_ADDR) + time.sleep(0.1) uuid = None @@ -210,8 +210,8 @@ def response_cmd(cmd_socket): execute_result = {"engine_index": 0, "success": True, "method_uuid": uuid} cmd_socket.send_multipart([b"", json.dumps(execute_result).encode("utf-8")]) - threading.Thread(target=receive_cmd, args=(cmd_socket,)).start() - threading.Thread(target=response_cmd, args=(cmd_socket,)).start() + threading.Thread(target=receive_cmd, args=(cmd_socket,), daemon=True).start() + threading.Thread(target=response_cmd, args=(cmd_socket,), daemon=True).start() result = await guard.handle_fault("retry", 3) @@ -219,4 +219,5 @@ def response_cmd(cmd_socket): assert engine_status_dict[0] == "Healthy" cmd_socket.close() + ctx.term() guard.shutdown_guard() diff --git a/tests/v1/engine/test_engine_core_guard.py b/tests/v1/engine/test_engine_core_guard.py index 29f3e8e0f049..20c4d20fc617 100644 --- a/tests/v1/engine/test_engine_core_guard.py +++ b/tests/v1/engine/test_engine_core_guard.py @@ -113,7 +113,9 @@ def mock_worker_receiver(cmd_socket): elif instruction == "retry": busy_loop_active.set() - threading.Thread(target=mock_worker_receiver, args=(worker_cmd_socket,)).start() + threading.Thread( + target=mock_worker_receiver, args=(worker_cmd_socket,), daemon=True + ).start() time.sleep(0.1) identity, _, msg = client_socket.recv_multipart() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 2743caef4d1c..dd03fdf4413b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -198,7 +198,9 @@ def _stop_worker_execution(self, soft_pause: bool, timeout: int = 2) -> bool: pause_method = "pause_by_abort_communicators" self.communicator_aborted = True - success = self._execute_worker_method(pause_method, timeout=timeout) + success = self._execute_worker_method( + pause_method, timeout=timeout, worker_timeout=timeout + ) return success def _execute_worker_method(self, method_name, timeout: int = 5, **kwargs) -> bool: @@ -237,18 +239,19 @@ def _execute_cmd(self, cmd_str): try: success = run_method(self, method, args=(), kwargs=method_params) self.logger("Command (%s) succeeded: %s", method, success, level="info") - + reason = None except Exception as e: self.logger( - "Error executing method %s: %s %s", + "Error executing method %s: %s, %s", method, type(e).__name__, e, level="error", ) success = False + reason = f"{type(e).__name__}: {e}" - self._send_execution_result(success, method_uuid) + self._send_execution_result(success, method_uuid, reason) def pause(self, timeout: int = 1, soft_pause: bool = True) -> bool: """ @@ -296,6 +299,9 @@ def retry(self, new_stateless_dp_group_port: int, timeout: int = 1): This instruction tells the EngineCore to continue its busy loop after being suspended due to an exception. """ + if self.engine_running: + return True + start_time = time.monotonic() success = self._execute_worker_method("restore_worker", timeout=timeout) @@ -319,14 +325,19 @@ def retry(self, new_stateless_dp_group_port: int, timeout: int = 1): remaining_timeout = max(0, timeout - elapsed) success = self.busy_loop_active.wait(timeout=remaining_timeout) self.engine_running = success + assert self.cmd_q.empty(), "cmd_q must be empty after execution" return success - def _send_execution_result(self, success: bool, method_uuid: str): + def _send_execution_result( + self, success: bool, method_uuid: str, reason: str | None + ): msg = { "engine_index": self.engine_index, "success": success, "method_uuid": method_uuid, } + if not success and reason is not None: + msg["reason"] = reason msg_bytes = json.dumps(msg).encode("utf-8") self.client_cmd_socket.send_multipart([b"", msg_bytes]) @@ -936,7 +947,7 @@ def __init__( # Track whether the busy loop is currently active. self.busy_loop_active = threading.Event() self.fault_signal_q: queue.Queue[Exception] = queue.Queue() - self.cmd_q: queue.Queue[str | None] = queue.Queue() + self.cmd_q: queue.Queue[str | None] = queue.Queue(maxsize=1) self.engine_recovery_timeout = ft_config.engine_recovery_timeout engine_core_guard_ids = addresses.engine_core_guard_identities assert engine_core_guard_ids is not None diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 6104e0a0631f..f72a825f568b 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -351,8 +351,7 @@ def __init__( fault_receiver_addr: str, cmd_addr: str, engine_registry: list[bytes], - engine_exception_q: asyncio.Queue[FaultInfo], - engine_exception_q_lock: asyncio.Lock, + engine_exception_q: queue.Queue[FaultInfo], fault_pub_addr: str, engine_status_dict: ThreadSafeDict[int, str], ): @@ -372,9 +371,7 @@ def __init__( ctx=self.zmq_ctx, path=fault_pub_addr, socket_type=zmq.PUB, bind=True ) - self.engine_exception_q: asyncio.Queue[FaultInfo] = engine_exception_q - - self.engine_exception_q_lock = engine_exception_q_lock + self.engine_exception_q: queue.Queue[FaultInfo] = engine_exception_q self.engine_status_dict: ThreadSafeDict[int, str] = engine_status_dict @@ -382,7 +379,6 @@ def __init__( self.cmd_socket, self.engine_registry, self.engine_exception_q, - self.engine_exception_q_lock, self.engine_status_dict, ) @@ -449,7 +445,7 @@ def fault_receiver(self): ) # Pause healthy engines on fault. - # Pause will be invoked again during fault-tolerance handling, + # Pause can be invoked again during fault-tolerance handling, # so it's unnecessary to track whether all engines are currently # paused. self.fault_handler.submit_fault("pause", 5, soft_pause=False) @@ -684,7 +680,7 @@ def __init__( self.start_engine_core_monitor() if vllm_config.fault_tolerance_config.enable_fault_tolerance: - self.engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue() + self.engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() assert addresses.fault_report_addr is not None, ( "addresses.fault_report_addr should not be None at fault tolerance" " scenario" @@ -694,7 +690,6 @@ def __init__( " scenario" ) self.engine_registry = addresses.engine_core_guard_identities - self.engine_exception_q_lock = asyncio.Lock() assert self.engine_registry is not None assert addresses.fault_pub_socket_addr is not None, ( "addresses.fault_pub_socket_addr should not be None at" @@ -708,7 +703,6 @@ def __init__( addresses.client_cmd_addr, self.engine_registry, self.engine_exception_q, - self.engine_exception_q_lock, addresses.fault_pub_socket_addr, self.engine_status_dict, ) @@ -758,7 +752,7 @@ def monitor_actors(): if not all_actors: return while True: - for actor in all_actors: + for actor in all_actors[:]: actor_id = actor._actor_id.hex() if actor in engine_manager.local_engine_actors: actor_index = engine_manager.local_engine_actors.index(actor) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 22064048ed71..6ac4107c9e7b 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -5,6 +5,7 @@ import json import multiprocessing import os +import queue import time import uuid import weakref @@ -1233,20 +1234,6 @@ def generate_identity_group(peer1, peer2, use, n): return identitys -async def get_queue_snapshot(queue: asyncio.Queue, queue_lock: asyncio.Lock) -> list: - """Thread-safe snapshot of the exception queue.""" - async with queue_lock: - items = [] - # get item at first - while not queue.empty(): - item = queue.get_nowait() - items.append(item) - # put item into queue again - for item in items: - queue.put_nowait(item) - return items - - def broadcast_instruction( cmd_socket, target_identities: set[bytes] | list[bytes], @@ -1355,13 +1342,11 @@ def __init__( self, cmd_socket: zmq.Socket, client_cmd_registry: list[bytes], - engine_exception_q: asyncio.Queue[FaultInfo], - engine_exception_q_lock: asyncio.Lock, + engine_exception_q: queue.Queue[FaultInfo], engine_status_dict: ThreadSafeDict[int, str], ) -> None: self.cmd_socket = cmd_socket self.engine_exception_q = engine_exception_q - self.engine_exception_q_lock = engine_exception_q_lock self.engine_status_dict: ThreadSafeDict[int, str] = engine_status_dict self.engine_identity_to_index: dict[bytes, int] = { identity: i for i, identity in enumerate(client_cmd_registry) @@ -1403,9 +1388,8 @@ async def _dispatcher(self): def retry(self, **kwargs): if "Dead" in self.engine_status_dict.values(): self.logger( - "engine_core dead unexpectedly, retry is impossible," - "shutdown will be performed", - level="info", + "Engine core is dead; retry won't work.", + level="warning", ) return False, set(), kwargs @@ -1481,7 +1465,12 @@ async def _handle_fault_internal( if instruction == "retry" and all_success: for engine_index, _ in self.engine_status_dict.items(): self.engine_status_dict[engine_index] = "Healthy" - # todo: should we also clear the engine_exception_q here? + while not self.engine_exception_q.empty(): + try: + self.engine_exception_q.get_nowait() + except queue.Empty: + break + return all_success async def handle_fault(self, instruction: str, timeout: int, **kwargs) -> bool: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 892b9740c1b3..a9f684e2e977 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -9,7 +9,7 @@ import time import traceback from collections.abc import Callable -from concurrent.futures import ThreadPoolExecutor, TimeoutError, as_completed +from concurrent.futures import FIRST_EXCEPTION, ThreadPoolExecutor, wait from contextlib import AbstractContextManager, nullcontext from datetime import timedelta from functools import partial @@ -130,7 +130,8 @@ def run(self): if has_msg: assert cmd_str is not None method, method_uuid, params = deserialize_method_call(cmd_str) - self.logger("Executing command: %s", method) + self.logger("Executing command: %s, %s", method, params) + try: success = run_method(self, method, args=(), kwargs=params) except Exception as e: @@ -156,7 +157,7 @@ def pause_by_signal(self): self.logger("Pause signal sent.") return True - def pause_by_abort_communicators(self, timeout=5): + def pause_by_abort_communicators(self, worker_timeout=5): """ Abort all NCCL communicators and process groups in parallel using a thread pool. """ @@ -166,7 +167,6 @@ def pause_by_abort_communicators(self, timeout=5): torch.cuda.set_device(self.device) model_groups = get_all_model_groups() futures = [] - start_time = time.time() def _abort_nccl_comm(group: GroupCoordinator): if group.device_communicator is not None: @@ -178,48 +178,41 @@ def _abort_process_group(group: GroupCoordinator): backend = group.device_group._get_backend(device) backend.abort() - with ThreadPoolExecutor(max_workers=len(model_groups) * 2) as executor: + executor = ThreadPoolExecutor(max_workers=len(model_groups) * 2) + try: for group in model_groups: futures.append(executor.submit(_abort_nccl_comm, group)) futures.append(executor.submit(_abort_process_group, group)) - done, not_done = [], [] - for future in as_completed(futures): - elapsed = time.time() - start_time - remaining = max(timeout - elapsed, 0) - if remaining == 0: - self.logger( - "Timeout while waiting for abort operations", level="warning" - ) - break - try: - # Wait at most 'remaining' seconds for this future - future.result(timeout=remaining) - done.append(future) - except TimeoutError: - not_done.append(future) - except Exception as e: - self.logger("Abort call raised exception: %s", e, level="warning") - not_done.append(future) - - # Add any futures that were not processed yet - not_done.extend([f for f in futures if f not in done and f not in not_done]) + done, not_done = wait( + futures, timeout=worker_timeout, return_when=FIRST_EXCEPTION + ) if not_done: self.logger( "%d abort calls did not finish in total %s seconds", len(not_done), - timeout, + worker_timeout, level="warning", ) + finally: + executor.shutdown(wait=False, cancel_futures=True) - self.communicator_aborted = True - success = len(not_done) == 0 - if success: + exception_count = sum(1 for f in done if f.exception() is not None) + self.communicator_aborted = len(not_done) == 0 and exception_count == 0 + if self.communicator_aborted: cleanup_dist_env_and_memory() self.logger("Communicators are aborted.") else: - self.logger("Communicators did not abort in time.", level="warning") - return success + self.logger( + "Communicator abort failed: %d NCCL comm abort calls timed out," + " %d tasks threw exceptions. This may leave NCCL communicators " + "or process groups in an inconsistent state. Subsequent " + "distributed operations could be unsafe.", + len(not_done), + exception_count, + level="error", + ) + return self.communicator_aborted def _set_device_communicator_status(self, active: bool): model_groups = get_all_model_groups() From 9485d2b938a3a70493f148f9965aade02a2e3310 Mon Sep 17 00:00:00 2001 From: fangyuchu Date: Thu, 13 Nov 2025 23:11:23 +0800 Subject: [PATCH 5/6] Refactor fault tolerance modules by renaming classes to Sentinel and converting engine_registry to a dict Signed-off-by: fangyuchu --- tests/v1/engine/test_client_guard.py | 68 +++++++++++------------ tests/v1/engine/test_engine_core_guard.py | 42 +++++++------- vllm/config/fault_tolerance.py | 4 +- vllm/v1/engine/core.py | 58 ++++++++++--------- vllm/v1/engine/core_client.py | 34 ++++++------ vllm/v1/engine/utils.py | 32 ++++++----- vllm/v1/worker/gpu_worker.py | 22 ++++---- 7 files changed, 135 insertions(+), 125 deletions(-) diff --git a/tests/v1/engine/test_client_guard.py b/tests/v1/engine/test_client_guard.py index 20d363876303..a64ccc7ed420 100644 --- a/tests/v1/engine/test_client_guard.py +++ b/tests/v1/engine/test_client_guard.py @@ -11,7 +11,7 @@ import zmq from vllm.utils.collection_utils import ThreadSafeDict -from vllm.v1.engine.core_client import ClientGuard +from vllm.v1.engine.core_client import ClientSentinel from vllm.v1.engine.utils import FaultHandler, FaultInfo FAULT_RECEIVER_ADDR = "tcp://127.0.0.1:8844" @@ -31,59 +31,59 @@ def create_test_thread_safe_dict(initial_data=None): return tsd -def create_client_guard( +def create_client_sentinel( engine_exception_q: queue.Queue, engine_status_dict: ThreadSafeDict[int, str] ): - return ClientGuard( + return ClientSentinel( fault_receiver_addr=FAULT_RECEIVER_ADDR, cmd_addr=CMD_ADDR, - engine_registry=[b"engine_identity"], + engine_registry={0: b"engine_identity"}, engine_exception_q=engine_exception_q, fault_pub_addr=FAULT_PUB_ADDR, engine_status_dict=engine_status_dict, ) -def test_client_guard_initialization(): +def test_client_sentinel_initialization(): engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) - guard = create_client_guard(engine_exception_q, engine_status_dict) + sentinel = create_client_sentinel(engine_exception_q, engine_status_dict) - assert guard.engine_registry == [b"engine_identity"] - assert not guard.client_guard_dead - assert isinstance(guard.fault_handler, FaultHandler) - assert guard.engine_exception_q is engine_exception_q + assert sentinel.engine_registry[0] == b"engine_identity" + assert not sentinel.client_sentinel_dead + assert isinstance(sentinel.fault_handler, FaultHandler) + assert sentinel.engine_exception_q is engine_exception_q - assert guard.fault_receiver_socket.type == zmq.ROUTER - assert guard.cmd_socket.type == zmq.ROUTER - assert guard.fault_pub_socket.type == zmq.PUB + assert sentinel.fault_receiver_socket.type == zmq.ROUTER + assert sentinel.cmd_socket.type == zmq.ROUTER + assert sentinel.fault_pub_socket.type == zmq.PUB - guard.shutdown_guard() + sentinel.shutdown_sentinel() @pytest.mark.asyncio async def test_handle_fault(): engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) - guard = create_client_guard(engine_exception_q, engine_status_dict) + sentinel = create_client_sentinel(engine_exception_q, engine_status_dict) engine_exception_q.put_nowait( FaultInfo(engine_id="1", message="test exception", type="test") ) - guard.fault_handler.handle_fault = AsyncMock(return_value=True) + sentinel.fault_handler.handle_fault = AsyncMock(return_value=True) - result = await guard.handle_fault("pause", 5) + result = await sentinel.handle_fault("pause", 5) assert result is True - guard.fault_handler.handle_fault.assert_awaited_once_with("pause", 5) + sentinel.fault_handler.handle_fault.assert_awaited_once_with("pause", 5) - guard.shutdown_guard() + sentinel.shutdown_sentinel() def test_fault_receiver(): engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) - guard = create_client_guard(engine_exception_q, engine_status_dict) + sentinel = create_client_sentinel(engine_exception_q, engine_status_dict) def send_test_message(): ctx = zmq.Context() @@ -125,13 +125,13 @@ def check_published_message(): assert engine_status_dict[1] == "Dead" - guard.shutdown_guard() + sentinel.shutdown_sentinel() def test_fault_receiver_unhealthy(): engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) - guard = create_client_guard(engine_exception_q, engine_status_dict) + sentinel = create_client_sentinel(engine_exception_q, engine_status_dict) def send_unhealthy_message(): ctx = zmq.Context() @@ -149,22 +149,22 @@ def send_unhealthy_message(): assert engine_status_dict[1] == "Unhealthy" - guard.shutdown_guard() + sentinel.shutdown_sentinel() -def test_shutdown_guard(): +def test_shutdown_sentinel(): engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() engine_status_dict = create_test_thread_safe_dict({1: "Healthy"}) - guard = create_client_guard(engine_exception_q, engine_status_dict) + sentinel = create_client_sentinel(engine_exception_q, engine_status_dict) - original_fault_sock = guard.fault_receiver_socket - original_cmd_sock = guard.cmd_socket - original_pub_sock = guard.fault_pub_socket - original_ctx = guard.zmq_ctx + original_fault_sock = sentinel.fault_receiver_socket + original_cmd_sock = sentinel.cmd_socket + original_pub_sock = sentinel.fault_pub_socket + original_ctx = sentinel.zmq_ctx - guard.shutdown_guard() + sentinel.shutdown_sentinel() - assert guard.client_guard_dead is True + assert sentinel.client_sentinel_dead is True with pytest.raises(zmq.ZMQError): original_fault_sock.recv() @@ -182,7 +182,7 @@ def test_shutdown_guard(): async def test_handle_fault_async(): engine_exception_q: queue.Queue[FaultInfo] = queue.Queue() engine_status_dict = create_test_thread_safe_dict({0: "Unhealthy"}) - guard = create_client_guard(engine_exception_q, engine_status_dict) + sentinel = create_client_sentinel(engine_exception_q, engine_status_dict) time.sleep(0.1) ctx = zmq.Context().instance() @@ -213,11 +213,11 @@ def response_cmd(cmd_socket): threading.Thread(target=receive_cmd, args=(cmd_socket,), daemon=True).start() threading.Thread(target=response_cmd, args=(cmd_socket,), daemon=True).start() - result = await guard.handle_fault("retry", 3) + result = await sentinel.handle_fault("retry", 3) assert result is True assert engine_status_dict[0] == "Healthy" cmd_socket.close() ctx.term() - guard.shutdown_guard() + sentinel.shutdown_sentinel() diff --git a/tests/v1/engine/test_engine_core_guard.py b/tests/v1/engine/test_engine_core_guard.py index 20c4d20fc617..f55a903db770 100644 --- a/tests/v1/engine/test_engine_core_guard.py +++ b/tests/v1/engine/test_engine_core_guard.py @@ -12,7 +12,7 @@ from vllm.utils.network_utils import make_zmq_socket from vllm.v1.engine.core import ( - EngineCoreGuard, + EngineCoreSentinel, EngineLoopPausedError, ) from vllm.v1.serial_utils import serialize_method_call @@ -20,13 +20,13 @@ CLIENT_CMD_ADDR = "tcp://127.0.0.1:8844" WORKER_CMD_ADDR = "tcp://127.0.0.1:8845" FAULT_REPORT_ADDR = "tcp://127.0.0.1:8846" -GUARD_IDENTITY = b"engine_guard_0" +SENTINEL_IDENTITY = b"engine_sentinel_0" -def create_engine_core_guard( +def create_engine_core_sentinel( fault_signal_q: queue.Queue, busy_loop_active: threading.Event ): - return EngineCoreGuard( + return EngineCoreSentinel( engine_index=0, fault_signal_q=fault_signal_q, cmd_q=queue.Queue(), @@ -35,31 +35,31 @@ def create_engine_core_guard( client_cmd_addr=CLIENT_CMD_ADDR, worker_cmd_addr=WORKER_CMD_ADDR, fault_report_addr=FAULT_REPORT_ADDR, - guard_identity=GUARD_IDENTITY, + sentinel_identity=SENTINEL_IDENTITY, tp_size=1, pp_size=1, dp_size=1, ) -def test_engine_core_guard_initialization(): +def test_engine_core_sentinel_initialization(): fault_signal_q: queue.Queue = queue.Queue() busy_loop_active = threading.Event() - guard = create_engine_core_guard(fault_signal_q, busy_loop_active) + sentinel = create_engine_core_sentinel(fault_signal_q, busy_loop_active) - assert guard.engine_index == 0 - assert guard.tp_size == 1 - assert guard.pp_size == 1 - assert not guard.communicator_aborted - assert guard.engine_running is True - assert guard.daemon is True + assert sentinel.engine_index == 0 + assert sentinel.tp_size == 1 + assert sentinel.pp_size == 1 + assert not sentinel.communicator_aborted + assert sentinel.engine_running is True + assert sentinel.daemon is True - assert guard.fault_report_socket.type == zmq.DEALER - assert guard.client_cmd_socket.type == zmq.DEALER - assert guard.worker_cmd_socket.type == zmq.ROUTER + assert sentinel.fault_report_socket.type == zmq.DEALER + assert sentinel.client_cmd_socket.type == zmq.DEALER + assert sentinel.worker_cmd_socket.type == zmq.ROUTER - guard.shutdown() + sentinel.shutdown() @pytest.mark.parametrize("instruction", ["pause", "retry"]) @@ -73,7 +73,7 @@ def test_run_handle_instruction(instruction): time.sleep(0.1) - guard = create_engine_core_guard(fault_signal_q, busy_loop_active) + sentinel = create_engine_core_sentinel(fault_signal_q, busy_loop_active) time.sleep(0.1) ctx = zmq.Context() @@ -96,7 +96,7 @@ def mock_worker_receiver(cmd_socket): logging.info(identity) cmd_socket.send_multipart([b"", json.dumps(response_dict).encode("utf-8")]) - threading.Thread(target=guard.run, daemon=True).start() + threading.Thread(target=sentinel.run, daemon=True).start() time.sleep(0.1) param = {"timeout": 3} @@ -106,7 +106,7 @@ def mock_worker_receiver(cmd_socket): param["new_stateless_dp_group_port"] = 23456 serial_instruction = serialize_method_call(instruction, **param) client_socket.send_multipart( - [GUARD_IDENTITY, b"", serial_instruction.encode("utf-8")] + [SENTINEL_IDENTITY, b"", serial_instruction.encode("utf-8")] ) if instruction == "pause": fault_signal_q.put(EngineLoopPausedError(Exception("test error"))) @@ -127,4 +127,4 @@ def mock_worker_receiver(cmd_socket): client_socket.close() worker_cmd_socket.close() - guard.shutdown() + sentinel.shutdown() diff --git a/vllm/config/fault_tolerance.py b/vllm/config/fault_tolerance.py index b12be7af81b0..24fbd6f1f259 100644 --- a/vllm/config/fault_tolerance.py +++ b/vllm/config/fault_tolerance.py @@ -39,10 +39,10 @@ class FaultToleranceConfig: engine_core_cmd_addr: str = "" """ - The ZMQ address between engine_core_guard and worker_guard. + The ZMQ address between engine_core_sentinel and worker_sentinel. It will be initialized and assigned in EngineCore, then passed to the Worker via vllm_config—this is required for the Worker - to spin up the WorkerGuard. + to spin up the WorkerSentinel. """ gloo_comm_timeout: int = 30 diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index dd03fdf4413b..8f3c6cef1a14 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -84,12 +84,12 @@ _R = TypeVar("_R") # Return type for collective_rpc -class EngineCoreGuard(threading.Thread): # changed +class EngineCoreSentinel(threading.Thread): """ - EngineCoreGuard monitors a single EngineCore instance, responsible for: + EngineCoreSentinel monitors a single EngineCore instance, responsible for: 1. Receiving fault signals (exceptions raised in EngineCore busy loop) - 2. Receiving and executing commands from ClientGuard - 3. Reporting execution results or faults back to the ClientGuard + 2. Receiving and executing commands from ClientSentinel + 3. Reporting execution results or faults back to the ClientSentinel """ def __init__( @@ -102,7 +102,7 @@ def __init__( client_cmd_addr: str, worker_cmd_addr: str, fault_report_addr: str, - guard_identity: bytes, + sentinel_identity: bytes, tp_size: int, pp_size: int, dp_size: int, @@ -118,30 +118,34 @@ def __init__( self.dp_size = dp_size self.ctx = zmq.Context() - # Client <-> EngineCoreGuard sockets + # Client <-> EngineCoreSentinel sockets self.fault_report_socket = make_zmq_socket( self.ctx, fault_report_addr, zmq.DEALER, bind=False, - identity=guard_identity, + identity=sentinel_identity, ) self.client_cmd_socket = make_zmq_socket( - self.ctx, client_cmd_addr, zmq.DEALER, bind=False, identity=guard_identity + self.ctx, + client_cmd_addr, + zmq.DEALER, + bind=False, + identity=sentinel_identity, ) - # EngineCoreGuard <-> WorkerGuard sockets + # EngineCoreSentinel <-> WorkerSentinel sockets self.worker_cmd_socket = make_zmq_socket( self.ctx, worker_cmd_addr, zmq.ROUTER, bind=True ) self.poller = zmq.Poller() self.communicator_aborted = False self.engine_running = True - self.engine_core_guard_dead = False - self.logger = self._make_engine_core_guard_logger() + self.engine_core_sentinel_dead = False + self.logger = self._make_engine_core_sentinel_logger() - def _make_engine_core_guard_logger(self): - prefix = f"[EngineCoreGuard_{self.engine_index}] " + def _make_engine_core_sentinel_logger(self): + prefix = f"[EngineCoreSentinel_{self.engine_index}] " def log(msg, *args, level="info", **kwargs): """ @@ -154,10 +158,10 @@ def log(msg, *args, level="info", **kwargs): def run(self) -> None: """ - Run the main monitoring loop for EngineCoreGuard. + Run the main monitoring loop for EngineCoreSentinel. """ poll_timeout_ms = 100 - while not self.engine_core_guard_dead: + while not self.engine_core_sentinel_dead: # Check for engine fault signals try: engine_exception = self.fault_signal_q.get_nowait() @@ -167,7 +171,7 @@ def run(self) -> None: self.logger("Engine paused", level="info") else: self.logger( - "[EngineCoreGuard] Detected exception %s: %s\n Call Stack:\n%s", + "Detected exception %s: %s\n Call Stack:\n%s", type(engine_exception).__name__, engine_exception, "".join(traceback.format_tb(engine_exception.__traceback__)), @@ -184,7 +188,9 @@ def run(self) -> None: poll_timeout=poll_timeout_ms, ) except zmq.ZMQError: - self.logger("Socket closed, terminating EngineCoreGuard", level="info") + self.logger( + "Socket closed, terminating EngineCoreSentinel", level="info" + ) break if has_msg: @@ -232,7 +238,7 @@ def _report_client_exception(self, exception: Exception) -> None: def _execute_cmd(self, cmd_str): """ - Execute a command received from ClientGuard. + Execute a command received from ClientSentinel. """ method, method_uuid, method_params = deserialize_method_call(cmd_str) self.logger("Executing command: %s", method, level="info") @@ -295,7 +301,7 @@ def pause(self, timeout: int = 1, soft_pause: bool = True) -> bool: def retry(self, new_stateless_dp_group_port: int, timeout: int = 1): """ - Handle the retry instruction from the ClientGuard. + Handle the retry instruction from the ClientSentinel. This instruction tells the EngineCore to continue its busy loop after being suspended due to an exception. """ @@ -350,7 +356,7 @@ def shutdown(self): self.worker_cmd_socket.close() if self.ctx is not None: self.ctx.term() - self.engine_core_guard_dead = True + self.engine_core_sentinel_dead = True def busy_loop_wrapper(busy_loop_func): @@ -949,15 +955,15 @@ def __init__( self.fault_signal_q: queue.Queue[Exception] = queue.Queue() self.cmd_q: queue.Queue[str | None] = queue.Queue(maxsize=1) self.engine_recovery_timeout = ft_config.engine_recovery_timeout - engine_core_guard_ids = addresses.engine_core_guard_identities - assert engine_core_guard_ids is not None + engine_core_sentinel_ids = addresses.engine_core_sentinel_identities + assert engine_core_sentinel_ids is not None assert addresses.fault_report_addr is not None assert addresses.client_cmd_addr is not None assert addresses.engine_core_cmd_addrs is not None engine_core_cmd_addr = addresses.engine_core_cmd_addrs[ vllm_config.parallel_config.data_parallel_rank ] - self.engine_core_guard = EngineCoreGuard( + self.engine_core_sentinel = EngineCoreSentinel( engine_index=self.engine_index, fault_signal_q=self.fault_signal_q, cmd_q=self.cmd_q, @@ -966,12 +972,12 @@ def __init__( fault_report_addr=addresses.fault_report_addr, client_cmd_addr=addresses.client_cmd_addr, worker_cmd_addr=engine_core_cmd_addr, - guard_identity=engine_core_guard_ids[self.engine_index], + sentinel_identity=engine_core_sentinel_ids[self.engine_index], tp_size=vllm_config.parallel_config.tensor_parallel_size, pp_size=vllm_config.parallel_config.pipeline_parallel_size, dp_size=vllm_config.parallel_config.data_parallel_size, ) - self.engine_core_guard.start() + self.engine_core_sentinel.start() vllm_config.fault_tolerance_config.engine_core_cmd_addr = ( engine_core_cmd_addr ) @@ -1492,7 +1498,7 @@ def process_output_sockets( def shutdown(self): super().shutdown() if self.vllm_config.fault_tolerance_config.enable_fault_tolerance: - self.engine_core_guard.shutdown() + self.engine_core_sentinel.shutdown() class DPEngineCoreProc(EngineCoreProc): diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index f72a825f568b..efad7fc120bc 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -345,12 +345,12 @@ def dp_engines_running(self) -> bool: return False -class ClientGuard: +class ClientSentinel: def __init__( self, fault_receiver_addr: str, cmd_addr: str, - engine_registry: list[bytes], + engine_registry: dict[int, bytes], engine_exception_q: queue.Queue[FaultInfo], fault_pub_addr: str, engine_status_dict: ThreadSafeDict[int, str], @@ -382,15 +382,15 @@ def __init__( self.engine_status_dict, ) - self.logger = self._make_client_guard_logger() + self.logger = self._make_client_sentinel_logger() - self.client_guard_dead = False + self.client_sentinel_dead = False Thread( target=self.fault_receiver, daemon=True, name="EngineCoreFaultReceiver" ).start() - def _make_client_guard_logger(self): - prefix = "[client_guard] " + def _make_client_sentinel_logger(self): + prefix = "[client_sentinel] " def log(msg, *args, level="info", **kwargs): """ @@ -427,7 +427,7 @@ def fault_receiver(self): engine_core component. It is designed to run continuously to ensure no critical error information from the engine core is missed. """ - while not self.client_guard_dead: + while not self.client_sentinel_dead: try: _, sender_identity, message = recv_router_dealer_message( self.fault_receiver_socket @@ -457,13 +457,13 @@ def fault_receiver(self): break self.logger("Fault receiver thread has stopped.") - def shutdown_guard(self): - self.client_guard_dead = True + def shutdown_sentinel(self): + self.client_sentinel_dead = True self.fault_receiver_socket.close() self.cmd_socket.close() self.fault_pub_socket.close() self.zmq_ctx.term() - self.logger("ClientGuard is closed.", level="info") + self.logger("ClientSentinel is closed.", level="info") @dataclass @@ -484,7 +484,7 @@ class BackgroundResources: output_queue_task: asyncio.Task | None = None stats_update_task: asyncio.Task | None = None shutdown_path: str | None = None - client_guard: ClientGuard | None = None + client_sentinel: ClientSentinel | None = None # Set if any of the engines are dead. Here so that the output # processing threads can access it without holding a ref to the client. @@ -498,8 +498,8 @@ def __call__(self): self.engine_manager.close() if self.coordinator is not None: self.coordinator.close() - if self.client_guard is not None: - self.client_guard.shutdown_guard() + if self.client_sentinel is not None: + self.client_sentinel.shutdown_sentinel() if isinstance(self.output_socket, zmq.asyncio.Socket): # Async case. @@ -689,7 +689,7 @@ def __init__( "addresses.client_cmd_addr should not be None at fault tolerance" " scenario" ) - self.engine_registry = addresses.engine_core_guard_identities + self.engine_registry = addresses.engine_core_sentinel_identities assert self.engine_registry is not None assert addresses.fault_pub_socket_addr is not None, ( "addresses.fault_pub_socket_addr should not be None at" @@ -698,7 +698,7 @@ def __init__( self.engine_status_dict: ThreadSafeDict[int, str] = ThreadSafeDict() for engine_id in range(vllm_config.parallel_config.data_parallel_size): self.engine_status_dict[engine_id] = "Healthy" - self.client_guard = ClientGuard( + self.client_sentinel = ClientSentinel( addresses.fault_report_addr, addresses.client_cmd_addr, self.engine_registry, @@ -706,7 +706,7 @@ def __init__( addresses.fault_pub_socket_addr, self.engine_status_dict, ) - self.resources.client_guard = self.client_guard + self.resources.client_sentinel = self.client_sentinel success = True finally: if not success: @@ -859,7 +859,7 @@ def monitor_engine_cores(): async def handle_fault(self, instruction: str, timeout: int, **kwargs) -> bool: """handle fault of current instance by instruction""" - return await self.client_guard.handle_fault(instruction, timeout, **kwargs) + return await self.client_sentinel.handle_fault(instruction, timeout, **kwargs) async def fault_reporter(self): return self.engine_status_dict.to_dict() diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 6ac4107c9e7b..907cdb7c9bbf 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -83,11 +83,11 @@ class EngineZmqAddresses: frontend_stats_publish_address: str | None = None # fault_report_addr: str | None = None - # ZMQ client_cmd socket address of client guard + # ZMQ client_cmd socket address of client sentinel client_cmd_addr: str | None = None - # identities of engine_core_guard - engine_core_guard_identities: list[bytes] | None = None - # ZMQ fault_pub_socket address of client guard + # identities of engine_core_sentinel + engine_core_sentinel_identities: dict[int, bytes] | None = None + # ZMQ fault_pub_socket address of client sentinel fault_pub_socket_addr: str | None = None @@ -133,7 +133,7 @@ def __init__( if vllm_config.fault_tolerance_config.enable_fault_tolerance: zmq_ctx = zmq.Context() identity = generate_identity_group( - "core_engine_proc_manager", "client_guard", "report", 1 + "core_engine_proc_manager", "client_sentinel", "report", 1 )[0] zmq_addr = get_engine_client_zmq_addr( local_only=False, @@ -198,7 +198,7 @@ def __init__( self.close() def _report_engine_dead(self, dead_message): - """Send engine dead message to ClientGuard""" + """Send engine dead message to ClientSentinel""" try: self.engine_down_socket.send_multipart( [ @@ -206,7 +206,7 @@ def _report_engine_dead(self, dead_message): dead_message.encode("utf-8"), ] ) - logger.info("Sent message to ClientGuard: %s", dead_message) + logger.info("Sent message to ClientSentinel: %s", dead_message) except Exception as e: logger.error("Failed to send message: %s", e) @@ -359,7 +359,7 @@ def __init__( port=vllm_config.fault_tolerance_config.internal_fault_report_port, ) identity = generate_identity_group( - "core_engine_actor_manager", "clinet_guard", "report", 1 + "core_engine_actor_manager", "clinet_sentinel", "report", 1 )[0] self.engine_down_socket = make_zmq_socket( ctx=zmq_ctx, @@ -923,12 +923,15 @@ def launch_core_engines( addresses.client_cmd_addr = get_engine_client_zmq_addr( local_only=client_local_only, host=host ) - addresses.engine_core_guard_identities = generate_identity_group( + identity_group = generate_identity_group( peer1="client", - peer2="engine_core_guard", + peer2="engine_core_sentinel", use="report and cmd", n=dp_size, ) + addresses.engine_core_sentinel_identities = { + rank: identity for rank, identity in enumerate(identity_group) + } addresses.fault_pub_socket_addr = get_engine_client_zmq_addr( local_only=False, host="0.0.0.0", @@ -1341,7 +1344,7 @@ class FaultHandler: def __init__( self, cmd_socket: zmq.Socket, - client_cmd_registry: list[bytes], + client_cmd_registry: dict[int, bytes], engine_exception_q: queue.Queue[FaultInfo], engine_status_dict: ThreadSafeDict[int, str], ) -> None: @@ -1349,7 +1352,7 @@ def __init__( self.engine_exception_q = engine_exception_q self.engine_status_dict: ThreadSafeDict[int, str] = engine_status_dict self.engine_identity_to_index: dict[bytes, int] = { - identity: i for i, identity in enumerate(client_cmd_registry) + identity: i for i, identity in client_cmd_registry.items() } # ensure handle_fault is executed sequentially self._task_queue: asyncio.Queue = asyncio.Queue() @@ -1445,7 +1448,7 @@ async def _handle_fault_internal( if response is None: self.logger( - "EngineCoreGuard[%s] did not respond" + "EngineCoreSentinel[%s] did not respond" ' to command "%s" within timeout.', engine_index, instruction, @@ -1454,7 +1457,8 @@ async def _handle_fault_internal( all_success = False elif not response.get("success", False): self.logger( - 'EngineCoreGuard[%s] failed to execute command "%s" (reason: %s)', + "EngineCoreSentinel[%s] failed to execute " + 'command "%s" (reason: %s)', engine_index, instruction, response.get("reason", "unknown"), diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index a9f684e2e977..d6fd3e7fd48a 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -74,7 +74,7 @@ from vllm.v1.core.sched.output import SchedulerOutput -class WorkerGuard: +class WorkerSentinel: def __init__( self, vllm_config: VllmConfig, @@ -100,16 +100,16 @@ def __init__( bind=False, identity=identity, ) - self.worker_guard_dead = False + self.worker_sentinel_dead = False self.pause_event = pause_event self.communicator_aborted = False self.logger = self._make_worker_logger() threading.Thread( - target=self.run, daemon=True, name="WorkerGuardCmdReceiver" + target=self.run, daemon=True, name="WorkerSentinelCmdReceiver" ).start() def _make_worker_logger(self): - prefix = f"[WorkerGuard_dp{self.dp_rank}_pp{self.pp_rank}_tp{self.tp_rank}] " + prefix = f"[WorkerSentinel_dp{self.dp_rank}_pp{self.pp_rank}_tp{self.tp_rank}] " def log(msg, *args, level="info", **kwargs): """ @@ -123,7 +123,7 @@ def log(msg, *args, level="info", **kwargs): def run(self): """Run the message receiving loop and handle control commands""" torch.cuda.set_device(self.device) - while not self.worker_guard_dead: + while not self.worker_sentinel_dead: try: # Use blocking receive - will wait until a message arrives has_msg, _, cmd_str = recv_router_dealer_message(self.cmd_socket) @@ -149,7 +149,7 @@ def run(self): # Socket was closed, exit loop. self.logger("Command socket closed, stopping thread.", level="info") break - self.logger("Worker guard thread has stopped.") + self.logger("Worker sentinel thread has stopped.") def pause_by_signal(self): self._set_device_communicator_status(False) @@ -242,7 +242,7 @@ def _send_execution_result(self, success: bool, method_uuid: str): self.cmd_socket.send_multipart([b"", msg_bytes]) def shutdown(self): - self.worker_guard_dead = True + self.worker_sentinel_dead = True self.cmd_socket.close() self.zmq_ctx.term() @@ -264,7 +264,7 @@ def __init__( is_driver_worker=is_driver_worker, ) - self.worker_guard: WorkerGuard | None = None + self.worker_sentinel: WorkerSentinel | None = None if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils.import_utils import init_cached_hf_modules @@ -474,7 +474,7 @@ def clear_input_batch_callback(): for req_id in list(cached_req_ids): input_batch.remove_request(req_id) - self.worker_guard = WorkerGuard( + self.worker_sentinel = WorkerSentinel( self.vllm_config, self.model_runner.pause_event, init_distributed_env_callback, @@ -1080,8 +1080,8 @@ def save_tensorized_model( def shutdown(self) -> None: if runner := getattr(self, "model_runner", None): runner.ensure_kv_transfer_shutdown() - if self.worker_guard is not None: - self.worker_guard.shutdown() + if self.worker_sentinel is not None: + self.worker_sentinel.shutdown() def init_worker_distributed_environment( From fae0e756c7adf30b5ce037b6626920af5eadf006 Mon Sep 17 00:00:00 2001 From: fangyuchu Date: Tue, 18 Nov 2025 15:59:25 +0800 Subject: [PATCH 6/6] reject requests when engine is in fault status Signed-off-by: fangyuchu --- vllm/engine/protocol.py | 2 +- vllm/entrypoints/api_server.py | 4 +-- vllm/entrypoints/openai/api_server.py | 49 ++++++++++++++++++--------- vllm/v1/engine/async_llm.py | 2 +- vllm/v1/engine/core_client.py | 9 +++-- 5 files changed, 44 insertions(+), 22 deletions(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 33dbcc496a45..06f756e24cb3 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -171,7 +171,7 @@ async def handle_fault( """send fault tolerance instruction to the engine""" raise NotImplementedError - async def exception_reporter(self): + async def get_fault_info(self): """report exception from engine_core""" raise NotImplementedError diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index ce81ceffb604..3a020b58b37b 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -95,8 +95,8 @@ async def process_fault_tolerance_instruction(request: Request) -> Response: async def get_fault_info() -> Response: """Health check.""" assert engine is not None - engine_exception_dict = await engine.exception_reporter() - return Response(json.dumps(engine_exception_dict), status_code=200) + engine_status_dict = await engine.get_fault_info() + return Response(json.dumps(engine_status_dict), status_code=200) @with_cancellation diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c296109a05ef..3c8882ccec5f 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -263,6 +263,20 @@ async def validate_json_request(raw_request: Request): ) +async def check_engine_fault(raw_request: Request): + client = engine_client(raw_request) + assert hasattr(client, "engine_core") + core_client = client.engine_core + if ( + hasattr(core_client, "client_sentinel") + and core_client.client_sentinel.is_faulted.is_set() + ): + raise HTTPException( + status_code=503, + detail="Service is in faulted state, cannot process requests.", + ) + + router = APIRouter() @@ -395,7 +409,7 @@ async def get_server_load_metrics(request: Request): @router.post( "/tokenize", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, @@ -430,7 +444,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): @router.post( "/detokenize", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, @@ -505,7 +519,7 @@ async def _convert_stream_to_sse_events( @router.post( "/v1/responses", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, @@ -598,7 +612,7 @@ async def cancel_responses(response_id: str, raw_request: Request): @router.post( "/v1/messages", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse}, @@ -654,7 +668,7 @@ def translate_error_response(response: ErrorResponse) -> JSONResponse: @router.post( "/v1/chat/completions", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, @@ -695,7 +709,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re @router.post( "/v1/completions", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, @@ -741,7 +755,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post( "/v1/embeddings", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, @@ -784,7 +798,7 @@ async def create_embedding( @router.post( "/pooling", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, @@ -820,7 +834,10 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): assert_never(generator) -@router.post("/classify", dependencies=[Depends(validate_json_request)]) +@router.post( + "/classify", + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], +) @with_cancellation @load_aware_call async def create_classify(request: ClassificationRequest, raw_request: Request): @@ -849,7 +866,7 @@ async def create_classify(request: ClassificationRequest, raw_request: Request): @router.post( "/score", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, @@ -882,7 +899,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): @router.post( "/v1/score", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, @@ -979,7 +996,7 @@ async def create_translations( @router.post( "/rerank", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, @@ -1011,7 +1028,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request): @router.post( "/v1/rerank", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, @@ -1030,7 +1047,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request): @router.post( "/v2/rerank", - dependencies=[Depends(validate_json_request)], + dependencies=[Depends(validate_json_request), Depends(check_engine_fault)], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, @@ -1303,8 +1320,8 @@ async def get_fault_info( raw_request: Request, ): client = engine_client(raw_request) - engine_exception_dict = await client.exception_reporter() - return JSONResponse(content=engine_exception_dict) + engine_status_dict = await client.get_fault_info() + return JSONResponse(content=engine_status_dict) # NOTE: Construct the TypeAdapters only once diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 9a51fe711ff3..4b4abc8f23fd 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -785,7 +785,7 @@ async def handle_fault( """send fault tolerance instruction to the engine""" return await self.engine_core.handle_fault(instruction, timeout, **kwargs) - async def exception_reporter(self): + async def get_fault_info(self): """report exception in engine core""" return await self.engine_core.fault_reporter() diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index efad7fc120bc..91a597b3a8a9 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -6,6 +6,7 @@ import multiprocessing import queue import sys +import threading import time import uuid import weakref @@ -355,6 +356,7 @@ def __init__( fault_pub_addr: str, engine_status_dict: ThreadSafeDict[int, str], ): + self.is_faulted = threading.Event() self.engine_registry = engine_registry self.zmq_ctx = zmq.Context() self.fault_receiver_socket = make_zmq_socket( @@ -411,12 +413,15 @@ async def handle_fault(self, instruction: str, timeout: int, **kwargs) -> bool: to handle system anomalies, ensuring stable operation or graceful degradation of the relevant components. """ - return await run_method( + result = await run_method( self.fault_handler, "handle_fault", args=(instruction, timeout), kwargs=kwargs, ) + if result: + self.is_faulted.clear() + return result def fault_receiver(self): """ @@ -443,7 +448,7 @@ def fault_receiver(self): self.fault_pub_socket.send_string( f"vllm_fault|{json.dumps(self.engine_status_dict.to_dict())}" ) - + self.is_faulted.set() # Pause healthy engines on fault. # Pause can be invoked again during fault-tolerance handling, # so it's unnecessary to track whether all engines are currently