Skip to content

Commit 7194000

Browse files
committed
Fix DT and zmq socket closing issues, updated names per feedback and reinitialize dp_group with new port
Signed-off-by: fangyuchu <fangyuchu@qq.com>
1 parent e875ad3 commit 7194000

File tree

9 files changed

+129
-122
lines changed

9 files changed

+129
-122
lines changed

tests/v1/engine/test_client_guard.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_shutdown_guard():
183183
@pytest.mark.asyncio
184184
async def test_handle_fault_async():
185185
engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue()
186-
engine_status_dict = create_test_thread_safe_dict({1: "Unhealthy"})
186+
engine_status_dict = create_test_thread_safe_dict({0: "Unhealthy"})
187187
guard = create_client_guard(engine_exception_q, engine_status_dict)
188188

189189
time.sleep(0.1)
@@ -208,7 +208,7 @@ def response_cmd(cmd_socket):
208208
nonlocal uuid
209209
while uuid is None:
210210
time.sleep(0.1)
211-
execute_result = {"engine_index": 1, "success": True, "method_uuid": uuid}
211+
execute_result = {"engine_index": 0, "success": True, "method_uuid": uuid}
212212
cmd_socket.send_multipart([b"", json.dumps(execute_result).encode("utf-8")])
213213

214214
threading.Thread(target=receive_cmd, args=(cmd_socket,)).start()
@@ -217,6 +217,6 @@ def response_cmd(cmd_socket):
217217
result = await guard.handle_fault("retry", 3)
218218

219219
assert result is True
220-
assert engine_status_dict[1] == "Healthy"
220+
assert engine_status_dict[0] == "Healthy"
221221

222222
guard.shutdown_guard()

tests/v1/engine/test_engine_core_guard.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def create_engine_core_guard(
3838
guard_identity=GUARD_IDENTITY,
3939
tp_size=1,
4040
pp_size=1,
41+
dp_size=1,
4142
)
4243

4344

@@ -101,6 +102,8 @@ def mock_worker_receiver(cmd_socket):
101102
param = {"timeout": 3}
102103
if instruction == "pause":
103104
param["soft_pause"] = True
105+
elif instruction == "retry":
106+
param["new_stateless_dp_group_port"] = 23456
104107
serial_instruction = serialize_method_call(instruction, **param)
105108
client_socket.send_multipart(
106109
[GUARD_IDENTITY, b"", serial_instruction.encode("utf-8")]

vllm/config/parallel.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,8 @@ def get_next_dp_init_port(self) -> int:
327327

328328
def stateless_init_dp_group(
329329
self,
330-
gloo_comm_timeout: int = 30,
331-
enable_fault_tolerance: bool = False,
330+
gloo_comm_timeout: int | None = None,
331+
dp_init_port: int | None = None,
332332
) -> ProcessGroup:
333333
# NOTE: In high-concurrency scenarios multiple processes
334334
# can pick the same (currently free) port through a race
@@ -345,23 +345,25 @@ def stateless_init_dp_group(
345345

346346
max_retries = 5
347347
last_exc: Exception | None = None
348+
if dp_init_port is None:
349+
dp_init_port = self.get_next_dp_init_port()
348350
for _ in range(max_retries):
349351
try:
350352
# use gloo since the engine process might not have cuda device
351353
return stateless_init_torch_distributed_process_group(
352354
self.data_parallel_master_ip,
353-
self.get_next_dp_init_port(),
355+
dp_init_port,
354356
self.data_parallel_rank,
355357
self.data_parallel_size,
356358
backend=current_platform.dist_backend,
357359
gloo_comm_timeout=gloo_comm_timeout,
358-
enable_fault_tolerance=enable_fault_tolerance,
359360
)
360361
except DistNetworkError as e:
361362
# We only want to retry when the root cause is EADDRINUSE.
362363
if "EADDRINUSE" in str(e):
363364
logger.warning("Address already in use. Retrying with a new port.")
364365
last_exc = e
366+
dp_init_port = self.get_next_dp_init_port()
365367
continue # try again with a new port
366368
raise e
367369

vllm/distributed/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -463,8 +463,7 @@ def stateless_init_torch_distributed_process_group(
463463
rank: int,
464464
world_size: int,
465465
backend: str,
466-
gloo_comm_timeout: int,
467-
enable_fault_tolerance: bool = False,
466+
gloo_comm_timeout: int | None,
468467
) -> ProcessGroup:
469468
"""
470469
A replacement for `torch.distributed.init_process_group` that does not
@@ -499,10 +498,11 @@ def stateless_init_torch_distributed_process_group(
499498
"""
500499
init_method = get_tcp_uri(host, port)
501500
backend = Backend(backend) # it is basically string
502-
if enable_fault_tolerance:
503-
timeout = timedelta(seconds=gloo_comm_timeout)
504-
else:
501+
502+
if gloo_comm_timeout is None:
505503
timeout = _get_default_timeout(backend)
504+
else:
505+
timeout = timedelta(seconds=gloo_comm_timeout)
506506

507507
store, rank, world_size = next(
508508
rendezvous(init_method, rank, world_size, timeout=timeout)

vllm/v1/core/sched/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def get_grammar_bitmask(
6868
def preempt_request(
6969
self,
7070
scheduled_timestamp: float | None = None,
71-
preempted_req: Request | None = None,
72-
) -> Request:
71+
preempted_req: Optional["Request"] = None,
72+
) -> "Request":
7373
"""
7474
Preempt a running request and move it back to the waiting queue.
7575

vllm/v1/engine/core.py

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,9 @@
4545
from vllm.v1.core.sched.interface import SchedulerInterface
4646
from vllm.v1.core.sched.output import SchedulerOutput
4747
from vllm.v1.engine import (
48-
EngineCoreOutput,
4948
EngineCoreOutputs,
5049
EngineCoreRequest,
5150
EngineCoreRequestType,
52-
FinishReason,
5351
ReconfigureDistributedRequest,
5452
ReconfigureRankType,
5553
UtilityOutput,
@@ -180,13 +178,16 @@ def run(self) -> None:
180178
self.engine_running = False
181179
except queue.Empty:
182180
pass
183-
184-
if self.client_cmd_socket.closed:
185-
self.logger("Client socket closed", level="info")
181+
try:
182+
has_msg, _, cmd_str = recv_router_dealer_message(
183+
self.client_cmd_socket,
184+
use_poller=True,
185+
poll_timeout=poll_timeout_ms,
186+
)
187+
except zmq.ZMQError:
188+
self.logger("Socket closed, terminating EngineCoreGuard", level="info")
186189
break
187-
has_msg, _, cmd_str = recv_router_dealer_message(
188-
self.client_cmd_socket, use_poller=True, poll_timeout=poll_timeout_ms
189-
)
190+
190191
if has_msg:
191192
self.logger("Received cmd: %s", cmd_str, level="info")
192193
self._execute_cmd(cmd_str)
@@ -205,7 +206,7 @@ def _execute_worker_method(self, method_name, timeout: int = 5, **kwargs) -> boo
205206
identities = set()
206207
for tp_rank in range(self.tp_size):
207208
for pp_rank in range(self.pp_size):
208-
identity = f"{tp_rank}_{pp_rank}".encode()
209+
identity = f"{pp_rank}_{tp_rank}".encode()
209210
identities.add(identity)
210211

211212
method_uuid = broadcast_instruction(
@@ -287,26 +288,30 @@ def pause(self, timeout: int = 1, soft_pause: bool = True) -> bool:
287288
success = True
288289
if not soft_pause:
289290
# abort the communicators
290-
self._stop_worker_execution(soft_pause=False, timeout=timeout)
291+
success = self._stop_worker_execution(soft_pause=False, timeout=timeout)
291292
return success
292293

293-
def retry(self, timeout: int = 1):
294+
def retry(self, new_stateless_dp_group_port: int, timeout: int = 1):
294295
"""
295296
Handle the retry instruction from the ClientGuard.
296297
This instruction tells the EngineCore to continue its busy loop
297298
after being suspended due to an exception.
298299
"""
299300
start_time = time.monotonic()
300301

301-
success = self._execute_worker_method("restart_worker", timeout=timeout)
302+
success = self._execute_worker_method("restore_worker", timeout=timeout)
302303
if not success:
303304
return success
304305

305306
if self.dp_size > 1:
306307
# If the Gloo communication times out
307308
# the data parallel group (dp_group) needs to be reinitialized
308309
command = "reinit_dp_group_on_fault_tolerance"
309-
self.cmd_q.put(serialize_method_call(command))
310+
self.cmd_q.put(
311+
serialize_method_call(
312+
command, new_stateless_dp_group_port=new_stateless_dp_group_port
313+
)
314+
)
310315
else:
311316
self.cmd_q.put(None)
312317

@@ -1486,21 +1491,6 @@ def process_output_sockets(
14861491
# Limit the number of buffers to reuse.
14871492
reuse_buffers.append(buffer)
14881493

1489-
def engine_finish_requests(self):
1490-
assert isinstance(self.scheduler, V1Scheduler)
1491-
engine_finish_outputs = EngineCoreOutputs()
1492-
engine_finish_outputs.engine_index = self.engine_index
1493-
for request_id in list(self.scheduler.requests.keys()):
1494-
self.scheduler.finish_requests(request_id, RequestStatus.FINISHED_ABORTED)
1495-
engine_finish_outputs.outputs.append(
1496-
EngineCoreOutput(
1497-
request_id=request_id,
1498-
finish_reason=FinishReason.ABORT,
1499-
new_token_ids=[],
1500-
)
1501-
)
1502-
self.output_queue.put((0, engine_finish_outputs))
1503-
15041494
def shutdown(self):
15051495
super().shutdown()
15061496
if self.vllm_config.fault_tolerance_config.enable_fault_tolerance:
@@ -1562,7 +1552,6 @@ def _init_data_parallel(self, vllm_config: VllmConfig):
15621552
self.dp_rank = dp_rank
15631553
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group(
15641554
vllm_config.fault_tolerance_config.gloo_comm_timeout,
1565-
vllm_config.fault_tolerance_config.enable_fault_tolerance,
15661555
)
15671556

15681557
def shutdown(self):
@@ -1671,12 +1660,13 @@ def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
16711660

16721661
return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
16731662

1674-
def reinit_dp_group_on_fault_tolerance(self):
1663+
def reinit_dp_group_on_fault_tolerance(self, new_stateless_dp_group_port: int):
16751664
stateless_destroy_torch_distributed_process_group(self.dp_group)
16761665
self.dp_group = self.vllm_config.parallel_config.stateless_init_dp_group(
16771666
self.vllm_config.fault_tolerance_config.gloo_comm_timeout,
1678-
self.vllm_config.fault_tolerance_config.enable_fault_tolerance,
1667+
new_stateless_dp_group_port,
16791668
)
1669+
self.step_counter = 0
16801670

16811671
def reinitialize_distributed(
16821672
self, reconfig_request: ReconfigureDistributedRequest

vllm/v1/engine/core_client.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -431,31 +431,35 @@ def fault_receiver(self):
431431
engine_core component. It is designed to run continuously to ensure no critical
432432
error information from the engine core is missed.
433433
"""
434-
while True:
435-
_, sender_identity, message = recv_router_dealer_message(
436-
self.fault_receiver_socket
437-
)
438-
if self.client_guard_dead:
439-
self.logger("client guard dead, stop receiving fault")
440-
break
441-
assert message is not None, (
442-
"message should not be None at fault tolerance scenario"
443-
)
434+
while not self.client_guard_dead:
435+
try:
436+
_, sender_identity, message = recv_router_dealer_message(
437+
self.fault_receiver_socket
438+
)
439+
assert message is not None, (
440+
"message should not be None at fault tolerance scenario"
441+
)
444442

445-
fault_info = FaultInfo.from_json(message)
446-
self.engine_exception_q.put_nowait(fault_info)
447-
engine_status = "Dead" if "dead" in fault_info.type else "Unhealthy"
448-
self.engine_status_dict[int(fault_info.engine_id)] = engine_status
449-
self.fault_pub_socket.send_string(
450-
f"vllm_fault|{json.dumps(self.engine_status_dict.to_dict())}"
451-
)
452-
# TODO Asynchronous issuance of pause commands and design of engine
453-
# core status
454-
# Pause healthy engines on fault.
455-
# Pause will be invoked again during fault-tolerance handling,
456-
# so it's unnecessary to track whether all engines are currently
457-
# paused.
458-
self.fault_handler.submit_fault("pause", 5, soft_pause=False)
443+
fault_info = FaultInfo.from_json(message)
444+
self.engine_exception_q.put_nowait(fault_info)
445+
engine_status = "Dead" if "dead" in fault_info.type else "Unhealthy"
446+
self.engine_status_dict[int(fault_info.engine_id)] = engine_status
447+
self.fault_pub_socket.send_string(
448+
f"vllm_fault|{json.dumps(self.engine_status_dict.to_dict())}"
449+
)
450+
451+
# Pause healthy engines on fault.
452+
# Pause will be invoked again during fault-tolerance handling,
453+
# so it's unnecessary to track whether all engines are currently
454+
# paused.
455+
self.fault_handler.submit_fault("pause", 5, soft_pause=False)
456+
except zmq.ZMQError:
457+
# Socket was closed during polling, exit loop.
458+
self.logger(
459+
"Fault receiver socket closed, stopping thread.", level="info"
460+
)
461+
break
462+
self.logger("Fault receiver thread has stopped.")
459463

460464
def shutdown_guard(self):
461465
self.client_guard_dead = True

vllm/v1/engine/utils.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from vllm.ray.ray_env import get_env_vars_to_copy
2828
from vllm.utils.collection_utils import ThreadSafeDict
2929
from vllm.utils.network_utils import (
30+
get_open_port,
3031
get_open_zmq_ipc_path,
3132
make_zmq_socket,
3233
recv_router_dealer_message,
@@ -36,7 +37,7 @@
3637
from vllm.v1.engine.coordinator import DPCoordinator
3738
from vllm.v1.engine.exceptions import FaultInfo
3839
from vllm.v1.executor import Executor
39-
from vllm.v1.serial_utils import serialize_method_call
40+
from vllm.v1.serial_utils import run_method, serialize_method_call
4041
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
4142

4243
if TYPE_CHECKING:
@@ -1390,40 +1391,44 @@ async def _dispatcher(self):
13901391
if fut:
13911392
fut.set_exception(e)
13921393

1393-
async def _handle_fault_internal(
1394-
self, instruction: str, timeout: int, **kwargs
1395-
) -> bool:
1396-
if instruction == "retry" and "Dead" in self.engine_status_dict.values():
1394+
def retry(self, **kwargs):
1395+
if "Dead" in self.engine_status_dict.values():
13971396
self.logger(
13981397
"engine_core dead unexpectedly, retry is impossible,"
13991398
"shutdown will be performed",
14001399
level="info",
14011400
)
1402-
return False
1401+
return False, set(), kwargs
1402+
1403+
target_engines = set(self.engine_identity_to_index.keys())
1404+
kwargs["new_stateless_dp_group_port"] = get_open_port()
1405+
return True, target_engines, kwargs
1406+
1407+
def pause(self, **kwargs):
1408+
self.logger(
1409+
"Pause operation is best-effort only. Due to the complexity of "
1410+
"collective communications (e.g., timing dependencies and "
1411+
"synchronization barriers), pausing may not always succeed. If "
1412+
"the process remains unresponsive or collective operations "
1413+
"cannot be interrupted, consider shutting down and restarting "
1414+
"the instance.",
1415+
level="warning",
1416+
)
14031417

1404-
if instruction == "pause":
1405-
logger.warning(
1406-
"Pause operation is best-effort only. Due to the complexity of "
1407-
"collective communications (e.g., timing dependencies and "
1408-
"synchronization barriers), pausing may not always succeed. If "
1409-
"the process remains unresponsive or collective operations "
1410-
"cannot be interrupted, consider shutting down and restarting "
1411-
"the instance."
1412-
)
1418+
alive_engines = {
1419+
identity
1420+
for identity, index in self.engine_identity_to_index.items()
1421+
if self.engine_status_dict.get(index) != "Dead"
1422+
}
1423+
return True, alive_engines, kwargs
14131424

1414-
dead_engine_indices = {
1415-
index
1416-
for index, status in self.engine_status_dict.items()
1417-
if status == "Dead"
1418-
}
1419-
1420-
target_engines = {
1421-
identity
1422-
for identity, index in self.engine_identity_to_index.items()
1423-
if index not in dead_engine_indices
1424-
}
1425-
else:
1426-
target_engines = set(self.engine_identity_to_index.keys())
1425+
async def _handle_fault_internal(
1426+
self, instruction: str, timeout: int, **kwargs
1427+
) -> bool:
1428+
success, target_engines, kwargs = run_method(self, instruction, (), kwargs)
1429+
1430+
if not success:
1431+
return False
14271432

14281433
if timeout is not None:
14291434
kwargs["timeout"] = timeout

0 commit comments

Comments
 (0)