diff --git a/run.py b/run.py index 7d6dceb2d..8934111bd 100644 --- a/run.py +++ b/run.py @@ -13,9 +13,13 @@ def build_model_from_config(cfg): import vlmeval.api import vlmeval.vlm + import vlmeval.composite + config = cp.deepcopy(cfg) assert 'class' in config cls_name = config.pop('class') + if hasattr(vlmeval.composite, cls_name): + return getattr(vlmeval.composite, cls_name)(supported_VLM, **config) if hasattr(vlmeval.api, cls_name): return getattr(vlmeval.api, cls_name)(**config) elif hasattr(vlmeval.vlm, cls_name): @@ -165,6 +169,14 @@ def main(): for _, model_name in enumerate(args.model): model = None + if use_config: + model = build_model_from_config(cfg['model'][model_name]) + if model_name == 'Prism': + fronted_name = cfg['model']['Prism']['model']['fronted']['model'] + backend_name = cfg['model']['Prism']['model']['backend']['model'] + backend_name = backend_name.replace('/', '-') + model_name = model_name + '_' + fronted_name + '_' + backend_name + date, commit_id = timestr('day'), githash(digits=8) eval_id = f"T{date}_G{commit_id}" @@ -179,9 +191,6 @@ def main(): if not osp.exists(pred_root): os.makedirs(pred_root, exist_ok=True) - if use_config: - model = build_model_from_config(cfg['model'][model_name]) - for _, dataset_name in enumerate(args.data): try: result_file_base = f'{model_name}_{dataset_name}.xlsx' @@ -281,6 +290,7 @@ def main(): model = model_name # which is only a name # Perform the Inference + if dataset.MODALITY == 'VIDEO': model = infer_data_job_video( model, diff --git a/vlmeval/composite/__init__.py b/vlmeval/composite/__init__.py new file mode 100644 index 000000000..7dc348768 --- /dev/null +++ b/vlmeval/composite/__init__.py @@ -0,0 +1,7 @@ +import torch + +# torch.set_grad_enabled(False) +# torch.manual_seed(1234) + + +from .prism import Prism diff --git a/vlmeval/composite/prism.py b/vlmeval/composite/prism.py new file mode 100644 index 000000000..42bfef166 --- /dev/null +++ b/vlmeval/composite/prism.py @@ -0,0 +1,190 @@ +import torch +import re +from vlmeval.api import OpenAIWrapper, SiliconFlowAPI +from vlmeval.utils import track_progress_rich +import os + + +# remap the gpt model name +gpt_version_map = { + 'gpt-4-0409': 'gpt-4-turbo-2024-04-09', + 'gpt-4-0125': 'gpt-4-0125-preview', + 'gpt-4-turbo': 'gpt-4-1106-preview', + 'gpt-4-0613': 'gpt-4-0613', + 'chatgpt-1106': 'gpt-3.5-turbo-1106', + 'chatgpt-0613': 'gpt-3.5-turbo-0613', + 'chatgpt-0125': 'gpt-3.5-turbo-0125', + 'gpt-4o': 'gpt-4o-2024-05-13' +} + +# # map the model name to the api type +# reasoning_mapping = { +# 'llama3-70b-chat':'silicon', +# 'Mixtral-8x22B-chat':'silicon', +# 'deepseek-ai/DeepSeek-V2-Chat':'silicon', +# } +# +# # stop_tokens for deploying vllm +# stop_tokens = { +# 'llama3-70b-chat': ["<|eot_id|>"], +# } + +mapping = {} +mapping.update(gpt_version_map) + +# mapping.update(reasoning_mapping) + +prompt_human1 = ('Describe the fine-grained content of the image, including scenes, objects,' + ' relationships, instance location, and any text present.') +prompt_human2 = ('Describe the fine-grained content of the image, including scenes, objects, ' + 'relationships, instance location, background and any text present. Please skip ' + 'generating statements for non-existent contents and describe all you see. ') +prompt_gpt1 = 'Given the image below, please provide a detailed description of what you see.' +prompt_gpt2 = 'Analyze the image below and describe the main elements and their relationship.' +prompt_cot = ('Describe the fine-grained content of the image, including scenes, objects, relationships,' + ' instance location, and any text present. Let\'s think step by step.') +prompt_decompose = ('Decompose the image into several parts and describe the fine-grained content of the ' + 'image part by part, including scenes, objects, relationships, instance location, and' + ' any text present.') + +genric_prompt_mapping = { + 'generic':prompt_human1, + 'human1':prompt_human1, + 'gpt1':prompt_gpt1, + 'gpt2':prompt_gpt2, + 'human2':prompt_human2, + 'cot': prompt_cot, + 'decompose': prompt_decompose, +} + + +class Prism(): + + def __init__(self, supported_VLM, **kwargs): + self.supported_VLM = supported_VLM + self.config = kwargs + + self.model_name_fronted = self.config['model']['fronted']['model'] + self.model_name_backend = self.config['model']['backend']['model'] + self.fronted_prompt_type = self.config['model']['fronted']['prompt_type'] + + self.model_fronted = supported_VLM[self.model_name_fronted]() if ( + isinstance(self.model_name_fronted, str)) else None + self.model_backend = Reasoning(model=self.model_name_backend) + + def set_dump_image(self, dump_image): + if hasattr(self.model_fronted, 'set_dump_image'): + self.model_fronted.set_dump_image(dump_image) + + def generate(self, message, dataset=None): + + # struct prompt + question = message[1]['value'] + prompt_fronted = self.build_fronted_prompt() + message[1]['value'] = prompt_fronted + + # generate fronted + is_api = getattr(self.model_fronted, 'is_api', False) + if is_api: + response_fronted = self.fronted_api(message=message, dataset=dataset) + else: + response_fronted = self.model_fronted.generate(message=message, dataset=dataset) + + print("----fronted output----\n" + response_fronted + "\n----backend output----") + + # generate backend + response_backend = self.model_backend.generate(question, response_fronted) + + return response_backend + + def fronted_api(self, message, dataset=None): + result = self.model_fronted.generate(message) + # gen_func = self.model_fronted.generate + # struct = {} + # struct['message'] = message + # struct['dataset'] = dataset + # result = track_progress_rich(gen_func, [struct]) + return result + + def build_fronted_prompt(self): + prompt = genric_prompt_mapping[self.fronted_prompt_type] + return prompt + + +class Reasoning: + def __init__(self, model): + self.model = LLMWrapper(model) + + def generate(self, question, des): + prompt = build_infer_prompt_external(question, des) + return self.model.generate(prompt) + + +def build_infer_prompt_external(question, des): + if not question.endswith('\n'): + question += '\n' + if not question.lower().startswith('question:') and not question.lower().startswith('hint:'): + question = 'Question: ' + question + if not des.endswith('\n'): + des += '\n' + description = 'Description: ' + des + role = ('You are an excellent text-based reasoning expert. You are required to answer the question' + ' based on the detailed description of the image.\n\n') + + prompt = role + description + question + return prompt + + +class LLMWrapper: + + def __init__(self, model_name, max_tokens=512, verbose=True, retry=5): + + # api bases, openai default + # self.deepseek_api_base = 'https://api.deepseek.com/v1/chat/completions' + + # server settings of vllm + # self.PORT = 8080 + # self.vllm_api_base = f'http://localhost:{self.PORT}/v1/chat/completions' + + self.prism_llm_api_base = os.environ['PRISM_LLM_API_BASE'] + + if model_name.endswith('-2048'): + model_name = model_name.replace('-2048', '') + max_tokens = 2048 + + if model_name in gpt_version_map: + gpt_version = gpt_version_map[model_name] + model = OpenAIWrapper(gpt_version, max_tokens=max_tokens, verbose=verbose, retry=retry) + else: + # use your api + api_key = os.environ['PRISM_LLM_API_KEY'] + model = SiliconFlowAPI(model_name, api_base=self.prism_llm_api_base, key=api_key, + system_prompt='You are a helpful assistant.', verbose=verbose, retry=retry) + # model = OpenAIWrapper(model_name, api_base=self.prism_llm_api_base, key=api_key, + # max_tokens=max_tokens, system_prompt='You are a helpful assistant.', + # verbose=verbose, retry=retry) + + # elif reasoning_mapping[model_name] == 'vllm': + # model = OpenAIWrapper(model_name, api_base=self.vllm_api_base, max_tokens=max_tokens, + # system_prompt='You are a helpful assistant.', verbose=verbose, retry=retry, + # stop=stop_tokens[model_name]) + # elif reasoning_mapping[model_name] == 'deepseek': + # deepseek_key = os.environ['SILICON_API_KEY'] + # model = OpenAIWrapper(model_name, api_base=self.deepseek_api_base, key=deepseek_key, + # max_tokens=max_tokens, system_prompt='You are a helpful assistant.', verbose=verbose, retry=retry) + # + # else: + # print('Unknown API model for inference') + + self.model = model + + def generate(self, prompt, **kwargs): + response = self.model.generate(prompt, **kwargs) + return response + + @staticmethod + def api_models(): + gpt_models = list(gpt_version_map.keys()) + api_models = gpt_models.copy() + # api_models.extend(list(reasoning_mapping.keys())) + return api_models diff --git a/vlmeval/config.py b/vlmeval/config.py index a21e2a2c7..c36a4aeec 100644 --- a/vlmeval/config.py +++ b/vlmeval/config.py @@ -357,6 +357,23 @@ 'h2ovl-mississippi-1b': partial(H2OVLChat, model_path='h2oai/h2ovl-mississippi-800m'), } +prismcationer_series = { + 'prismcaptioner-7b': partial( + LLaVA_XTuner_Wrapper, + llm_path='internlm/internlm2-chat-7b', + llava_path='Yuxuan-Qiao/PrismCaptioner-7B', + visual_select_layer=-2, + prompt_template='internlm2_chat', + visual_encoder_path='google/siglip-so400m-patch14-384'), + 'prismcaptioner-2b': partial( + LLaVA_XTuner_Wrapper, + llm_path='internlm/internlm2-chat-1_8b', + llava_path='Yuxuan-Qiao/PrismCaptioner-2B', + visual_select_layer=-2, + prompt_template='internlm2_chat', + visual_encoder_path='google/siglip-so400m-patch14-384'), +} + supported_VLM = {} model_groups = [ @@ -368,7 +385,7 @@ mantis_series, mmalaya_series, phi3_series, xgen_mm_series, qwen2vl_series, slime_series, eagle_series, moondream_series, llama_series, molmo_series, kosmos_series, points_series, nvlm_series, vintern_series, h2ovl_series, aria_series, - smolvlm_series + smolvlm_series, prismcationer_series ] for grp in model_groups: diff --git a/vlmeval/dataset/image_vqa.py b/vlmeval/dataset/image_vqa.py index 2ab3f99ba..5a3b891f9 100644 --- a/vlmeval/dataset/image_vqa.py +++ b/vlmeval/dataset/image_vqa.py @@ -11,7 +11,6 @@ from .utils import build_judge, DEBUG_MESSAGE from ..smp import * from ..utils import track_progress_rich -import ipdb class ImageVQADataset(ImageBaseDataset): diff --git a/vlmeval/vlm/__init__.py b/vlmeval/vlm/__init__.py index d112df8bd..80a2ca8bc 100644 --- a/vlmeval/vlm/__init__.py +++ b/vlmeval/vlm/__init__.py @@ -59,3 +59,4 @@ from .h2ovl_mississippi import H2OVLChat from .falcon_vlm import Falcon2VLM from .smolvlm import SmolVLM +from .prismcaptioner import LLaVA_XTuner_Wrapper diff --git a/vlmeval/vlm/prismcaptioner.py b/vlmeval/vlm/prismcaptioner.py new file mode 100644 index 000000000..bc3919844 --- /dev/null +++ b/vlmeval/vlm/prismcaptioner.py @@ -0,0 +1,230 @@ +import os +import os.path as osp +import string +import sys +import warnings + +import pandas as pd +import torch +from huggingface_hub import snapshot_download +from PIL import Image +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + CLIPImageProcessor, CLIPVisionModel, + GenerationConfig, StoppingCriteriaList, + SiglipImageProcessor, SiglipVisionModel) + +from vlmeval.vlm.base import BaseModel +from vlmeval.smp import cn_string, get_cache_path +from vlmeval.dataset import DATASET_TYPE + + +class LLaVA_XTuner_Wrapper(BaseModel): + INSTALL_REQ = True + INTERLEAVE = False + + def __init__(self, + llava_path, + llm_path=None, + visual_encoder_path='openai/clip-vit-large-patch14-336', + visual_select_layer=-2, + prompt_template=None, + stop_words=[], + torch_dtype=torch.float16, + vision_encoder_type=SiglipVisionModel, + image_processor_type=SiglipImageProcessor): + + from peft import PeftModel + from xtuner.utils import PROMPT_TEMPLATE, StopWordStoppingCriteria + + if not osp.isdir(llava_path): + cache_path = get_cache_path(llava_path) + if cache_path is not None: + llava_path = cache_path + else: + llava_path = snapshot_download(repo_id=llava_path) + assert osp.exists(llava_path) and osp.isdir(llava_path) + + # build visual_encoder + if 'llm' in os.listdir(llava_path): + assert llm_path is None, ( + "Please don't specify the `llm_path` since passed " + '`llava_path` contains a LLM!') + llm_path = osp.join(llava_path, 'llm') + else: + assert llm_path is not None, 'Please specify the `llm_path`!' + + llm = AutoModelForCausalLM.from_pretrained(llm_path, + trust_remote_code=True, + torch_dtype=torch_dtype, + device_map='cpu') + tokenizer = AutoTokenizer.from_pretrained(llm_path, + trust_remote_code=True, + encode_special_tokens=True) + print(f'Load LLM from {llm_path}') + + # build visual_encoder + if 'visual_encoder' in os.listdir(llava_path): + assert visual_encoder_path is None, ( + "Please don't specify the `visual_encoder_path` since passed " + '`llava_path` contains a visual encoder!') + visual_encoder_path = osp.join(llava_path, 'visual_encoder') + else: + assert visual_encoder_path is not None, ( + 'Please specify the `visual_encoder_path`!') + visual_encoder = vision_encoder_type.from_pretrained( + visual_encoder_path, torch_dtype=torch_dtype, device_map='cpu', trust_remote_code=True) + image_processor = image_processor_type.from_pretrained( + visual_encoder_path, trust_remote_code=True) + print(f'Load visual_encoder from {visual_encoder_path}') + + # load adapter + if 'llm_adapter' in os.listdir(llava_path): + adapter_path = osp.join(llava_path, 'llm_adapter') + llm = PeftModel.from_pretrained(llm, + adapter_path, + trust_remote_code=True, + device_map='cpu') + print(f'Load LLM adapter from {llava_path}') + if 'visual_encoder_adapter' in os.listdir(llava_path): + adapter_path = osp.join(llava_path, 'visual_encoder_adapter') + visual_encoder = PeftModel.from_pretrained(visual_encoder, + adapter_path, + trust_remote_code=True, + device_map='cpu') + print(f'Load visual_encoder adapter from {llava_path}') + + # build projector + projector_path = osp.join(llava_path, 'projector') + projector = AutoModel.from_pretrained(projector_path, + trust_remote_code=True, + torch_dtype=torch_dtype, + device_map='cpu') + print(f'Load projector from {llava_path}') + + llm.eval() + visual_encoder.eval() + projector.eval() + + self.vision_encoder_type = vision_encoder_type + self.llm = llm.cuda() + self.tokenizer = tokenizer + self.visual_encoder = visual_encoder.cuda() + self.image_processor = image_processor + self.projector = projector.cuda() + self.visual_select_layer = visual_select_layer + if prompt_template is not None: + self.prompt_template = PROMPT_TEMPLATE[prompt_template] + stop_words += self.prompt_template.get('STOP_WORDS', []) + else: + self.prompt_template = None + + self.stop_criteria = StoppingCriteriaList() + for word in stop_words: + self.stop_criteria.append( + StopWordStoppingCriteria(self.tokenizer, word)) + + def build_gen_config(self, dataset): + gen_kwargs = dict(max_new_tokens=512, + do_sample=False, + temperature=1, + num_beams=5, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id + if self.tokenizer.pad_token_id is not None else + self.tokenizer.eos_token_id) + # For single word generation + if (dataset is not None + and DATASET_TYPE(dataset) in ['multi-choice', 'Y/N']): + gen_kwargs.update( + dict(max_new_tokens=5, do_sample=False, num_beams=1)) + return GenerationConfig(**gen_kwargs) + + def use_custom_prompt(self, dataset): + assert dataset is not None + if DATASET_TYPE(dataset) == 'multi-choice': + return True + return False + + def build_prompt(self, line, dataset=None): + assert self.use_custom_prompt(dataset) + assert dataset is None or isinstance(dataset, str) + tgt_path = self.dump_image(line, dataset) + + question = line['question'] + hint = line['hint'] if ('hint' in line + and not pd.isna(line['hint'])) else None + if hint is not None: + question = hint + '\n' + question + + options = { + cand: line[cand] + for cand in string.ascii_uppercase + if cand in line and not pd.isna(line[cand]) + } + for key, item in options.items(): + question += f'\n{key}. {item}' + + if not cn_string(question): + prompt = question + '\n' + ("Answer with the option's letter " + 'from the given choices directly.') + else: + prompt = question + '\n' + '请直接回答选项字母。' + + message = [dict(type='text', value=prompt)] + message.extend([dict(type='image', value=s) for s in tgt_path]) + return message + + def generate_inner(self, message, dataset=None): + from xtuner.dataset.utils import expand2square + from xtuner.model.utils import prepare_inputs_labels_for_multimodal + from xtuner.utils import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX + prompt, image_path = self.message_to_promptimg(message) + image = Image.open(image_path).convert('RGB') + image = expand2square( + image, + tuple(int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + image = image.cuda().unsqueeze(0) + visual_outputs = self.visual_encoder(image.to(self.visual_encoder.dtype), output_hidden_states=True) + # try: + # from transformers import SiglipVisionModel + # assert vision_encoder_type is SiglipVisionModel + # pixel_values = self.projector( + # visual_outputs.hidden_states[self.visual_select_layer]) + # except: + pixel_values = self.projector( + visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) + + inputs = DEFAULT_IMAGE_TOKEN + '\n' + prompt + + if self.prompt_template: + inputs = self.prompt_template['INSTRUCTION'].format(input=inputs) + + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0: + cur_encode = self.tokenizer(chunk) + else: + cur_encode = self.tokenizer(chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode['input_ids']) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + mm_inputs = prepare_inputs_labels_for_multimodal( + llm=self.llm, input_ids=ids, pixel_values=pixel_values) + + gen_config = self.build_gen_config(dataset) + generate_output = self.llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=None, + bos_token_id=self.tokenizer.bos_token_id, + stopping_criteria=self.stop_criteria) + predict = self.tokenizer.decode(generate_output[0], + skip_special_tokens=True).strip() + return predict diff --git a/vlmeval/vlm/smolvlm.py b/vlmeval/vlm/smolvlm.py index 665a0dc3a..eeae8b9bc 100644 --- a/vlmeval/vlm/smolvlm.py +++ b/vlmeval/vlm/smolvlm.py @@ -15,6 +15,7 @@ class SmolVLM(BaseModel): def __init__(self, model_path='HuggingFaceTB/SmolVLM-Instruct', **kwargs): from transformers import AutoProcessor, Idefics3ForConditionalGeneration + from transformers.image_utils import load_image assert osp.exists(model_path) or splitlen(model_path) == 2 self.processor = AutoProcessor.from_pretrained(model_path)