Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions tests/v1/engine/test_client_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
import queue
import threading
import time
from unittest.mock import AsyncMock

import pytest
import zmq

from vllm.utils.collection_utils import ThreadSafeDict
from vllm.v1.engine.core_client import ClientSentinel
from vllm.v1.engine.utils import FaultHandler, FaultInfo

FAULT_RECEIVER_ADDR = "tcp://127.0.0.1:8844"
CMD_ADDR = "tcp://127.0.0.1:8845"
FAULT_PUB_ADDR = "tcp://127.0.0.1:8846"
FAULT_PUB_TOPIC = "vllm_fault"


def create_test_thread_safe_dict(initial_data=None):
if initial_data is None:
initial_data = {1: "Healthy"}

tsd = ThreadSafeDict()
if initial_data:
for k, v in initial_data.items():
tsd[k] = v
return tsd


def create_client_sentinel(
engine_exception_q: queue.Queue, engine_status_dict: ThreadSafeDict[int, str]
):
return ClientSentinel(
fault_receiver_addr=FAULT_RECEIVER_ADDR,
cmd_addr=CMD_ADDR,
engine_registry={0: b"engine_identity"},
engine_exception_q=engine_exception_q,
fault_pub_addr=FAULT_PUB_ADDR,
engine_status_dict=engine_status_dict,
)


def test_client_sentinel_initialization():
engine_exception_q: queue.Queue[FaultInfo] = queue.Queue()
engine_status_dict = create_test_thread_safe_dict({1: "Healthy"})
sentinel = create_client_sentinel(engine_exception_q, engine_status_dict)

assert sentinel.engine_registry[0] == b"engine_identity"
assert not sentinel.client_sentinel_dead
assert isinstance(sentinel.fault_handler, FaultHandler)
assert sentinel.engine_exception_q is engine_exception_q

assert sentinel.fault_receiver_socket.type == zmq.ROUTER
assert sentinel.cmd_socket.type == zmq.ROUTER
assert sentinel.fault_pub_socket.type == zmq.PUB

sentinel.shutdown_sentinel()


@pytest.mark.asyncio
async def test_handle_fault():
engine_exception_q: queue.Queue[FaultInfo] = queue.Queue()
engine_status_dict = create_test_thread_safe_dict({1: "Healthy"})
sentinel = create_client_sentinel(engine_exception_q, engine_status_dict)

engine_exception_q.put_nowait(
FaultInfo(engine_id="1", message="test exception", type="test")
)

sentinel.fault_handler.handle_fault = AsyncMock(return_value=True)

result = await sentinel.handle_fault("pause", 5)
assert result is True
sentinel.fault_handler.handle_fault.assert_awaited_once_with("pause", 5)

sentinel.shutdown_sentinel()


def test_fault_receiver():
engine_exception_q: queue.Queue[FaultInfo] = queue.Queue()
engine_status_dict = create_test_thread_safe_dict({1: "Healthy"})
sentinel = create_client_sentinel(engine_exception_q, engine_status_dict)

def send_test_message():
ctx = zmq.Context()
socket = ctx.socket(zmq.DEALER)
socket.setsockopt(zmq.IDENTITY, b"test_sender")
socket.connect(FAULT_RECEIVER_ADDR)

test_fault = FaultInfo(engine_id="1", type="dead", message="test error")
socket.send_multipart([b"", test_fault.serialize().encode("utf-8")])
socket.close()
ctx.term()

sender_thread = threading.Thread(target=send_test_message, daemon=True)
sender_thread.start()

def check_published_message():
ctx = zmq.Context()
sub_socket = ctx.socket(zmq.SUB)
sub_socket.connect(FAULT_PUB_ADDR)
sub_socket.setsockopt_string(zmq.SUBSCRIBE, FAULT_PUB_TOPIC)

message = sub_socket.recv_string()
sub_socket.close()
ctx.term()

prefix, data = message.split("|", 1)
assert prefix == FAULT_PUB_TOPIC
assert json.loads(data) == {"1": "Dead"}

check_thread = threading.Thread(target=check_published_message, daemon=True)
check_thread.start()

time.sleep(0.1)

assert not engine_exception_q.empty()
received_fault = engine_exception_q.get_nowait()
assert received_fault.engine_id == "1"
assert received_fault.type == "dead"

assert engine_status_dict[1] == "Dead"

sentinel.shutdown_sentinel()


def test_fault_receiver_unhealthy():
engine_exception_q: queue.Queue[FaultInfo] = queue.Queue()
engine_status_dict = create_test_thread_safe_dict({1: "Healthy"})
sentinel = create_client_sentinel(engine_exception_q, engine_status_dict)

def send_unhealthy_message():
ctx = zmq.Context()
socket = ctx.socket(zmq.DEALER)
socket.setsockopt(zmq.IDENTITY, b"engine_identity")
socket.connect(FAULT_RECEIVER_ADDR)

test_fault = FaultInfo(engine_id="1", type="error", message="test error")
socket.send_multipart([b"", test_fault.serialize().encode()])
socket.close()
ctx.term()

threading.Thread(target=send_unhealthy_message, daemon=True).start()
time.sleep(0.1)

assert engine_status_dict[1] == "Unhealthy"

sentinel.shutdown_sentinel()


def test_shutdown_sentinel():
engine_exception_q: queue.Queue[FaultInfo] = queue.Queue()
engine_status_dict = create_test_thread_safe_dict({1: "Healthy"})
sentinel = create_client_sentinel(engine_exception_q, engine_status_dict)

original_fault_sock = sentinel.fault_receiver_socket
original_cmd_sock = sentinel.cmd_socket
original_pub_sock = sentinel.fault_pub_socket
original_ctx = sentinel.zmq_ctx

sentinel.shutdown_sentinel()

assert sentinel.client_sentinel_dead is True

with pytest.raises(zmq.ZMQError):
original_fault_sock.recv()

with pytest.raises(zmq.ZMQError):
original_cmd_sock.recv()

with pytest.raises(zmq.ZMQError):
original_pub_sock.send(b"test")

assert original_ctx.closed


@pytest.mark.asyncio
async def test_handle_fault_async():
engine_exception_q: queue.Queue[FaultInfo] = queue.Queue()
engine_status_dict = create_test_thread_safe_dict({0: "Unhealthy"})
sentinel = create_client_sentinel(engine_exception_q, engine_status_dict)

time.sleep(0.1)
ctx = zmq.Context().instance()
cmd_socket = ctx.socket(zmq.DEALER)
cmd_socket.setsockopt(zmq.IDENTITY, b"engine_identity")
cmd_socket.connect(CMD_ADDR)
time.sleep(0.1)

uuid = None

def receive_cmd(cmd_socket):
nonlocal uuid
time.sleep(0.1)

identity, msg = cmd_socket.recv_multipart()
cmd_dict = json.loads(msg.decode("utf-8"))
assert cmd_dict["method"] == "retry"
assert cmd_dict["timeout"] == 3
uuid = cmd_dict["method_uuid"]

def response_cmd(cmd_socket):
nonlocal uuid
while uuid is None:
time.sleep(0.1)
execute_result = {"engine_index": 0, "success": True, "method_uuid": uuid}
cmd_socket.send_multipart([b"", json.dumps(execute_result).encode("utf-8")])

threading.Thread(target=receive_cmd, args=(cmd_socket,), daemon=True).start()
threading.Thread(target=response_cmd, args=(cmd_socket,), daemon=True).start()

result = await sentinel.handle_fault("retry", 3)

assert result is True
assert engine_status_dict[0] == "Healthy"

cmd_socket.close()
ctx.term()
sentinel.shutdown_sentinel()
130 changes: 130 additions & 0 deletions tests/v1/engine/test_engine_core_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
import logging
import queue
import threading
import time

import pytest
import zmq

from vllm.utils.network_utils import make_zmq_socket
from vllm.v1.engine.core import (
EngineCoreSentinel,
EngineLoopPausedError,
)
from vllm.v1.serial_utils import serialize_method_call

CLIENT_CMD_ADDR = "tcp://127.0.0.1:8844"
WORKER_CMD_ADDR = "tcp://127.0.0.1:8845"
FAULT_REPORT_ADDR = "tcp://127.0.0.1:8846"
SENTINEL_IDENTITY = b"engine_sentinel_0"


def create_engine_core_sentinel(
fault_signal_q: queue.Queue, busy_loop_active: threading.Event
):
return EngineCoreSentinel(
engine_index=0,
fault_signal_q=fault_signal_q,
cmd_q=queue.Queue(),
busy_loop_active=busy_loop_active,
engine_input_q=queue.Queue(),
client_cmd_addr=CLIENT_CMD_ADDR,
worker_cmd_addr=WORKER_CMD_ADDR,
fault_report_addr=FAULT_REPORT_ADDR,
sentinel_identity=SENTINEL_IDENTITY,
tp_size=1,
pp_size=1,
dp_size=1,
)


def test_engine_core_sentinel_initialization():
fault_signal_q: queue.Queue = queue.Queue()
busy_loop_active = threading.Event()

sentinel = create_engine_core_sentinel(fault_signal_q, busy_loop_active)

assert sentinel.engine_index == 0
assert sentinel.tp_size == 1
assert sentinel.pp_size == 1
assert not sentinel.communicator_aborted
assert sentinel.engine_running is True
assert sentinel.daemon is True

assert sentinel.fault_report_socket.type == zmq.DEALER
assert sentinel.client_cmd_socket.type == zmq.DEALER
assert sentinel.worker_cmd_socket.type == zmq.ROUTER

sentinel.shutdown()


@pytest.mark.parametrize("instruction", ["pause", "retry"])
def test_run_handle_instruction(instruction):
fault_signal_q: queue.Queue = queue.Queue()
busy_loop_active = threading.Event()

client_socket = make_zmq_socket(
ctx=zmq.Context(), path=CLIENT_CMD_ADDR, socket_type=zmq.ROUTER, bind=True
)

time.sleep(0.1)

sentinel = create_engine_core_sentinel(fault_signal_q, busy_loop_active)
time.sleep(0.1)

ctx = zmq.Context()
worker_cmd_socket = ctx.socket(zmq.DEALER)
worker_cmd_socket.setsockopt(zmq.IDENTITY, b"0_0")
worker_cmd_socket.connect(WORKER_CMD_ADDR)

def mock_worker_receiver(cmd_socket):
time.sleep(0.1)
logging.info("start worker")
identity, msg = cmd_socket.recv_multipart()
logging.info(identity)
cmd_dict = json.loads(msg.decode("utf-8"))
assert (
cmd_dict["method"] == "pause_by_signal"
if instruction == "pause"
else "retry"
)
response_dict = {"success": True, "method_uuid": cmd_dict["method_uuid"]}
logging.info(identity)
cmd_socket.send_multipart([b"", json.dumps(response_dict).encode("utf-8")])

threading.Thread(target=sentinel.run, daemon=True).start()
time.sleep(0.1)

param = {"timeout": 3}
if instruction == "pause":
param["soft_pause"] = True
elif instruction == "retry":
param["new_stateless_dp_group_port"] = 23456
serial_instruction = serialize_method_call(instruction, **param)
client_socket.send_multipart(
[SENTINEL_IDENTITY, b"", serial_instruction.encode("utf-8")]
)
if instruction == "pause":
fault_signal_q.put(EngineLoopPausedError(Exception("test error")))
elif instruction == "retry":
busy_loop_active.set()

threading.Thread(
target=mock_worker_receiver, args=(worker_cmd_socket,), daemon=True
).start()

time.sleep(0.1)
identity, _, msg = client_socket.recv_multipart()
result_dict = json.loads(msg.decode("utf-8"))
assert result_dict["engine_index"] == 0
assert result_dict["success"]

time.sleep(0.1)

client_socket.close()
worker_cmd_socket.close()
sentinel.shutdown()
3 changes: 3 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from vllm.config.device import DeviceConfig
from vllm.config.ec_transfer import ECTransferConfig
from vllm.config.fault_tolerance import FaultToleranceConfig
from vllm.config.kv_events import KVEventsConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.config.load import LoadConfig
Expand Down Expand Up @@ -86,6 +87,8 @@
"SpeechToTextConfig",
# From vllm.config.structured_outputs
"StructuredOutputsConfig",
# From vllm.config.fault_tolerance
"FaultToleranceConfig",
# From vllm.config.utils
"ConfigType",
"SupportsMetricsInfo",
Expand Down
Loading