Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion lmdeploy/pytorch/engine/mp_engine/ray_engine.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from collections import defaultdict
from typing import Dict

import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from lmdeploy.messages import PytorchEngineConfig
from lmdeploy.messages import EngineOutput, PytorchEngineConfig
from lmdeploy.pytorch import envs as _envs
from lmdeploy.pytorch.ray import RayContext, get_device_str, get_resource_kwargs
from lmdeploy.utils import get_logger
Expand Down Expand Up @@ -35,6 +36,7 @@ def __init__(self,
self._stream_id = 0
self._stream_aiter = dict()
self._stream_task = dict()
self._engine_output_offset = defaultdict(int)

async def _stream_task_wrapper(self, stream_id: int, func: str, *args, **kwargs):
"""Create a stream task."""
Expand Down Expand Up @@ -67,6 +69,22 @@ async def get_stream_task_result(self, stream_id: int):
result, stopped = self._stream_aiter[stream_id][1]
event.clear()

def __return_engine_output_incrementally(result):
if not isinstance(result, EngineOutput):
return
old_offset = self._engine_output_offset[stream_id]
new_offset = len(result.token_ids)
if old_offset:
if result.token_ids:
result.token_ids = result.token_ids[old_offset:]
if result.logprobs:
result.logprobs = result.logprobs[old_offset:]
self._engine_output_offset[stream_id] = new_offset
if stopped:
self._engine_output_offset.pop(stream_id, None)

__return_engine_output_incrementally(result)

if stopped:
self._stream_aiter.pop(stream_id, None)
self._stream_task.pop(stream_id, None)
Expand All @@ -92,6 +110,7 @@ def __init__(self, model_path: str, engine_config: PytorchEngineConfig = None, *

self.worker = self._create_worker(model_path, engine_config, log_level=logger.level, **kwargs)
super().__init__()
self._engine_output = defaultdict(lambda: EngineOutput(status=None, token_ids=[], num_token=0, logprobs=[]))

def _init_ray(self, engine_config: PytorchEngineConfig = None):
"""Initialize Ray."""
Expand Down Expand Up @@ -142,9 +161,22 @@ async def _collective_rpc_streaming_async(self, func, *args, **kwargs):
# ray generator would try cache every result, which is too verbose.
stream_id = await self._collective_rpc_async('create_stream_task', func, *args, **kwargs)

def __merge_engine_output(result):
if not isinstance(result, EngineOutput):
return

output = self._engine_output[stream_id]
output.token_ids.extend(result.token_ids or [])
output.logprobs.extend(result.logprobs or [])
result.token_ids = output.token_ids or []
result.logprobs = output.logprobs or None
if stopped:
self._engine_output.pop(stream_id, None)

stopped = False
while not stopped:
result, stopped = await self._collective_rpc_async('get_stream_task_result', stream_id)
__merge_engine_output(result)
yield result

def close(self) -> None:
Expand Down