|
32 | 32 | logger = get_logger('lmdeploy') |
33 | 33 |
|
34 | 34 |
|
| 35 | +def _merge_message_content(msg: Dict) -> Dict: |
| 36 | + """Merge multimodal content blocks and ensure content field exists. |
| 37 | +
|
| 38 | + This function normalizes message content to match vLLM's behavior: |
| 39 | + 1. Missing content field -> add content=None |
| 40 | + 2. None content -> keep as None |
| 41 | + 3. String content -> return as-is |
| 42 | + 4. List content (multimodal) -> merge all text blocks with newline separator |
| 43 | +
|
| 44 | + Args: |
| 45 | + msg: A message dict with 'role' and optionally 'content' field |
| 46 | +
|
| 47 | + Returns: |
| 48 | + A message dict with 'content' field guaranteed to exist |
| 49 | +
|
| 50 | + Note: |
| 51 | + This implementation is based on vLLM's content processing logic. |
| 52 | + vLLM uses "\n".join() to merge multiple text blocks from multimodal content. |
| 53 | +
|
| 54 | + References: |
| 55 | + - vLLM content normalization: |
| 56 | + https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/chat_utils.py |
| 57 | + See _parse_chat_message_content() and _parse_chat_message_content_parts() |
| 58 | + - vLLM text merging logic: |
| 59 | + text_prompt = "\n".join(texts) |
| 60 | + """ |
| 61 | + # If content is missing, add it with None value (e.g., assistant with tool_calls only) |
| 62 | + if 'content' not in msg: |
| 63 | + result = dict(msg) |
| 64 | + result['content'] = None |
| 65 | + return result |
| 66 | + |
| 67 | + # If content is None, keep it as None |
| 68 | + if msg['content'] is None: |
| 69 | + return msg |
| 70 | + |
| 71 | + # If content is already a string, return as-is |
| 72 | + if isinstance(msg['content'], str): |
| 73 | + return msg |
| 74 | + |
| 75 | + # If content is a list, merge all text blocks into a single string |
| 76 | + # This matches vLLM's behavior: text_prompt = "\n".join(texts) |
| 77 | + content_parts = [] |
| 78 | + for block in msg['content']: |
| 79 | + if isinstance(block, dict) and block.get('type') == 'text': |
| 80 | + content_parts.append(block.get('text', '')) |
| 81 | + merged_content = '\n'.join(content_parts) |
| 82 | + |
| 83 | + # Preserve all other fields in the message (e.g., tool_calls) |
| 84 | + result = dict(msg) |
| 85 | + result['content'] = merged_content |
| 86 | + return result |
| 87 | + |
| 88 | + |
35 | 89 | @dataclasses.dataclass |
36 | 90 | class GenOut: |
37 | 91 | """Pack all response information together.""" |
@@ -609,11 +663,9 @@ async def _get_prompt_input(self, |
609 | 663 | # Change multimodal data to openai text messages, i.e., |
610 | 664 | # [{'role': 'user', 'content': [{'type': 'text', 'text': 'hi'}]}] -> |
611 | 665 | # [{'role': 'user', 'content': 'hi'] |
612 | | - if isinstance(prompt, list) and any(isinstance(msg['content'], list) for msg in prompt): |
613 | | - prompt = [ |
614 | | - msg if isinstance(msg['content'], str) else dict(role=msg['role'], content=msg['content'][0]['text']) |
615 | | - for msg in prompt |
616 | | - ] |
| 666 | + # Also ensure all messages have 'content' field (set to None if missing, e.g., assistant with tool_calls) |
| 667 | + if isinstance(prompt, list): |
| 668 | + prompt = [_merge_message_content(msg) for msg in prompt] |
617 | 669 | if do_preprocess: |
618 | 670 | # use adapter's chat template if possible |
619 | 671 | chat_template = self.chat_template |
|
0 commit comments