diff --git a/lmdeploy/pytorch/engine/mp_engine/ray_engine.py b/lmdeploy/pytorch/engine/mp_engine/ray_engine.py index 0d133abfad..cfc7a53a37 100644 --- a/lmdeploy/pytorch/engine/mp_engine/ray_engine.py +++ b/lmdeploy/pytorch/engine/mp_engine/ray_engine.py @@ -1,11 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio +from collections import defaultdict from typing import Dict import ray from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from lmdeploy.messages import PytorchEngineConfig +from lmdeploy.messages import EngineOutput, PytorchEngineConfig from lmdeploy.pytorch import envs as _envs from lmdeploy.pytorch.ray import RayContext, get_device_str, get_resource_kwargs from lmdeploy.utils import get_logger @@ -35,6 +36,7 @@ def __init__(self, self._stream_id = 0 self._stream_aiter = dict() self._stream_task = dict() + self._engine_output_offset = defaultdict(int) async def _stream_task_wrapper(self, stream_id: int, func: str, *args, **kwargs): """Create a stream task.""" @@ -67,6 +69,22 @@ async def get_stream_task_result(self, stream_id: int): result, stopped = self._stream_aiter[stream_id][1] event.clear() + def __return_engine_output_incrementally(result): + if not isinstance(result, EngineOutput): + return + old_offset = self._engine_output_offset[stream_id] + new_offset = len(result.token_ids) + if old_offset: + if result.token_ids: + result.token_ids = result.token_ids[old_offset:] + if result.logprobs: + result.logprobs = result.logprobs[old_offset:] + self._engine_output_offset[stream_id] = new_offset + if stopped: + self._engine_output_offset.pop(stream_id, None) + + __return_engine_output_incrementally(result) + if stopped: self._stream_aiter.pop(stream_id, None) self._stream_task.pop(stream_id, None) @@ -92,6 +110,7 @@ def __init__(self, model_path: str, engine_config: PytorchEngineConfig = None, * self.worker = self._create_worker(model_path, engine_config, log_level=logger.level, **kwargs) super().__init__() + self._engine_output = defaultdict(lambda: EngineOutput(status=None, token_ids=[], num_token=0, logprobs=[])) def _init_ray(self, engine_config: PytorchEngineConfig = None): """Initialize Ray.""" @@ -142,9 +161,22 @@ async def _collective_rpc_streaming_async(self, func, *args, **kwargs): # ray generator would try cache every result, which is too verbose. stream_id = await self._collective_rpc_async('create_stream_task', func, *args, **kwargs) + def __merge_engine_output(result): + if not isinstance(result, EngineOutput): + return + + output = self._engine_output[stream_id] + output.token_ids.extend(result.token_ids or []) + output.logprobs.extend(result.logprobs or []) + result.token_ids = output.token_ids or [] + result.logprobs = output.logprobs or None + if stopped: + self._engine_output.pop(stream_id, None) + stopped = False while not stopped: result, stopped = await self._collective_rpc_async('get_stream_task_result', stream_id) + __merge_engine_output(result) yield result def close(self) -> None: