diff --git a/vlmeval/api/__init__.py b/vlmeval/api/__init__.py index 10f27b195..a05dd59c4 100644 --- a/vlmeval/api/__init__.py +++ b/vlmeval/api/__init__.py @@ -13,7 +13,7 @@ from .bailingmm import bailingMMAPI from .bluelm_api import BlueLMWrapper, BlueLM_API from .jt_vl_chat import JTVLChatAPI -from .jt_vl_chat_mini import JTVLChatAPI_Mini +from .jt_vl_chat_mini import JTVLChatAPI_Mini, JTVLChatAPI_2B from .taiyi import TaiyiAPI from .lmdeploy import LMDeployAPI from .taichu import TaichuVLAPI, TaichuVLRAPI @@ -25,7 +25,7 @@ 'OpenAIWrapper', 'HFChatModel', 'GeminiWrapper', 'GPT4V', 'Gemini', 'QwenVLWrapper', 'QwenVLAPI', 'QwenAPI', 'Claude3V', 'Claude_Wrapper', 'Reka', 'GLMVisionAPI', 'CWWrapper', 'SenseChatVisionAPI', 'HunyuanVision', - 'Qwen2VLAPI', 'BlueLMWrapper', 'BlueLM_API', 'JTVLChatAPI', 'JTVLChatAPI_Mini', + 'Qwen2VLAPI', 'BlueLMWrapper', 'BlueLM_API', 'JTVLChatAPI', 'JTVLChatAPI_Mini', 'JTVLChatAPI_2B', 'bailingMMAPI', 'TaiyiAPI', 'TeleMMAPI', 'SiliconFlowAPI', 'LMDeployAPI', 'TaichuVLAPI', 'TaichuVLRAPI', 'DoubaoVL', "MUGUAPI", 'KimiVLAPIWrapper', 'KimiVLAPI' ] diff --git a/vlmeval/api/jt_vl_chat_mini.py b/vlmeval/api/jt_vl_chat_mini.py index 7e14546d5..a7a0935b0 100644 --- a/vlmeval/api/jt_vl_chat_mini.py +++ b/vlmeval/api/jt_vl_chat_mini.py @@ -1,11 +1,16 @@ +import pandas as pd +import requests +import json +import os +import base64 +from vlmeval.smp import * from vlmeval.api.base import BaseAPI from vlmeval.dataset import DATASET_TYPE from vlmeval.dataset import img_root_map -from vlmeval.smp import * -# JT-VL-Chat-mini -API_ENDPOINT = "https://hl.jiutian.10086.cn/kunlun/ingress/api/hl-4a9c15/7b11a3451e1a4612a6661c3e22235df6/ai-e7d64e71b61e421b8a41c0d43424d73a/service-cd5a067331e34cfea25f5a9d7960ffe8/infer" # noqa: E501 -APP_CODE = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiI1YjYyODY0ZjZmMWI0Yzg4YWE2ZDk1NzBhNDY1MWI3OSIsImlzcyI6ImFwaS1hdXRoLWtleSIsImV4cCI6NDg5ODU1NzAwN30.jLUbctPdJ74VC3Vwlr0nB4x9N2QxWtSGYE0vWsceZN-agDecVnH8pH8Q5SoCQ3-SBYx5jDx-UOg3kkoMqY9CdxiALauKU_UZ56CV2NKCcHUVeJIgNvfQJMb0z6yCCbSe80e1T8FxrxQXDvubyWtl4pTAhixYaEUqNG8rjUrDuA-vRgZ1e7HilBmU487OI76D9LUnU-zEdMWhzsCkh_Yy3M1Ur4PsKgMFi5QSmMuGSUGJjkpJHiGNx1QcevBLQSOCL2jvg15ifB2n2dD6zb8iPXFkfQTtmvbZofxWACSvkri-x9V3gFWg7DODwKUZsyyogPzRJVbmxDGruMsgiiCsPg" # noqa: E501 +API_ENDPOINT = "https://hl.jiutian.10086.cn/kunlun/ingress/api/hl-4a9c15/7b11a3451e1a4612a6661c3e22235df6/ai-e7d64e71b61e421b8a41c0d43424d73a/service-cd5a067331e34cfea25f5a9d7960ffe8//v1/chat/completions" +API_ENDPOINT_2B = 'https://hl.jiutian.10086.cn/kunlun/ingress/api/hl-4a9c15/7b11a3451e1a4612a6661c3e22235df6/ai-3101961c75eb47ad8cc8d8ebb62fb8b0/service-c0bf9bac00824ace8639c0e8e0a4a5da/v1/chat/completions' +APP_CODE = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiI1YjYyODY0ZjZmMWI0Yzg4YWE2ZDk1NzBhNDY1MWI3OSIsImlzcyI6ImFwaS1hdXRoLWtleSIsImV4cCI6NDg5ODU1NzAwN30.jLUbctPdJ74VC3Vwlr0nB4x9N2QxWtSGYE0vWsceZN-agDecVnH8pH8Q5SoCQ3-SBYx5jDx-UOg3kkoMqY9CdxiALauKU_UZ56CV2NKCcHUVeJIgNvfQJMb0z6yCCbSe80e1T8FxrxQXDvubyWtl4pTAhixYaEUqNG8rjUrDuA-vRgZ1e7HilBmU487OI76D9LUnU-zEdMWhzsCkh_Yy3M1Ur4PsKgMFi5QSmMuGSUGJjkpJHiGNx1QcevBLQSOCL2jvg15ifB2n2dD6zb8iPXFkfQTtmvbZofxWACSvkri-x9V3gFWg7DODwKUZsyyogPzRJVbmxDGruMsgiiCsPg" class JTVLChatWrapper(BaseAPI): @@ -13,10 +18,11 @@ class JTVLChatWrapper(BaseAPI): INTERLEAVE = False def __init__(self, - model: str = 'jt-vl-chat', + model: str = 'jt-vl-chat-mini', retry: int = 5, - api_base: str = API_ENDPOINT, - key: str = APP_CODE, + wait: int = 5, + api_base: str = '', + app_code: str = '', verbose: bool = True, system_prompt: str = None, temperature: float = 0.7, @@ -27,16 +33,13 @@ def __init__(self, self.temperature = temperature self.max_tokens = max_tokens - self.api_base = api_base - - if key is None: - key = os.environ.get('JTVLChat_API_KEY', None) - assert key is not None, ( - 'Please set the API Key (also called app_code, obtain it here: https://github.com/jiutiancv/JT-VL-Chat)' - ) + if model == 'jt-vl-chat-mini': + self.api_base = API_ENDPOINT + else: + self.api_base = API_ENDPOINT_2B + self.app_code = APP_CODE - self.key = key - super().__init__(retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs) + super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs) def dump_image(self, line, dataset): """Dump the image(s) of the input line to the corresponding dataset folder. @@ -155,37 +158,36 @@ def message_to_promptimg(self, message, dataset=None): image = [x['value'] for x in message if x['type'] == 'image'][0] return prompt, image - def get_send_data(self, prompt, image_path, temperature, max_tokens): + def get_send_data(self, prompt, image_path, temperature, max_tokens, stream=False): image = '' with open(image_path, 'rb') as f: image = str(base64.b64encode(f.read()), 'utf-8') - # send_data = { - # "messages": [ - # { - # "role": "user", - # "content": prompt - # } - # ], - # "image_base64": image, - # "max_tokens": max_tokens, - # "temperature": temperature - # } - send_data = {"prompt": prompt, "image_base64": image} + send_data = { + "messages": [ + { + "role": "user", + "content": prompt + } + ], + "image_base64": image, + "max_tokens": max_tokens, + "temperature": temperature, + "do_sample": False, + "stream": stream + } return send_data - def get_send_data_no_image(self, prompt, temperature, max_tokens): - # send_data = { - # "messages": [ - # { - # "role": "user", - # "content": prompt - # } - # ], - # "max_tokens": max_tokens, - # "temperature": temperature - # } + def get_send_data_no_image(self, prompt, temperature, max_tokens, stream=False): send_data = { - "prompt": prompt, "image_base64": "" + "messages": [ + { + "role": "user", + "content": prompt + } + ], + "max_tokens": max_tokens, + "temperature": temperature, + "stream": stream, } return send_data @@ -200,39 +202,82 @@ def generate_inner(self, inputs, **kwargs) -> str: prompt=prompt, image_path=image_path, temperature=self.temperature, - max_tokens=self.max_tokens) + max_tokens=self.max_tokens, + stream=True) else: send_data = self.get_send_data_no_image( prompt=prompt, temperature=self.temperature, - max_tokens=self.max_tokens) + max_tokens=self.max_tokens, + stream=True) json_data = json.dumps(send_data) - header_dict = {'Content-Type': 'application/json', 'Authorization': 'Bearer ' + self.key} + header_dict = {'Content-Type': 'application/json', 'Authorization': 'Bearer ' + self.app_code} - r = requests.post(self.api_base, headers=header_dict, data=json_data, timeout=3000) + r = requests.post(self.api_base, headers=header_dict, data=json_data, timeout=3000, stream=True) try: - assert r.status_code == 200 - # r_json = r.json() - # output = r_json['choices'][0]['message']['content'] - output = r.text - if self.verbose: - self.logger.info(f'inputs: {inputs}\nanswer: {output}') - - return 0, output, 'Succeeded! ' - - except: - error_msg = f'Error! code {r.status_code} content: {r.content}' - error_con = r.content.decode('utf-8') - if self.verbose: - self.logger.error(error_msg) - self.logger.error(error_con) - self.logger.error(f'The input messages are {inputs}.') - return -1, error_msg, '' + if send_data.get('stream', False): + # 流式处理 + chunks = [] + full_content = "" + + try: + for line in r.iter_lines(): + if line: + decoded_line = line.decode('utf-8') + if decoded_line.startswith('data: '): + event_data = decoded_line[6:] + if event_data == '[DONE]': + break + try: + chunk = json.loads(event_data) + chunks.append(chunk) + + # 记录最后一个有效的usage(不累加) + if 'usage' in chunk: + _ = chunk['usage'] + + # 实时输出内容 + if 'choices' in chunk: + for choice in chunk['choices']: + if 'delta' in choice and 'content' in choice['delta']: + content = choice['delta']['content'] + print(content, end='', flush=True) + full_content += content + except json.JSONDecodeError: + continue + print("\n") # 换行 + + return 0, full_content, 'Succeeded! ' + + except Exception as e: + return -1, f'Error: {str(e)}', '' + else: + # 非流式处理 + try: + r_json = r.json() + output = r_json['choices'][0]['message']['content'] + return 0, output, 'Succeeded! ' + except: + error_msg = f'Error! code {r.status_code} content: {r.content}' + error_con = r.content.decode('utf-8') + if self.verbose: + self.logger.error(error_msg) + self.logger.error(error_con) + self.logger.error(f'The input messages are {inputs}.') + return -1, error_msg, '' + except Exception as e: + return -1, f'Error: {str(e)}', '' class JTVLChatAPI_Mini(JTVLChatWrapper): def generate(self, message, dataset=None): return super(JTVLChatAPI_Mini, self).generate(message, dataset=dataset) + + +class JTVLChatAPI_2B(JTVLChatWrapper): + + def generate(self, message, dataset=None): + return super(JTVLChatAPI_2B, self).generate(message, dataset=dataset) diff --git a/vlmeval/config.py b/vlmeval/config.py index 22077b649..442a3bda4 100644 --- a/vlmeval/config.py +++ b/vlmeval/config.py @@ -462,6 +462,7 @@ # JiuTian-VL "JTVL": partial(JTVLChatAPI, model="jt-vl-chat", temperature=0, retry=10), "JTVL-Mini": partial(JTVLChatAPI_Mini, model="jt-vl-chat-mini", temperature=0, retry=10), + "JTVL-2B": partial(JTVLChatAPI_2B, model="jt-vl-chat-2b", temperature=0, retry=10), "Taiyi": partial(TaiyiAPI, model="taiyi", temperature=0, retry=10), # TeleMM "TeleMM": partial(TeleMMAPI, model="TeleAI/TeleMM", temperature=0, retry=10),