Skip to content

Commit fae0e75

Browse files
committed
reject requests when engine is in fault status
Signed-off-by: fangyuchu <fangyuchu@qq.com>
1 parent 9485d2b commit fae0e75

File tree

5 files changed

+44
-22
lines changed

5 files changed

+44
-22
lines changed

vllm/engine/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ async def handle_fault(
171171
"""send fault tolerance instruction to the engine"""
172172
raise NotImplementedError
173173

174-
async def exception_reporter(self):
174+
async def get_fault_info(self):
175175
"""report exception from engine_core"""
176176
raise NotImplementedError
177177

vllm/entrypoints/api_server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ async def process_fault_tolerance_instruction(request: Request) -> Response:
9595
async def get_fault_info() -> Response:
9696
"""Health check."""
9797
assert engine is not None
98-
engine_exception_dict = await engine.exception_reporter()
99-
return Response(json.dumps(engine_exception_dict), status_code=200)
98+
engine_status_dict = await engine.get_fault_info()
99+
return Response(json.dumps(engine_status_dict), status_code=200)
100100

101101

102102
@with_cancellation

vllm/entrypoints/openai/api_server.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,20 @@ async def validate_json_request(raw_request: Request):
263263
)
264264

265265

266+
async def check_engine_fault(raw_request: Request):
267+
client = engine_client(raw_request)
268+
assert hasattr(client, "engine_core")
269+
core_client = client.engine_core
270+
if (
271+
hasattr(core_client, "client_sentinel")
272+
and core_client.client_sentinel.is_faulted.is_set()
273+
):
274+
raise HTTPException(
275+
status_code=503,
276+
detail="Service is in faulted state, cannot process requests.",
277+
)
278+
279+
266280
router = APIRouter()
267281

268282

@@ -395,7 +409,7 @@ async def get_server_load_metrics(request: Request):
395409

396410
@router.post(
397411
"/tokenize",
398-
dependencies=[Depends(validate_json_request)],
412+
dependencies=[Depends(validate_json_request), Depends(check_engine_fault)],
399413
responses={
400414
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
401415
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
@@ -430,7 +444,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
430444

431445
@router.post(
432446
"/detokenize",
433-
dependencies=[Depends(validate_json_request)],
447+
dependencies=[Depends(validate_json_request), Depends(check_engine_fault)],
434448
responses={
435449
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
436450
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
@@ -505,7 +519,7 @@ async def _convert_stream_to_sse_events(
505519

506520
@router.post(
507521
"/v1/responses",
508-
dependencies=[Depends(validate_json_request)],
522+
dependencies=[Depends(validate_json_request), Depends(check_engine_fault)],
509523
responses={
510524
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
511525
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
@@ -598,7 +612,7 @@ async def cancel_responses(response_id: str, raw_request: Request):
598612

599613
@router.post(
600614
"/v1/messages",
601-
dependencies=[Depends(validate_json_request)],
615+
dependencies=[Depends(validate_json_request), Depends(check_engine_fault)],
602616
responses={
603617
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
604618
HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
@@ -654,7 +668,7 @@ def translate_error_response(response: ErrorResponse) -> JSONResponse:
654668

655669
@router.post(
656670
"/v1/chat/completions",
657-
dependencies=[Depends(validate_json_request)],
671+
dependencies=[Depends(validate_json_request), Depends(check_engine_fault)],
658672
responses={
659673
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
660674
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
@@ -695,7 +709,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
695709

696710
@router.post(
697711
"/v1/completions",
698-
dependencies=[Depends(validate_json_request)],
712+
dependencies=[Depends(validate_json_request), Depends(check_engine_fault)],
699713
responses={
700714
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
701715
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
@@ -741,7 +755,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
741755

742756
@router.post(
743757
"/v1/embeddings",
744-
dependencies=[Depends(validate_json_request)],
758+
dependencies=[Depends(validate_json_request), Depends(check_engine_fault)],
745759
responses={
746760
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
747761
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
@@ -784,7 +798,7 @@ async def create_embedding(
784798

785799
@router.post(
786800
"/pooling",
787-
dependencies=[Depends(validate_json_request)],
801+
dependencies=[Depends(validate_json_request), Depends(check_engine_fault)],
788802
responses={
789803
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
790804
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
@@ -820,7 +834,10 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
820834
assert_never(generator)
821835

822836

823-
@router.post("/classify", dependencies=[Depends(validate_json_request)])
837+
@router.post(
838+
"/classify",
839+
dependencies=[Depends(validate_json_request), Depends(check_engine_fault)],
840+
)
824841
@with_cancellation
825842
@load_aware_call
826843
async def create_classify(request: ClassificationRequest, raw_request: Request):
@@ -849,7 +866,7 @@ async def create_classify(request: ClassificationRequest, raw_request: Request):
849866

850867
@router.post(
851868
"/score",
852-
dependencies=[Depends(validate_json_request)],
869+
dependencies=[Depends(validate_json_request), Depends(check_engine_fault)],
853870
responses={
854871
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
855872
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
@@ -882,7 +899,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
882899

883900
@router.post(
884901
"/v1/score",
885-
dependencies=[Depends(validate_json_request)],
902+
dependencies=[Depends(validate_json_request), Depends(check_engine_fault)],
886903
responses={
887904
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
888905
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
@@ -979,7 +996,7 @@ async def create_translations(
979996

980997
@router.post(
981998
"/rerank",
982-
dependencies=[Depends(validate_json_request)],
999+
dependencies=[Depends(validate_json_request), Depends(check_engine_fault)],
9831000
responses={
9841001
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
9851002
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
@@ -1011,7 +1028,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
10111028

10121029
@router.post(
10131030
"/v1/rerank",
1014-
dependencies=[Depends(validate_json_request)],
1031+
dependencies=[Depends(validate_json_request), Depends(check_engine_fault)],
10151032
responses={
10161033
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
10171034
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
@@ -1030,7 +1047,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
10301047

10311048
@router.post(
10321049
"/v2/rerank",
1033-
dependencies=[Depends(validate_json_request)],
1050+
dependencies=[Depends(validate_json_request), Depends(check_engine_fault)],
10341051
responses={
10351052
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
10361053
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
@@ -1303,8 +1320,8 @@ async def get_fault_info(
13031320
raw_request: Request,
13041321
):
13051322
client = engine_client(raw_request)
1306-
engine_exception_dict = await client.exception_reporter()
1307-
return JSONResponse(content=engine_exception_dict)
1323+
engine_status_dict = await client.get_fault_info()
1324+
return JSONResponse(content=engine_status_dict)
13081325

13091326

13101327
# NOTE: Construct the TypeAdapters only once

vllm/v1/engine/async_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ async def handle_fault(
785785
"""send fault tolerance instruction to the engine"""
786786
return await self.engine_core.handle_fault(instruction, timeout, **kwargs)
787787

788-
async def exception_reporter(self):
788+
async def get_fault_info(self):
789789
"""report exception in engine core"""
790790
return await self.engine_core.fault_reporter()
791791

vllm/v1/engine/core_client.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import multiprocessing
77
import queue
88
import sys
9+
import threading
910
import time
1011
import uuid
1112
import weakref
@@ -355,6 +356,7 @@ def __init__(
355356
fault_pub_addr: str,
356357
engine_status_dict: ThreadSafeDict[int, str],
357358
):
359+
self.is_faulted = threading.Event()
358360
self.engine_registry = engine_registry
359361
self.zmq_ctx = zmq.Context()
360362
self.fault_receiver_socket = make_zmq_socket(
@@ -411,12 +413,15 @@ async def handle_fault(self, instruction: str, timeout: int, **kwargs) -> bool:
411413
to handle system anomalies, ensuring stable operation or graceful degradation
412414
of the relevant components.
413415
"""
414-
return await run_method(
416+
result = await run_method(
415417
self.fault_handler,
416418
"handle_fault",
417419
args=(instruction, timeout),
418420
kwargs=kwargs,
419421
)
422+
if result:
423+
self.is_faulted.clear()
424+
return result
420425

421426
def fault_receiver(self):
422427
"""
@@ -443,7 +448,7 @@ def fault_receiver(self):
443448
self.fault_pub_socket.send_string(
444449
f"vllm_fault|{json.dumps(self.engine_status_dict.to_dict())}"
445450
)
446-
451+
self.is_faulted.set()
447452
# Pause healthy engines on fault.
448453
# Pause can be invoked again during fault-tolerance handling,
449454
# so it's unnecessary to track whether all engines are currently

0 commit comments

Comments
 (0)