Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,17 @@ async def startup_event():
engine_config = VariableInterface.async_engine.engine.engine_config
engine_role = engine_config.role.value if hasattr(engine_config, 'role') else 1
url = f'{VariableInterface.proxy_url}/nodes/add'
data = {'url': VariableInterface.api_server_url, 'status': {'models': get_model_list(), 'role': engine_role}}
rank = os.environ.get('RANK', -1)
if rank == -1:
rank = getattr(VariableInterface.async_engine.backend_config, 'dp_rank', -1)
data = {
'url': VariableInterface.api_server_url,
'status': {
'models': get_model_list(),
'role': engine_role,
'rank': rank
}
}
headers = {'accept': 'application/json', 'Content-Type': 'application/json'}
response = requests.post(url, headers=headers, json=data)

Expand Down
35 changes: 24 additions & 11 deletions lmdeploy/serve/proxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class Status(BaseModel):
unfinished: int = 0
latency: Deque = Field(default=deque(maxlen=LATENCY_DEQUE_LEN), examples=[[]])
speed: Optional[int] = Field(default=None, examples=[None])
rank: int = -1


class Node(BaseModel):
Expand Down Expand Up @@ -165,7 +166,7 @@ def add(self, node_url: str, status: Optional[Status] = None):
status.models = client.available_models
self.nodes[node_url] = status
except requests.exceptions.RequestException as e: # noqa
logger.error(f'exception happened when adding node {node_url}, {e}')
logger.error(f'exception happened when adding node url={node_url} rank={status.rank}, {e}')
return self.handle_api_timeout(node_url)
self.update_config_file()

Expand Down Expand Up @@ -227,8 +228,9 @@ def remove_stale_nodes_by_expiration(self):
except: # noqa
to_be_deleted.append(node_url)
for node_url in to_be_deleted:
status = self.nodes[node_url]
self.remove(node_url)
logger.info(f'Removed node_url: {node_url} '
logger.info(f'Removed node_url: {node_url} rank={status.rank} unfinished={status.unfinished} '
'due to heart beat expiration')

@property
Expand Down Expand Up @@ -336,7 +338,8 @@ def handle_unavailable_model(self, model_name):

def handle_api_timeout(self, node_url):
"""Handle the api time out."""
logger.warning(f'api timeout: {node_url}')
status = node_manager.nodes.get(node_url, Status())
logger.warning(f'api timeout: url={node_url} rank={status.rank} unfinished={status.unfinished}')
ret = {
'error_code': ErrorCodes.API_TIMEOUT.value,
'text': err_msg[ErrorCodes.API_TIMEOUT],
Expand Down Expand Up @@ -586,8 +589,9 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
node_url = node_manager.get_node_url(request.model)
if not node_url:
return node_manager.handle_unavailable_model(request.model)

logger.info(f'A request is dispatched to {node_url}')
cur_status = node_manager.nodes.get(node_url, Status())
logger.info(
f'A request is dispatched to rank={cur_status.rank} url={node_url} unfinished={cur_status.unfinished}')
request_dict = request.model_dump()
start = node_manager.pre_call(node_url)
if request.stream is True:
Expand All @@ -611,7 +615,9 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
p_url = node_manager.get_node_url(request.model, EngineRole.Prefill)
if not p_url:
return node_manager.handle_unavailable_model(request.model)
logger.info(f'A Prefill request is dispatched to {p_url}')
p_status = node_manager.nodes.get(p_url, Status())
logger.info(
f'A Prefill request is dispatched to rank={p_status.rank} url={p_url} unfinished={p_status.unfinished}')

start = node_manager.pre_call(p_url)
prefill_info = json.loads(await node_manager.generate(prefill_request_dict,
Expand All @@ -624,7 +630,9 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
d_url = node_manager.get_node_url(request.model, EngineRole.Decode)
if not d_url:
return node_manager.handle_unavailable_model(request.model)
logger.info(f'A Decode request is dispatched to {d_url}')
d_status = node_manager.nodes.get(d_url, Status())
logger.info(
f'A Decode request is dispatched to rank={d_status.rank} url={d_url} unfinished={d_status.unfinished}')

if not node_manager.pd_connection_pool.is_connected(p_url, d_url):
await node_manager.pd_connection_pool.connect(
Expand Down Expand Up @@ -710,8 +718,9 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None
node_url = node_manager.get_node_url(request.model)
if not node_url:
return node_manager.handle_unavailable_model(request.model)

logger.info(f'A request is dispatched to {node_url}')
cur_status = node_manager.nodes.get(node_url, Status())
logger.info(
f'A request is dispatched to rank={cur_status.rank} url={node_url} unfinished={cur_status.unfinished}')
request_dict = request.model_dump()
start = node_manager.pre_call(node_url)
if request.stream is True:
Expand All @@ -735,7 +744,9 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None
p_url = node_manager.get_node_url(request.model, EngineRole.Prefill)
if not p_url:
return node_manager.handle_unavailable_model(request.model)
logger.info(f'A Prefill request is dispatched to {p_url}')
p_status = node_manager.nodes.get(p_url, Status())
logger.info(
f'A Prefill request is dispatched to rank={p_status.rank} url={p_url} unfinished={p_status.unfinished}')

start = node_manager.pre_call(p_url)
prefill_info = json.loads(await node_manager.generate(prefill_request_dict,
Expand All @@ -748,7 +759,9 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None
d_url = node_manager.get_node_url(request.model, EngineRole.Decode)
if not d_url:
return node_manager.handle_unavailable_model(request.model)
logger.info(f'A Decode request is dispatched to {d_url}')
d_status = node_manager.nodes.get(d_url, Status())
logger.info(
f'A Decode request is dispatched to rank={d_status.rank} url={d_url} unfinished={d_status.unfinished}')

if not node_manager.pd_connection_pool.is_connected(p_url, d_url):
await node_manager.pd_connection_pool.connect(
Expand Down
Loading