Skip to content

Commit 3b203d6

Browse files
committed
Fix DT and zmq socket closing issues, updated names per feedback and reinitialize dp_group with new port
Signed-off-by: fangyuchu <fangyuchu@qq.com>
1 parent e2c1f17 commit 3b203d6

File tree

9 files changed

+129
-122
lines changed

9 files changed

+129
-122
lines changed

tests/v1/engine/test_client_guard.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_shutdown_guard():
183183
@pytest.mark.asyncio
184184
async def test_handle_fault_async():
185185
engine_exception_q: asyncio.Queue[FaultInfo] = asyncio.Queue()
186-
engine_status_dict = create_test_thread_safe_dict({1: "Unhealthy"})
186+
engine_status_dict = create_test_thread_safe_dict({0: "Unhealthy"})
187187
guard = create_client_guard(engine_exception_q, engine_status_dict)
188188

189189
time.sleep(0.1)
@@ -208,7 +208,7 @@ def response_cmd(cmd_socket):
208208
nonlocal uuid
209209
while uuid is None:
210210
time.sleep(0.1)
211-
execute_result = {"engine_index": 1, "success": True, "method_uuid": uuid}
211+
execute_result = {"engine_index": 0, "success": True, "method_uuid": uuid}
212212
cmd_socket.send_multipart([b"", json.dumps(execute_result).encode("utf-8")])
213213

214214
threading.Thread(target=receive_cmd, args=(cmd_socket,)).start()
@@ -217,6 +217,6 @@ def response_cmd(cmd_socket):
217217
result = await guard.handle_fault("retry", 3)
218218

219219
assert result is True
220-
assert engine_status_dict[1] == "Healthy"
220+
assert engine_status_dict[0] == "Healthy"
221221

222222
guard.shutdown_guard()

tests/v1/engine/test_engine_core_guard.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def create_engine_core_guard(
3838
guard_identity=GUARD_IDENTITY,
3939
tp_size=1,
4040
pp_size=1,
41+
dp_size=1,
4142
)
4243

4344

@@ -101,6 +102,8 @@ def mock_worker_receiver(cmd_socket):
101102
param = {"timeout": 3}
102103
if instruction == "pause":
103104
param["soft_pause"] = True
105+
elif instruction == "retry":
106+
param["new_stateless_dp_group_port"] = 23456
104107
serial_instruction = serialize_method_call(instruction, **param)
105108
client_socket.send_multipart(
106109
[GUARD_IDENTITY, b"", serial_instruction.encode("utf-8")]

vllm/config/parallel.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ def get_next_dp_init_port(self) -> int:
339339

340340
def stateless_init_dp_group(
341341
self,
342-
gloo_comm_timeout: int = 30,
343-
enable_fault_tolerance: bool = False,
342+
gloo_comm_timeout: int | None = None,
343+
dp_init_port: int | None = None,
344344
) -> ProcessGroup:
345345
# NOTE: In high-concurrency scenarios multiple processes
346346
# can pick the same (currently free) port through a race
@@ -357,23 +357,25 @@ def stateless_init_dp_group(
357357

358358
max_retries = 5
359359
last_exc: Exception | None = None
360+
if dp_init_port is None:
361+
dp_init_port = self.get_next_dp_init_port()
360362
for _ in range(max_retries):
361363
try:
362364
# use gloo since the engine process might not have cuda device
363365
return stateless_init_torch_distributed_process_group(
364366
self.data_parallel_master_ip,
365-
self.get_next_dp_init_port(),
367+
dp_init_port,
366368
self.data_parallel_rank,
367369
self.data_parallel_size,
368370
backend=current_platform.dist_backend,
369371
gloo_comm_timeout=gloo_comm_timeout,
370-
enable_fault_tolerance=enable_fault_tolerance,
371372
)
372373
except DistNetworkError as e:
373374
# We only want to retry when the root cause is EADDRINUSE.
374375
if "EADDRINUSE" in str(e):
375376
logger.warning("Address already in use. Retrying with a new port.")
376377
last_exc = e
378+
dp_init_port = self.get_next_dp_init_port()
377379
continue # try again with a new port
378380
raise e
379381

vllm/distributed/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -463,8 +463,7 @@ def stateless_init_torch_distributed_process_group(
463463
rank: int,
464464
world_size: int,
465465
backend: str,
466-
gloo_comm_timeout: int,
467-
enable_fault_tolerance: bool = False,
466+
gloo_comm_timeout: int | None,
468467
) -> ProcessGroup:
469468
"""
470469
A replacement for `torch.distributed.init_process_group` that does not
@@ -499,10 +498,11 @@ def stateless_init_torch_distributed_process_group(
499498
"""
500499
init_method = get_tcp_uri(host, port)
501500
backend = Backend(backend) # it is basically string
502-
if enable_fault_tolerance:
503-
timeout = timedelta(seconds=gloo_comm_timeout)
504-
else:
501+
502+
if gloo_comm_timeout is None:
505503
timeout = _get_default_timeout(backend)
504+
else:
505+
timeout = timedelta(seconds=gloo_comm_timeout)
506506

507507
store, rank, world_size = next(
508508
rendezvous(init_method, rank, world_size, timeout=timeout)

vllm/v1/core/sched/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def get_grammar_bitmask(
6868
def preempt_request(
6969
self,
7070
scheduled_timestamp: float | None = None,
71-
preempted_req: Request | None = None,
72-
) -> Request:
71+
preempted_req: Optional["Request"] = None,
72+
) -> "Request":
7373
"""
7474
Preempt a running request and move it back to the waiting queue.
7575

vllm/v1/engine/core.py

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,9 @@
4545
from vllm.v1.core.sched.interface import SchedulerInterface
4646
from vllm.v1.core.sched.output import SchedulerOutput
4747
from vllm.v1.engine import (
48-
EngineCoreOutput,
4948
EngineCoreOutputs,
5049
EngineCoreRequest,
5150
EngineCoreRequestType,
52-
FinishReason,
5351
ReconfigureDistributedRequest,
5452
ReconfigureRankType,
5553
UtilityOutput,
@@ -179,13 +177,16 @@ def run(self) -> None:
179177
self.engine_running = False
180178
except queue.Empty:
181179
pass
182-
183-
if self.client_cmd_socket.closed:
184-
self.logger("Client socket closed", level="info")
180+
try:
181+
has_msg, _, cmd_str = recv_router_dealer_message(
182+
self.client_cmd_socket,
183+
use_poller=True,
184+
poll_timeout=poll_timeout_ms,
185+
)
186+
except zmq.ZMQError:
187+
self.logger("Socket closed, terminating EngineCoreGuard", level="info")
185188
break
186-
has_msg, _, cmd_str = recv_router_dealer_message(
187-
self.client_cmd_socket, use_poller=True, poll_timeout=poll_timeout_ms
188-
)
189+
189190
if has_msg:
190191
self.logger("Received cmd: %s", cmd_str, level="info")
191192
self._execute_cmd(cmd_str)
@@ -204,7 +205,7 @@ def _execute_worker_method(self, method_name, timeout: int = 5, **kwargs) -> boo
204205
identities = set()
205206
for tp_rank in range(self.tp_size):
206207
for pp_rank in range(self.pp_size):
207-
identity = f"{tp_rank}_{pp_rank}".encode()
208+
identity = f"{pp_rank}_{tp_rank}".encode()
208209
identities.add(identity)
209210

210211
method_uuid = broadcast_instruction(
@@ -286,26 +287,30 @@ def pause(self, timeout: int = 1, soft_pause: bool = True) -> bool:
286287
success = True
287288
if not soft_pause:
288289
# abort the communicators
289-
self._stop_worker_execution(soft_pause=False, timeout=timeout)
290+
success = self._stop_worker_execution(soft_pause=False, timeout=timeout)
290291
return success
291292

292-
def retry(self, timeout: int = 1):
293+
def retry(self, new_stateless_dp_group_port: int, timeout: int = 1):
293294
"""
294295
Handle the retry instruction from the ClientGuard.
295296
This instruction tells the EngineCore to continue its busy loop
296297
after being suspended due to an exception.
297298
"""
298299
start_time = time.monotonic()
299300

300-
success = self._execute_worker_method("restart_worker", timeout=timeout)
301+
success = self._execute_worker_method("restore_worker", timeout=timeout)
301302
if not success:
302303
return success
303304

304305
if self.dp_size > 1:
305306
# If the Gloo communication times out
306307
# the data parallel group (dp_group) needs to be reinitialized
307308
command = "reinit_dp_group_on_fault_tolerance"
308-
self.cmd_q.put(serialize_method_call(command))
309+
self.cmd_q.put(
310+
serialize_method_call(
311+
command, new_stateless_dp_group_port=new_stateless_dp_group_port
312+
)
313+
)
309314
else:
310315
self.cmd_q.put(None)
311316

@@ -1473,21 +1478,6 @@ def process_output_sockets(
14731478
# Limit the number of buffers to reuse.
14741479
reuse_buffers.append(buffer)
14751480

1476-
def engine_finish_requests(self):
1477-
assert isinstance(self.scheduler, V1Scheduler)
1478-
engine_finish_outputs = EngineCoreOutputs()
1479-
engine_finish_outputs.engine_index = self.engine_index
1480-
for request_id in list(self.scheduler.requests.keys()):
1481-
self.scheduler.finish_requests(request_id, RequestStatus.FINISHED_ABORTED)
1482-
engine_finish_outputs.outputs.append(
1483-
EngineCoreOutput(
1484-
request_id=request_id,
1485-
finish_reason=FinishReason.ABORT,
1486-
new_token_ids=[],
1487-
)
1488-
)
1489-
self.output_queue.put((0, engine_finish_outputs))
1490-
14911481
def shutdown(self):
14921482
super().shutdown()
14931483
if self.vllm_config.fault_tolerance_config.enable_fault_tolerance:
@@ -1549,7 +1539,6 @@ def _init_data_parallel(self, vllm_config: VllmConfig):
15491539
self.dp_rank = dp_rank
15501540
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group(
15511541
vllm_config.fault_tolerance_config.gloo_comm_timeout,
1552-
vllm_config.fault_tolerance_config.enable_fault_tolerance,
15531542
)
15541543

15551544
def shutdown(self):
@@ -1658,12 +1647,13 @@ def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
16581647

16591648
return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
16601649

1661-
def reinit_dp_group_on_fault_tolerance(self):
1650+
def reinit_dp_group_on_fault_tolerance(self, new_stateless_dp_group_port: int):
16621651
stateless_destroy_torch_distributed_process_group(self.dp_group)
16631652
self.dp_group = self.vllm_config.parallel_config.stateless_init_dp_group(
16641653
self.vllm_config.fault_tolerance_config.gloo_comm_timeout,
1665-
self.vllm_config.fault_tolerance_config.enable_fault_tolerance,
1654+
new_stateless_dp_group_port,
16661655
)
1656+
self.step_counter = 0
16671657

16681658
def reinitialize_distributed(
16691659
self, reconfig_request: ReconfigureDistributedRequest

vllm/v1/engine/core_client.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -431,31 +431,35 @@ def fault_receiver(self):
431431
engine_core component. It is designed to run continuously to ensure no critical
432432
error information from the engine core is missed.
433433
"""
434-
while True:
435-
_, sender_identity, message = recv_router_dealer_message(
436-
self.fault_receiver_socket
437-
)
438-
if self.client_guard_dead:
439-
self.logger("client guard dead, stop receiving fault")
440-
break
441-
assert message is not None, (
442-
"message should not be None at fault tolerance scenario"
443-
)
434+
while not self.client_guard_dead:
435+
try:
436+
_, sender_identity, message = recv_router_dealer_message(
437+
self.fault_receiver_socket
438+
)
439+
assert message is not None, (
440+
"message should not be None at fault tolerance scenario"
441+
)
444442

445-
fault_info = FaultInfo.from_json(message)
446-
self.engine_exception_q.put_nowait(fault_info)
447-
engine_status = "Dead" if "dead" in fault_info.type else "Unhealthy"
448-
self.engine_status_dict[int(fault_info.engine_id)] = engine_status
449-
self.fault_pub_socket.send_string(
450-
f"vllm_fault|{json.dumps(self.engine_status_dict.to_dict())}"
451-
)
452-
# TODO Asynchronous issuance of pause commands and design of engine
453-
# core status
454-
# Pause healthy engines on fault.
455-
# Pause will be invoked again during fault-tolerance handling,
456-
# so it's unnecessary to track whether all engines are currently
457-
# paused.
458-
self.fault_handler.submit_fault("pause", 5, soft_pause=False)
443+
fault_info = FaultInfo.from_json(message)
444+
self.engine_exception_q.put_nowait(fault_info)
445+
engine_status = "Dead" if "dead" in fault_info.type else "Unhealthy"
446+
self.engine_status_dict[int(fault_info.engine_id)] = engine_status
447+
self.fault_pub_socket.send_string(
448+
f"vllm_fault|{json.dumps(self.engine_status_dict.to_dict())}"
449+
)
450+
451+
# Pause healthy engines on fault.
452+
# Pause will be invoked again during fault-tolerance handling,
453+
# so it's unnecessary to track whether all engines are currently
454+
# paused.
455+
self.fault_handler.submit_fault("pause", 5, soft_pause=False)
456+
except zmq.ZMQError:
457+
# Socket was closed during polling, exit loop.
458+
self.logger(
459+
"Fault receiver socket closed, stopping thread.", level="info"
460+
)
461+
break
462+
self.logger("Fault receiver thread has stopped.")
459463

460464
def shutdown_guard(self):
461465
self.client_guard_dead = True

vllm/v1/engine/utils.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from vllm.ray.ray_env import get_env_vars_to_copy
2828
from vllm.utils.collection_utils import ThreadSafeDict
2929
from vllm.utils.network_utils import (
30+
get_open_port,
3031
get_open_zmq_ipc_path,
3132
make_zmq_socket,
3233
recv_router_dealer_message,
@@ -36,7 +37,7 @@
3637
from vllm.v1.engine.coordinator import DPCoordinator
3738
from vllm.v1.engine.exceptions import FaultInfo
3839
from vllm.v1.executor import Executor
39-
from vllm.v1.serial_utils import serialize_method_call
40+
from vllm.v1.serial_utils import run_method, serialize_method_call
4041
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
4142

4243
if TYPE_CHECKING:
@@ -1399,40 +1400,44 @@ async def _dispatcher(self):
13991400
if fut:
14001401
fut.set_exception(e)
14011402

1402-
async def _handle_fault_internal(
1403-
self, instruction: str, timeout: int, **kwargs
1404-
) -> bool:
1405-
if instruction == "retry" and "Dead" in self.engine_status_dict.values():
1403+
def retry(self, **kwargs):
1404+
if "Dead" in self.engine_status_dict.values():
14061405
self.logger(
14071406
"engine_core dead unexpectedly, retry is impossible,"
14081407
"shutdown will be performed",
14091408
level="info",
14101409
)
1411-
return False
1410+
return False, set(), kwargs
1411+
1412+
target_engines = set(self.engine_identity_to_index.keys())
1413+
kwargs["new_stateless_dp_group_port"] = get_open_port()
1414+
return True, target_engines, kwargs
1415+
1416+
def pause(self, **kwargs):
1417+
self.logger(
1418+
"Pause operation is best-effort only. Due to the complexity of "
1419+
"collective communications (e.g., timing dependencies and "
1420+
"synchronization barriers), pausing may not always succeed. If "
1421+
"the process remains unresponsive or collective operations "
1422+
"cannot be interrupted, consider shutting down and restarting "
1423+
"the instance.",
1424+
level="warning",
1425+
)
14121426

1413-
if instruction == "pause":
1414-
logger.warning(
1415-
"Pause operation is best-effort only. Due to the complexity of "
1416-
"collective communications (e.g., timing dependencies and "
1417-
"synchronization barriers), pausing may not always succeed. If "
1418-
"the process remains unresponsive or collective operations "
1419-
"cannot be interrupted, consider shutting down and restarting "
1420-
"the instance."
1421-
)
1427+
alive_engines = {
1428+
identity
1429+
for identity, index in self.engine_identity_to_index.items()
1430+
if self.engine_status_dict.get(index) != "Dead"
1431+
}
1432+
return True, alive_engines, kwargs
14221433

1423-
dead_engine_indices = {
1424-
index
1425-
for index, status in self.engine_status_dict.items()
1426-
if status == "Dead"
1427-
}
1428-
1429-
target_engines = {
1430-
identity
1431-
for identity, index in self.engine_identity_to_index.items()
1432-
if index not in dead_engine_indices
1433-
}
1434-
else:
1435-
target_engines = set(self.engine_identity_to_index.keys())
1434+
async def _handle_fault_internal(
1435+
self, instruction: str, timeout: int, **kwargs
1436+
) -> bool:
1437+
success, target_engines, kwargs = run_method(self, instruction, (), kwargs)
1438+
1439+
if not success:
1440+
return False
14361441

14371442
if timeout is not None:
14381443
kwargs["timeout"] = timeout

0 commit comments

Comments
 (0)