11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4- import asyncio
54import json
5+ import queue
66import threading
77import time
88from unittest .mock import AsyncMock
@@ -32,21 +32,20 @@ def create_test_thread_safe_dict(initial_data=None):
3232
3333
3434def create_client_guard (
35- engine_exception_q : asyncio .Queue , engine_status_dict : ThreadSafeDict [int , str ]
35+ engine_exception_q : queue .Queue , engine_status_dict : ThreadSafeDict [int , str ]
3636):
3737 return ClientGuard (
3838 fault_receiver_addr = FAULT_RECEIVER_ADDR ,
3939 cmd_addr = CMD_ADDR ,
4040 engine_registry = [b"engine_identity" ],
4141 engine_exception_q = engine_exception_q ,
42- engine_exception_q_lock = asyncio .Lock (),
4342 fault_pub_addr = FAULT_PUB_ADDR ,
4443 engine_status_dict = engine_status_dict ,
4544 )
4645
4746
4847def test_client_guard_initialization ():
49- engine_exception_q : asyncio .Queue [FaultInfo ] = asyncio .Queue ()
48+ engine_exception_q : queue .Queue [FaultInfo ] = queue .Queue ()
5049 engine_status_dict = create_test_thread_safe_dict ({1 : "Healthy" })
5150 guard = create_client_guard (engine_exception_q , engine_status_dict )
5251
@@ -64,7 +63,7 @@ def test_client_guard_initialization():
6463
6564@pytest .mark .asyncio
6665async def test_handle_fault ():
67- engine_exception_q : asyncio .Queue [FaultInfo ] = asyncio .Queue ()
66+ engine_exception_q : queue .Queue [FaultInfo ] = queue .Queue ()
6867 engine_status_dict = create_test_thread_safe_dict ({1 : "Healthy" })
6968 guard = create_client_guard (engine_exception_q , engine_status_dict )
7069
@@ -82,7 +81,7 @@ async def test_handle_fault():
8281
8382
8483def test_fault_receiver ():
85- engine_exception_q : asyncio .Queue [FaultInfo ] = asyncio .Queue ()
84+ engine_exception_q : queue .Queue [FaultInfo ] = queue .Queue ()
8685 engine_status_dict = create_test_thread_safe_dict ({1 : "Healthy" })
8786 guard = create_client_guard (engine_exception_q , engine_status_dict )
8887
@@ -97,7 +96,7 @@ def send_test_message():
9796 socket .close ()
9897 ctx .term ()
9998
100- sender_thread = threading .Thread (target = send_test_message )
99+ sender_thread = threading .Thread (target = send_test_message , daemon = True )
101100 sender_thread .start ()
102101
103102 def check_published_message ():
@@ -114,7 +113,7 @@ def check_published_message():
114113 assert prefix == FAULT_PUB_TOPIC
115114 assert json .loads (data ) == {"1" : "Dead" }
116115
117- check_thread = threading .Thread (target = check_published_message )
116+ check_thread = threading .Thread (target = check_published_message , daemon = True )
118117 check_thread .start ()
119118
120119 time .sleep (0.1 )
@@ -130,7 +129,7 @@ def check_published_message():
130129
131130
132131def test_fault_receiver_unhealthy ():
133- engine_exception_q : asyncio .Queue [FaultInfo ] = asyncio .Queue ()
132+ engine_exception_q : queue .Queue [FaultInfo ] = queue .Queue ()
134133 engine_status_dict = create_test_thread_safe_dict ({1 : "Healthy" })
135134 guard = create_client_guard (engine_exception_q , engine_status_dict )
136135
@@ -145,7 +144,7 @@ def send_unhealthy_message():
145144 socket .close ()
146145 ctx .term ()
147146
148- threading .Thread (target = send_unhealthy_message ).start ()
147+ threading .Thread (target = send_unhealthy_message , daemon = True ).start ()
149148 time .sleep (0.1 )
150149
151150 assert engine_status_dict [1 ] == "Unhealthy"
@@ -154,7 +153,7 @@ def send_unhealthy_message():
154153
155154
156155def test_shutdown_guard ():
157- engine_exception_q : asyncio .Queue [FaultInfo ] = asyncio .Queue ()
156+ engine_exception_q : queue .Queue [FaultInfo ] = queue .Queue ()
158157 engine_status_dict = create_test_thread_safe_dict ({1 : "Healthy" })
159158 guard = create_client_guard (engine_exception_q , engine_status_dict )
160159
@@ -181,7 +180,7 @@ def test_shutdown_guard():
181180
182181@pytest .mark .asyncio
183182async def test_handle_fault_async ():
184- engine_exception_q : asyncio .Queue [FaultInfo ] = asyncio .Queue ()
183+ engine_exception_q : queue .Queue [FaultInfo ] = queue .Queue ()
185184 engine_status_dict = create_test_thread_safe_dict ({0 : "Unhealthy" })
186185 guard = create_client_guard (engine_exception_q , engine_status_dict )
187186
@@ -190,6 +189,7 @@ async def test_handle_fault_async():
190189 cmd_socket = ctx .socket (zmq .DEALER )
191190 cmd_socket .setsockopt (zmq .IDENTITY , b"engine_identity" )
192191 cmd_socket .connect (CMD_ADDR )
192+ time .sleep (0.1 )
193193
194194 uuid = None
195195
@@ -210,13 +210,14 @@ def response_cmd(cmd_socket):
210210 execute_result = {"engine_index" : 0 , "success" : True , "method_uuid" : uuid }
211211 cmd_socket .send_multipart ([b"" , json .dumps (execute_result ).encode ("utf-8" )])
212212
213- threading .Thread (target = receive_cmd , args = (cmd_socket ,)).start ()
214- threading .Thread (target = response_cmd , args = (cmd_socket ,)).start ()
213+ threading .Thread (target = receive_cmd , args = (cmd_socket ,), daemon = True ).start ()
214+ threading .Thread (target = response_cmd , args = (cmd_socket ,), daemon = True ).start ()
215215
216216 result = await guard .handle_fault ("retry" , 3 )
217217
218218 assert result is True
219219 assert engine_status_dict [0 ] == "Healthy"
220220
221221 cmd_socket .close ()
222+ ctx .term ()
222223 guard .shutdown_guard ()
0 commit comments