diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c92fed507a6d..ffd0a0b497aa 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1090,6 +1090,8 @@ title: InstructBlipVideo - local: model_doc/internvl title: InternVL + - local: model_doc/internvl_flash + title: InternVLFlash - local: model_doc/janus title: Janus - local: model_doc/kosmos-2 diff --git a/docs/source/en/model_doc/internvl_flash.md b/docs/source/en/model_doc/internvl_flash.md new file mode 100644 index 000000000000..209da9743130 --- /dev/null +++ b/docs/source/en/model_doc/internvl_flash.md @@ -0,0 +1,74 @@ + +*This model was released on {release_date} and added to Hugging Face Transformers on 2025-11-25.* + + +# [InternVLFlash] + +## Overview + +The [InternVLFlash] model was proposed in []() by . + + +The abstract from the paper is the following: + + + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + +## Usage examples + + + +## InternvlFlashVisionConfig + +[[autodoc]] InternvlFlashVisionConfig + +## InternvlFlashConfig + +[[autodoc]] InternvlFlashConfig + +## InternvlFlashVisionPreTrainedModel + +[[autodoc]] InternvlFlashVisionPreTrainedModel + - forward + +## InternvlFlashVisionModel + +[[autodoc]] InternvlFlashVisionModel + - forward + +## InternvlFlashPreTrainedModel + +[[autodoc]] InternvlFlashPreTrainedModel + - forward + +## InternvlFlashModel + +[[autodoc]] InternvlFlashModel + - forward + +## InternvlFlashForConditionalGeneration + +[[autodoc]] InternvlFlashForConditionalGeneration diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 3534ce6719d0..47374e7e6a49 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -176,6 +176,7 @@ from .instructblip import * from .instructblipvideo import * from .internvl import * + from .internvl_flash import * from .jamba import * from .janus import * from .jetmoe import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 9a3b2ec5ecc2..80da971f8d79 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -214,6 +214,8 @@ ("instructblip", "InstructBlipConfig"), ("instructblipvideo", "InstructBlipVideoConfig"), ("internvl", "InternVLConfig"), + ("internvl_flash", "InternvlFlashConfig"), + ("internvl_flash_vision", "InternvlFlashVisionConfig"), ("internvl_vision", "InternVLVisionConfig"), ("jamba", "JambaConfig"), ("janus", "JanusConfig"), @@ -663,6 +665,8 @@ ("instructblip", "InstructBLIP"), ("instructblipvideo", "InstructBlipVideo"), ("internvl", "InternVL"), + ("internvl_flash", "InternVLFlash"), + ("internvl_flash_vision", "InternVLFlashVision"), ("internvl_vision", "InternVLVision"), ("jamba", "Jamba"), ("janus", "Janus"), @@ -994,6 +998,7 @@ ("rt_detr_resnet", "rt_detr"), ("granitevision", "llava_next"), ("internvl_vision", "internvl"), + ("internvl_flash_vision", "internvl_flash"), ("qwen2_5_vl_text", "qwen2_5_vl"), ("qwen2_vl_text", "qwen2_vl"), ("qwen3_vl_text", "qwen3_vl"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 257fb95fdea7..4cdd73d6392e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -215,6 +215,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("instructblip", "InstructBlipModel"), ("instructblipvideo", "InstructBlipVideoModel"), ("internvl", "InternVLModel"), + ("internvl_flash", "InternvlFlashModel"), + ("internvl_flash_vision", "InternvlFlashVisionModel"), ("internvl_vision", "InternVLVisionModel"), ("jamba", "JambaModel"), ("janus", "JanusModel"), @@ -1040,6 +1042,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("idefics3", "Idefics3ForConditionalGeneration"), ("instructblip", "InstructBlipForConditionalGeneration"), ("internvl", "InternVLForConditionalGeneration"), + ("internvl_flash", "InternvlFlashForConditionalGeneration"), ("janus", "JanusForConditionalGeneration"), ("kosmos-2", "Kosmos2ForConditionalGeneration"), ("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"), diff --git a/src/transformers/models/internvl_flash/__init__.py b/src/transformers/models/internvl_flash/__init__.py new file mode 100644 index 000000000000..ca1e2daa424f --- /dev/null +++ b/src/transformers/models/internvl_flash/__init__.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_internvl_flash import * + from .modeling_internvl_flash import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/internvl_flash/configuration_internvl_flash.py b/src/transformers/models/internvl_flash/configuration_internvl_flash.py new file mode 100644 index 000000000000..4f2d571deaf9 --- /dev/null +++ b/src/transformers/models/internvl_flash/configuration_internvl_flash.py @@ -0,0 +1,230 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/internvl_flash/modular_internvl_flash.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_internvl_flash.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PreTrainedConfig +from ..auto import CONFIG_MAPPING, AutoConfig + + +class InternvlFlashVisionConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternvlFlashVisionModel`]. It is used to instantiate an InternvlFlashVisionModel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield + a similar configuration to that of the InternvlFlash3-1B. + e.g. [OpenGVLab/InternvlFlash3-1B-hf](https://huggingface.co/OpenGVLab/InternvlFlash3-1B-hf) + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the queries, keys and values. + use_qk_norm (`bool`, *optional*, defaults to `False`): + Whether to apply normalization to the queries and keys before the attention operation. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for attention weights. + projection_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for the projection layer. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The type of normalization to use in the encoder. Can be `"layer_norm"` or `"rms_norm"`. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int` or `list[int]`, *optional*, defaults to `[448, 448]`): + The size (resolution) of each image. + patch_size (`int` or `list[int]`, *optional*, defaults to `[14, 14]`): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + use_mask_token (`bool`, *optional*, defaults to `False`): + Whether to use a mask token for masked image modeling. + use_absolute_position_embeddings (`bool`, *optional*, defaults to `True`): + Whether to use BERT-style absolute position embeddings. + layer_scale_init_value (`float`, *optional*, defaults to 0.1): + Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale. + use_mean_pooling (`bool`, *optional*, defaults to `True`): + Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the + CLS token, before applying the classification head. + + Example: + + ```python + >>> from transformers import InternvlFlashVisionConfig, InternvlFlashVisionModel + + >>> # Initializing a InternvlFlashVisionModel OpenGVLab/InternvlFlash3-1B-hf style configuration + >>> configuration = InternvlFlashVisionConfig() + + >>> # Initializing a model (with random weights) from the OpenGVLab/InternvlFlash3-1B-hf configuration + >>> model = InternvlFlashVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "internvl_flash_vision" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + attention_bias=False, + use_qk_norm=False, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_dropout=0.0, + projection_dropout=0.0, + initializer_range=0.02, + norm_type="layer_norm", + layer_norm_eps=1e-06, + image_size=[448, 448], + patch_size=[14, 14], + num_channels=3, + use_mask_token=False, + use_absolute_position_embeddings=True, + layer_scale_init_value=0.1, + use_mean_pooling=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_bias = attention_bias + self.use_qk_norm = use_qk_norm + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_dropout = attention_dropout + self.projection_dropout = projection_dropout + self.initializer_range = initializer_range + self.norm_type = norm_type + self.layer_norm_eps = layer_norm_eps + + image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size) + self.image_size = image_size + self.patch_size = patch_size + + self.num_channels = num_channels + self.use_mask_token = use_mask_token + self.use_absolute_position_embeddings = use_absolute_position_embeddings + self.layer_scale_init_value = layer_scale_init_value + self.use_mean_pooling = use_mean_pooling + + +class InternvlFlashConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternvlFlashForConditionalGeneration`]. It is used to instantiate a + InternvlFlash model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of InternvlFlash3-1B. + e.g. [OpenGVLab/InternvlFlash3-1B-hf](https://huggingface.co/OpenGVLab/InternvlFlash3-1B-hf) + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `InternVisonConfig`): + The config object or dictionary of the vision backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`): + The config object or dictionary of the text backbone. + image_token_id (`int`, *optional*, defaults to 151667): + The image token index to encode the image prompt. + image_seq_length (`int`, *optional*, defaults to 256): + Number of image tokens to use per image patch. + downsample_ratio (`float`, *optional*, defaults to 0.5): + Factor by which to downsample the image. + projector_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the projector. + vision_feature_layer (`int`, *optional*, defaults to -1): + The index of the layer to use as the image features. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. + + ```python + >>> from transformers import InternvlFlashForConditionalGeneration, InternvlFlashConfig + + >>> # Initializing a InternvlFlash style configuration + >>> configuration = InternvlFlashConfig() + + >>> # Initializing a model (with random weights) from the OpenGVLab/InternvlFlash3-1B-hf configuration + >>> model = InternvlFlashForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "internvl_flash" + sub_configs = {"text_config": AutoConfig, "vision_config": InternvlFlashVisionConfig} + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_id=151667, + image_seq_length=256, + downsample_ratio=0.5, + projector_hidden_act="gelu", + vision_feature_layer=-1, + vision_feature_select_strategy="default", + **kwargs, + ): + self.image_token_id = image_token_id + self.image_seq_length = image_seq_length + self.downsample_ratio = downsample_ratio + self.projector_hidden_act = projector_hidden_act + self.vision_feature_layer = vision_feature_layer + self.vision_feature_select_strategy = vision_feature_select_strategy + + if isinstance(vision_config, dict): + self.vision_config = InternvlFlashVisionConfig(**vision_config) + elif isinstance(vision_config, InternvlFlashVisionConfig): + self.vision_config = vision_config + elif vision_config is None: + self.vision_config = InternvlFlashVisionConfig() + + if isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "qwen2") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["qwen2"]() + + self.text_config = text_config + + super().__init__(**kwargs) + + +__all__ = ["InternvlFlashConfig", "InternvlFlashVisionConfig"] diff --git a/src/transformers/models/internvl_flash/modeling_internvl_flash.py b/src/transformers/models/internvl_flash/modeling_internvl_flash.py new file mode 100644 index 000000000000..63db67bb4113 --- /dev/null +++ b/src/transformers/models/internvl_flash/modeling_internvl_flash.py @@ -0,0 +1,1324 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/internvl_flash/modular_internvl_flash.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_internvl_flash.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections.abc +import math +from collections.abc import Callable +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_int +from ...utils.generic import check_model_inputs +from ..auto import AutoModel +from .configuration_internvl_flash import InternvlFlashConfig, InternvlFlashVisionConfig + + +class InternvlFlashMLP(nn.Module): + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.dense_in = nn.Linear(in_dim, out_dim) + self.act_fn = nn.GELU() + self.dropout_in = nn.Dropout(dropout) + self.dense_out = nn.Linear(out_dim, in_dim) + self.dropout_out = nn.Dropout(dropout) + self.norm = nn.LayerNorm(in_dim) + + def forward(self, hidden_states): + hidden_states = self.dense_in(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.dropout_in(hidden_states) + hidden_states = self.dense_out(hidden_states) + hidden_states = self.dropout_out(hidden_states) + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class InternvlFlashMLP2(nn.Module): + def __init__(self, vit_hidden_size, llm_hidden_size, config): + super().__init__() + + in_dim = vit_hidden_size * int(1 / config.downsample_ratio) ** 4 + mid_dim = llm_hidden_size * 2 + out_dim = llm_hidden_size + self.norm = nn.LayerNorm(in_dim) + self.dense1 = nn.Linear(in_dim, mid_dim) + self.act_fn1 = nn.GELU() + self.dropout1 = nn.Dropout(0.1) + self.dense2 = nn.Linear(mid_dim, mid_dim) + self.act_fn2 = nn.GELU() + self.dropout2 = nn.Dropout(0.1) + self.dense3 = nn.Linear(mid_dim, out_dim) + + def forward(self, hidden_states): + hidden_states = self.norm(hidden_states) + hidden_states = self.dense1(hidden_states) + hidden_states = self.act_fn1(hidden_states) + hidden_states = self.dropout1(hidden_states) + hidden_states = self.dense2(hidden_states) + hidden_states = self.act_fn2(hidden_states) + hidden_states = self.dropout2(hidden_states) + hidden_states = self.dense3(hidden_states) + + return hidden_states + + +class InternvlFlashGating(nn.Module): + def __init__(self, hidden_size=2048, expansion_factor=4, dropout=0.1, use_checkpoint=True): + super().__init__() + self.use_checkpoint = use_checkpoint + mid_dim = hidden_size * expansion_factor + + self.block1 = InternvlFlashMLP(hidden_size, mid_dim) + self.block2 = InternvlFlashMLP(hidden_size, mid_dim) + self.block3 = InternvlFlashMLP(hidden_size, mid_dim) + self.block4 = InternvlFlashMLP(hidden_size, mid_dim) + self.gate_norm = nn.LayerNorm(hidden_size) + self.gate_proj = nn.Linear(hidden_size, 2) + + def forward(self, x): + x = x + self.block1(x) + x = x + self.block2(x) + x = x + self.block3(x) + x = x + self.block4(x) + logits = self.gate_proj(self.gate_norm(x)) + probs = torch.softmax(logits, dim=-1) # 每δΈͺ token ηš„ expert ι€‰ζ‹©ζ¦‚ηŽ‡ + return probs + + +class InternvlFlashTextAttention(nn.Module): + """Equivalent implementation of nn.MultiheadAttention with `batch_first=True`.""" + + # Ignore copy + def __init__(self, hidden_size, num_attention_heads, dropout): + super().__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"The hidden size ({hidden_size}) is not a multiple of the number of attention " + f"heads ({num_attention_heads})" + ) + + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + + self.out_proj = nn.Linear(hidden_size, hidden_size) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + queries: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + batch_size, seq_length, _ = queries.shape + query_layer = ( + self.query(queries) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(keys).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + value_layer = ( + self.value(values).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in InternvlFlashTextModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + context_layer = self.out_proj(context_layer) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class InternvlFlashCrossAttentionPooling(nn.Module): + def __init__(self, dim, num_heads=16): + super().__init__() + self.query_token = nn.Parameter(torch.randn(1, dim)) # [1, D] + self.attn1 = InternvlFlashTextAttention(hidden_size=dim, num_attention_heads=num_heads, dropout=0.0) + self.norm1 = nn.LayerNorm(dim) + self.attn2 = InternvlFlashTextAttention(hidden_size=dim, num_attention_heads=num_heads, dropout=0.0) + self.norm2 = nn.LayerNorm(dim) + self.attn3 = InternvlFlashTextAttention(hidden_size=dim, num_attention_heads=num_heads, dropout=0.0) + self.norm3 = nn.LayerNorm(dim) + self.attn4 = InternvlFlashTextAttention(hidden_size=dim, num_attention_heads=num_heads, dropout=0.0) + self.norm4 = nn.LayerNorm(dim) + + def forward(self, batched_tokens: list[torch.Tensor]): + """ + batched_tokens: List of Tensors of shape [Ti, D], length = B + """ + B = len(batched_tokens) + if B == 0: + return torch.empty( + 0, self.query_token.shape[-1], device=self.query_token.device, dtype=self.query_token.dtype + ) + + D = batched_tokens[0].shape[-1] + device = batched_tokens[0].device + # 1. Padding + max_len = max(t.shape[0] for t in batched_tokens) + dtype = self.query_token.dtype + padded = torch.zeros(B, max_len, D, dtype=dtype, device=device) + padding_mask = torch.ones(B, max_len, dtype=torch.bool, device=device) + for i, t in enumerate(batched_tokens): + L = t.shape[0] + padded[i, :L] = t + padding_mask[i, :L] = False + # 2. Query token: [B, 1, D] + query = self.query_token.unsqueeze(0).expand(B, -1, -1) # learnable token for each sample + + attention_mask = torch.zeros_like(padding_mask, dtype=query.dtype) + min_value = torch.finfo(query.dtype).min + attention_mask.masked_fill_(padding_mask, min_value) + + # 3. Adjust Attention Score: [B, Num_Heads, Q_Len, K_Len] + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) + + # 4. Attention layers + out1 = self.attn1(query, padded, padded, attention_mask=attention_mask)[0] # [B, 1, D] + out1 = self.norm1(out1) + out2 = self.attn2(out1, padded, padded, attention_mask=attention_mask)[0] # [B, 1, D] + out2 = self.norm2(out2) + out3 = self.attn3(out2, padded, padded, attention_mask=attention_mask)[0] # [B, 1, D] + out3 = self.norm3(out3) + out4 = self.attn4(out3, padded, padded, attention_mask=attention_mask)[0] # [B, 1, D] + out4 = self.norm4(out4) + return out4.squeeze(1) + + +class InternvlFlashMultiModalProjector(nn.Module): + def __init__(self, config: InternvlFlashConfig): + super().__init__() + self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2, config.text_config.hidden_size + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size) + + def forward(self, image_features): + hidden_states = self.layer_norm(image_features) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for InternvlFlash outputs, with hidden states and attentions. + """ +) +class InternvlFlashModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + +@use_kernel_forward_from_hub("RMSNorm") +class InternvlFlashVisionRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + InternvlFlashVisionRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = key + value_states = value + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # No upcasting of the attention weights to float32 in this implementation + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class InternvlFlashVisionAttention(nn.Module): + """Attention Class for InternvlFlash Vision Encoder""" + + def __init__(self, config: InternvlFlashVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + proj_dropout = config.projection_dropout + qk_norm = config.use_qk_norm + + # Needed for flash attention + self.is_causal = False + + self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) + self.projection_layer = nn.Linear(self.embed_dim, self.embed_dim) + self.projection_dropout = nn.Dropout(proj_dropout) if proj_dropout > 0 else nn.Identity() + + self.q_norm = InternvlFlashVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity() + self.k_norm = InternvlFlashVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + batch_size, seq_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scale, + is_causal=False, + **kwargs, + ) + attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim) + + output = self.projection_layer(attn_output) + output = self.projection_dropout(output) + + return output, attn_weights + + +class InternvlFlashVisionPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.patch_shape = patch_shape + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + + embeddings = self.projection(pixel_values) + patch_height, patch_width = embeddings.shape[2], embeddings.shape[3] + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, (patch_height, patch_width) + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +class InternvlFlashVisionEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + + """ + + def __init__(self, config: InternvlFlashVisionConfig) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + if config.use_mask_token: + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + else: + self.mask_token = None + self.patch_embeddings = InternvlFlashVisionPatchEmbeddings(config) + self.patch_size = config.patch_size + self.image_size = ( + config.image_size + if isinstance(config.image_size, collections.abc.Iterable) + else (config.image_size, config.image_size) + ) + num_patches = self.patch_embeddings.num_patches + if config.use_absolute_position_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + else: + self.position_embeddings = None + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size[0] + new_width = width // self.patch_size[1] + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + ) -> torch.Tensor: + _, _, height, width = pixel_values.shape + embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values) + batch_size, seq_len, _ = embeddings.size() + + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1 - w) + mask_tokens * w + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + if self.position_embeddings is not None: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + embeddings = self.dropout(embeddings) + + return embeddings, (patch_height, patch_width) + + +class InternvlFlashVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternvlFlashVisionRMSNorm} + + +class InternvlFlashVisionLayer(GradientCheckpointingLayer): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: InternvlFlashVisionConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = InternvlFlashVisionAttention(config) + self.mlp = InternvlFlashVisionMLP(config) + # InternvlFlash uses different layernorm implementations for different models + self.layernorm_before = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps) + + init_values = config.layer_scale_init_value + self.lambda_1 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True) + self.lambda_2 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + attention_output, _ = self.attention( + self.layernorm_before(hidden_states), # in InternvlFlashVision, layernorm is applied before self-attention + ) + + attention_output = self.lambda_1 * attention_output + + # first residual connection + hidden_states = attention_output + hidden_states + + # in InternvlFlashVision, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + + layer_output = self.mlp(layer_output) + layer_output = self.dropout(layer_output) + + if self.lambda_2 is not None: + layer_output = self.lambda_2 * layer_output + + # second residual connection + layer_output = layer_output + hidden_states + + return layer_output + + +@auto_docstring +class InternvlFlashVisionPreTrainedModel(PreTrainedModel): + config: InternvlFlashVisionConfig + base_model_prefix = "internvl_flash_vision" + main_input_name = "pixel_values" + input_modalities = ["image", "video"] + supports_gradient_checkpointing = True + _no_split_modules = ["InternvlFlashVisionLayer"] + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + + _can_record_outputs = { + "hidden_states": InternvlFlashVisionLayer, + "attentions": InternvlFlashVisionAttention, + } + + def _init_weights(self, module): + """Initialize the weights""" + super()._init_weights(module) + if isinstance(module, InternvlFlashVisionEmbeddings): + module.cls_token.data.zero_() + if module.mask_token is not None: + module.mask_token.data.zero_() + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, InternvlFlashVisionLayer): + module.lambda_1.data.fill_(self.config.layer_scale_init_value) + module.lambda_2.data.fill_(self.config.layer_scale_init_value) + + +@dataclass +@auto_docstring( + custom_intro=""" + Class for outputs of [`InternvlFlashVisionModel`]. + """ +) +class InternvlFlashVisionModelOutputWithPooling(BaseModelOutputWithPooling): + r""" + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if + *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token + will be returned. + """ + + +class InternvlFlashVisionEncoder(nn.Module): + def __init__(self, config: InternvlFlashVisionConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([InternvlFlashVisionLayer(config) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + ) -> Union[tuple, BaseModelOutput]: + for layer_module in self.layer: + hidden_states = layer_module(hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + ) + + +@auto_docstring +class InternvlFlashVisionModel(InternvlFlashVisionPreTrainedModel): + def __init__(self, config: InternvlFlashVisionConfig) -> None: + super().__init__(config) + self.config = config + + self.embeddings = InternvlFlashVisionEmbeddings(config) + self.encoder = InternvlFlashVisionEncoder(config) + + self.layernorm = ( + nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + ) -> Union[tuple, InternvlFlashVisionModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + + encoder_outputs = self.encoder(embedding_output) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + return InternvlFlashVisionModelOutputWithPooling( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring +class InternvlFlashPreTrainedModel(PreTrainedModel): + config: InternvlFlashConfig + base_model_prefix = "" + input_modalities = ["image", "text", "video"] + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_flex_attn = True + _supports_attention_backend = True + + +@auto_docstring( + custom_intro=""" + The InternvlFlash model which consists of a vision backbone and a language model, without a language modeling head. + """ +) +class InternvlFlashModel(InternvlFlashPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + + def __init__(self, config: InternvlFlashConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + self.multi_modal_projector = InternvlFlashMultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) + + vit_hidden_size = config.vision_config.hidden_size + self.pooling_before_gating = InternvlFlashCrossAttentionPooling(dim=vit_hidden_size) + self.gating = InternvlFlashGating(hidden_size=vit_hidden_size) + + llm_hidden_size = config.text_config.hidden_size + self.mlp2 = InternvlFlashMLP2(vit_hidden_size, llm_hidden_size, config) + self.flash_relative_threshold = config.flash_relative_threshold + self.flash_absolute_threshold = config.flash_absolute_threshold + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + lengths: torch.Tensor, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_layer (`int` or `list[int]`): + Layer index or list of layer indices to extract features from. + Returns: + vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`. + """ + vit_embeds_1024 = self.vision_tower(pixel_values=pixel_values).last_hidden_state + + vit_embeds_1024 = vit_embeds_1024[:, 1:, :] + h = w = int(vit_embeds_1024.shape[1] ** 0.5) + vit_embeds_1024 = vit_embeds_1024.reshape(vit_embeds_1024.shape[0], h, w, -1) + + # begin moe + lengths = [int(x) for x in lengths.tolist()] + tile_splits = torch.split(vit_embeds_1024, lengths, dim=0) + vit_embeds_1024_split_and_merge = [x.reshape(-1, x.shape[-1]) for x in tile_splits] + + gate = self.pooling_before_gating(vit_embeds_1024_split_and_merge) + gate = self.gating(gate) + + vit_embeds_256 = vit_embeds_1024.clone() + + vit_embeds_64 = self.pixel_shuffle(vit_embeds_1024, scale_factor=self.config.downsample_ratio**2) + vit_embeds_64 = vit_embeds_64.reshape(vit_embeds_64.shape[0], -1, vit_embeds_64.shape[-1]) + vit_embeds_64 = self.mlp2(vit_embeds_64) + + vit_embeds_256 = self.pixel_shuffle(vit_embeds_256, scale_factor=self.config.downsample_ratio) + vit_embeds_256 = vit_embeds_256.reshape(vit_embeds_256.shape[0], -1, vit_embeds_256.shape[-1]) + vit_embeds_256 = self.multi_modal_projector(vit_embeds_256) + + relative_threshold_value = torch.quantile(gate[:, 0].to(torch.float32), self.flash_relative_threshold) + gate_mask = (gate[:, 0] > relative_threshold_value) & (gate[:, 0] >= self.flash_absolute_threshold) + + selected_embeds = [] + for i in range(gate_mask.size(0)): + prob = gate[i, 0] + + if gate_mask[i]: + feat = vit_embeds_64[i] + else: + feat = vit_embeds_256[i] + + if self.training: + feat = feat + prob - prob.detach() # straight through estimator for backpropagation + + selected_embeds.append(feat) + + vit_embeds = torch.cat(selected_embeds, dim=0) + + if self.training: + vit_embeds = vit_embeds + 0.0 * vit_embeds_64.sum() + 0.0 * vit_embeds_256.sum() + + return vit_embeds, gate_mask + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, InternvlFlashModelOutputWithPast]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + # image feature is vit embeds + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None and input_ids is not None: + lengths, starts, batch_indices = self.get_image_num_per_sample(input_ids) + lengths_copy = lengths.clone() + lengths = lengths // 256 + + lengths_sum = torch.ones(int(lengths.sum().item()), dtype=torch.int64) + vit_embeds, gate_result = self.get_image_features(pixel_values, lengths_sum) + + B, N, C = inputs_embeds.shape + inputs_embeds = inputs_embeds.reshape(B * N, C) + + input_ids = input_ids.reshape(B * N) + + global_starts = starts + (batch_indices * N) + + inputs_embeds, input_ids, keep_mask = self.compress_visual_tokens_in_sentence( + input_embeds=inputs_embeds, + input_ids=input_ids, + img_context_token_id=self.config.image_token_id, + gate_result=gate_result, + lengths=lengths_copy, + starts=global_starts, + ) + if isinstance(attention_mask, dict): # add support for StaticCache + attention_mask = attention_mask["full_attention"] + + if attention_mask is not None: + if attention_mask.dim() > 2 or attention_mask.numel() != input_ids.numel(): + pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else 0 + attention_mask = input_ids.ne(pad_token_id) + else: + attention_mask = attention_mask.reshape(B * N) + + attention_mask = attention_mask[keep_mask].to(inputs_embeds.device) + + inputs_embeds = self._scatter_image_embeddings( + inputs_embeds=inputs_embeds, + input_ids=input_ids, + vit_embeds=vit_embeds, + ) + + inputs_embeds, attention_mask = self._reconstruct_batch( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + gate_result=gate_result, + lengths=lengths, + batch_indices=batch_indices, + N=N, + B=B, + ) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + return InternvlFlashModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=inputs_embeds if pixel_values is not None else None, + ) + + def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5): + """Perform pixel shuffle downsampling on vision features. + + Args: + vision_features (`torch.Tensor`): + Input tensor of shape (batch_size, width, height, channels). + scale_factor (`float`, *optional*, defaults to `0.5`): + Factor by which to downsample. Default is 0.5, which halves the dimensions. + + Returns: + vision_features (`torch.Tensor`): + Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)). + """ + batch_size, width, height, channels = vision_features.size() + + if height % scale_factor != 0 or width % scale_factor != 0: + raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.") + + # Reshape to allow downsampling + vision_features = vision_features.view( + batch_size, width, int(height * scale_factor), int(channels / scale_factor) + ) + # Permute dimensions to align downsampled axis correctly + vision_features = vision_features.permute(0, 2, 1, 3).contiguous() + + # Reshape to achieve final downsampled dimensions + vision_features = vision_features.view( + batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2)) + ) + + # Swap height and width back for proper orientation + vision_features = vision_features.permute(0, 2, 1, 3).contiguous() + + return vision_features + + def compress_visual_tokens_in_sentence( + self, + input_embeds: torch.Tensor, + input_ids: torch.Tensor, + lengths: torch.Tensor, + starts: torch.Tensor, + img_context_token_id: int, + gate_result, + ) -> tuple: + N, C = input_embeds.shape + + keep_mask = torch.ones(N, dtype=torch.bool, device=input_embeds.device) + + delete_flags = torch.zeros(N, dtype=torch.int32, device=input_embeds.device) + + if (lengths % 256 != 0).any(): + raise ValueError(f"lengths % 256 != 0, lengths = {lengths}") + + block_counts = lengths // 256 + total_blocks = block_counts.sum() + + starts_expended = torch.repeat_interleave(starts, block_counts) + global_range = torch.arange(total_blocks, device=input_embeds.device) + cumsum_blocks = torch.cumsum(block_counts, dim=0) + group_starts = torch.cat([torch.zeros(1, dtype=torch.long, device=lengths.device), cumsum_blocks[:-1]]) + local_block_indices = global_range - torch.repeat_interleave(group_starts, block_counts) + + all_block_starts = starts_expended + (local_block_indices * 256) + + compressed_starts = all_block_starts[gate_result] + + if compressed_starts.numel() > 0: + offsets = torch.arange(64, 256, device=input_embeds.device) + indices_to_remove = (compressed_starts.unsqueeze(1) + offsets.unsqueeze(0)).view(-1) + + keep_mask[indices_to_remove] = False + delete_flags[indices_to_remove] = 1 + + new_input_embeds = input_embeds[keep_mask.to(input_embeds.device), :] + new_input_ids = input_ids[keep_mask.to(input_ids.device)] + + return new_input_embeds, new_input_ids, keep_mask + + def get_image_num_per_sample( + self, + input_ids: torch.Tensor, + ): + if input_ids is None: + raise ValueError("input_ids cannot be None when pixel_values are provided. ") + if input_ids.dim() == 1: + input_ids = input_ids.squeeze(0) # (N,) #todo add batch size support + selected = input_ids == self.config.image_token_id + + padded = F.pad(selected.int(), (1, 1), value=0) + diff = torch.diff(padded, dim=1) + + starts_coords = (diff == 1).nonzero(as_tuple=False) + ends_coords = (diff == -1).nonzero(as_tuple=False) + + batch_indices = starts_coords[:, 0] + + starts = starts_coords[:, 1] + ends = ends_coords[:, 1] + lengths = ends - starts + + return lengths, starts, batch_indices + + def _scatter_image_embeddings( + self, + inputs_embeds: torch.Tensor, + input_ids: torch.Tensor, + vit_embeds: torch.Tensor, + ) -> torch.Tensor: + selected_mask = input_ids == self.config.image_token_id + if selected_mask.sum() == 0: + return inputs_embeds + inputs_embeds[selected_mask] = vit_embeds.to(inputs_embeds.device) + return inputs_embeds + + def _reconstruct_batch( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor, + gate_result: torch.Tensor, + lengths: torch.Tensor, + batch_indices: torch.Tensor, + N: int, + B: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Reconstructs the batch by removing compressed visual tokens based on gating results. + """ + device = inputs_embeds.device + + gate_result_split = torch.split(gate_result, lengths.tolist()) + + compressed_blocks_per_image = torch.tensor( + [g.sum().item() for g in gate_result_split], device=device, dtype=torch.int64 + ) + + compressed_blocks_per_sample = torch.zeros(B, dtype=torch.int64, device=device) + compressed_blocks_per_sample.index_add_(0, batch_indices, compressed_blocks_per_image) + + tokens_removed = compressed_blocks_per_sample * 192 + new_lengths = N - tokens_removed + + max_len = int(new_lengths.max().item()) + hidden_dim = inputs_embeds.shape[-1] + + out_embeds = torch.zeros((B, max_len, hidden_dim), dtype=inputs_embeds.dtype, device=device) + out_mask = torch.zeros((B, max_len), dtype=attention_mask.dtype, device=device) + + split_embeds = torch.split(inputs_embeds, new_lengths.tolist()) + split_masks = torch.split(attention_mask, new_lengths.tolist()) + + for i, (emb, mask, length) in enumerate(zip(split_embeds, split_masks, new_lengths)): + L = int(length.item()) + out_embeds[i, :L] = emb + out_mask[i, :L] = mask + + return out_embeds, out_mask + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for InternvlFlash causal language model (or autoregressive) outputs. + """ +) +class InternvlFlashCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +@auto_docstring( + custom_intro=""" + The INTERNVL_FLASH model which consists of a vision backbone and a language model. + """ +) +class InternvlFlashForConditionalGeneration(InternvlFlashPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: InternvlFlashConfig): + super().__init__(config) + self.model = InternvlFlashModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + **kwargs, + ): + return self.model.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + **kwargs, + ) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, InternvlFlashCausalLMOutputWithPast]: + r""" + Example: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, AutoModelForImageTextToText + + >>> torch_device = "cuda" + >>> processor = AutoProcessor.from_pretrained("chenhaoguan/InternVL3_5-2B-Flash-hf") + >>> model = AutoModelForImageTextToText.from_pretrained( + ... "chenhaoguan/InternVL3_5-2B-Flash-hf", dtype=torch.bfloat16, device_map=torch_device + ... ) + + >>> messages = [ + ... { + ... "role": "user", + ... "content": [ + ... { + ... "type": "image", + ... "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", + ... }, + ... { + ... "type": "image", + ... "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg", + ... }, + ... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"}, + ... ], + ... }, + ... ] + + >>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device) + >>> generate_ids = model.generate(**inputs, max_new_tokens=200) + >>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)) + The images depict the Statue of Liberty and the Golden Gate Bridge. + ```""" + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + cache_position=cache_position, + image_sizes=image_sizes, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return InternvlFlashCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + return model_inputs + + +__all__ = [ + "InternvlFlashVisionPreTrainedModel", + "InternvlFlashVisionModel", + "InternvlFlashPreTrainedModel", + "InternvlFlashModel", + "InternvlFlashForConditionalGeneration", +] diff --git a/src/transformers/models/internvl_flash/modular_internvl_flash.py b/src/transformers/models/internvl_flash/modular_internvl_flash.py new file mode 100644 index 000000000000..27af21383734 --- /dev/null +++ b/src/transformers/models/internvl_flash/modular_internvl_flash.py @@ -0,0 +1,580 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...cache_utils import Cache +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ..internvl.configuration_internvl import InternVLConfig, InternVLVisionConfig +from ..internvl.modeling_internvl import ( + InternVLModel, + InternVLModelOutputWithPast, + InternVLMultiModalProjector, + InternVLPreTrainedModel, + InternVLVisionModel, + InternVLVisionPreTrainedModel, +) +from ..llava.modeling_llava import ( + LlavaForConditionalGeneration, +) +from ..zoedepth.modeling_zoedepth import ZoeDepthMultiheadAttention + + +class InternvlFlashMLP(nn.Module): + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.dense_in = nn.Linear(in_dim, out_dim) + self.act_fn = nn.GELU() + self.dropout_in = nn.Dropout(dropout) + self.dense_out = nn.Linear(out_dim, in_dim) + self.dropout_out = nn.Dropout(dropout) + self.norm = nn.LayerNorm(in_dim) + + def forward(self, hidden_states): + hidden_states = self.dense_in(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.dropout_in(hidden_states) + hidden_states = self.dense_out(hidden_states) + hidden_states = self.dropout_out(hidden_states) + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class InternvlFlashMLP2(nn.Module): + def __init__(self, vit_hidden_size, llm_hidden_size, config): + super().__init__() + + in_dim = vit_hidden_size * int(1 / config.downsample_ratio) ** 4 + mid_dim = llm_hidden_size * 2 + out_dim = llm_hidden_size + self.norm = nn.LayerNorm(in_dim) + self.dense1 = nn.Linear(in_dim, mid_dim) + self.act_fn1 = nn.GELU() + self.dropout1 = nn.Dropout(0.1) + self.dense2 = nn.Linear(mid_dim, mid_dim) + self.act_fn2 = nn.GELU() + self.dropout2 = nn.Dropout(0.1) + self.dense3 = nn.Linear(mid_dim, out_dim) + + def forward(self, hidden_states): + hidden_states = self.norm(hidden_states) + hidden_states = self.dense1(hidden_states) + hidden_states = self.act_fn1(hidden_states) + hidden_states = self.dropout1(hidden_states) + hidden_states = self.dense2(hidden_states) + hidden_states = self.act_fn2(hidden_states) + hidden_states = self.dropout2(hidden_states) + hidden_states = self.dense3(hidden_states) + + return hidden_states + + +class InternvlFlashGating(nn.Module): + def __init__(self, hidden_size=2048, expansion_factor=4, dropout=0.1, use_checkpoint=True): + super().__init__() + self.use_checkpoint = use_checkpoint + mid_dim = hidden_size * expansion_factor + + self.block1 = InternvlFlashMLP(hidden_size, mid_dim) + self.block2 = InternvlFlashMLP(hidden_size, mid_dim) + self.block3 = InternvlFlashMLP(hidden_size, mid_dim) + self.block4 = InternvlFlashMLP(hidden_size, mid_dim) + self.gate_norm = nn.LayerNorm(hidden_size) + self.gate_proj = nn.Linear(hidden_size, 2) + + def forward(self, x): + x = x + self.block1(x) + x = x + self.block2(x) + x = x + self.block3(x) + x = x + self.block4(x) + logits = self.gate_proj(self.gate_norm(x)) + probs = torch.softmax(logits, dim=-1) # 每δΈͺ token ηš„ expert ι€‰ζ‹©ζ¦‚ηŽ‡ + return probs + + +class InternvlFlashTextAttention(ZoeDepthMultiheadAttention): + pass + + +class InternvlFlashCrossAttentionPooling(nn.Module): + def __init__(self, dim, num_heads=16): + super().__init__() + self.query_token = nn.Parameter(torch.randn(1, dim)) # [1, D] + self.attn1 = InternvlFlashTextAttention(hidden_size=dim, num_attention_heads=num_heads, dropout=0.0) + self.norm1 = nn.LayerNorm(dim) + self.attn2 = InternvlFlashTextAttention(hidden_size=dim, num_attention_heads=num_heads, dropout=0.0) + self.norm2 = nn.LayerNorm(dim) + self.attn3 = InternvlFlashTextAttention(hidden_size=dim, num_attention_heads=num_heads, dropout=0.0) + self.norm3 = nn.LayerNorm(dim) + self.attn4 = InternvlFlashTextAttention(hidden_size=dim, num_attention_heads=num_heads, dropout=0.0) + self.norm4 = nn.LayerNorm(dim) + + def forward(self, batched_tokens: list[torch.Tensor]): + """ + batched_tokens: List of Tensors of shape [Ti, D], length = B + """ + B = len(batched_tokens) + if B == 0: + return torch.empty( + 0, self.query_token.shape[-1], device=self.query_token.device, dtype=self.query_token.dtype + ) + + D = batched_tokens[0].shape[-1] + device = batched_tokens[0].device + # 1. Padding + max_len = max(t.shape[0] for t in batched_tokens) + dtype = self.query_token.dtype + padded = torch.zeros(B, max_len, D, dtype=dtype, device=device) + padding_mask = torch.ones(B, max_len, dtype=torch.bool, device=device) + for i, t in enumerate(batched_tokens): + L = t.shape[0] + padded[i, :L] = t + padding_mask[i, :L] = False + # 2. Query token: [B, 1, D] + query = self.query_token.unsqueeze(0).expand(B, -1, -1) # learnable token for each sample + + attention_mask = torch.zeros_like(padding_mask, dtype=query.dtype) + min_value = torch.finfo(query.dtype).min + attention_mask.masked_fill_(padding_mask, min_value) + + # 3. Adjust Attention Score: [B, Num_Heads, Q_Len, K_Len] + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) + + # 4. Attention layers + out1 = self.attn1(query, padded, padded, attention_mask=attention_mask)[0] # [B, 1, D] + out1 = self.norm1(out1) + out2 = self.attn2(out1, padded, padded, attention_mask=attention_mask)[0] # [B, 1, D] + out2 = self.norm2(out2) + out3 = self.attn3(out2, padded, padded, attention_mask=attention_mask)[0] # [B, 1, D] + out3 = self.norm3(out3) + out4 = self.attn4(out3, padded, padded, attention_mask=attention_mask)[0] # [B, 1, D] + out4 = self.norm4(out4) + return out4.squeeze(1) + + +class InternvlFlashVisionConfig(InternVLVisionConfig): + pass + + +class InternvlFlashConfig(InternVLConfig): + pass + + +class InternvlFlashMultiModalProjector(InternVLMultiModalProjector): + pass + + +class InternvlFlashModelOutputWithPast(InternVLModelOutputWithPast): + pass + + +class InternvlFlashVisionPreTrainedModel(InternVLVisionPreTrainedModel): + pass + + +@auto_docstring +class InternvlFlashVisionModel(InternVLVisionModel): + pass + + +class InternvlFlashPreTrainedModel(InternVLPreTrainedModel): + pass + + +class InternvlFlashModel(InternVLModel): + def __init__(self, config: InternvlFlashConfig): + super().__init__(config) + + vit_hidden_size = config.vision_config.hidden_size + self.pooling_before_gating = InternvlFlashCrossAttentionPooling(dim=vit_hidden_size) + self.gating = InternvlFlashGating(hidden_size=vit_hidden_size) + + llm_hidden_size = config.text_config.hidden_size + self.multi_modal_projector = InternvlFlashMultiModalProjector(config) + self.mlp2 = InternvlFlashMLP2(vit_hidden_size, llm_hidden_size, config) + self.flash_relative_threshold = config.flash_relative_threshold + self.flash_absolute_threshold = config.flash_absolute_threshold + + def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5): + """Perform pixel shuffle downsampling on vision features. + + Args: + vision_features (`torch.Tensor`): + Input tensor of shape (batch_size, width, height, channels). + scale_factor (`float`, *optional*, defaults to `0.5`): + Factor by which to downsample. Default is 0.5, which halves the dimensions. + + Returns: + vision_features (`torch.Tensor`): + Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)). + """ + batch_size, width, height, channels = vision_features.size() + + if height % scale_factor != 0 or width % scale_factor != 0: + raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.") + + # Reshape to allow downsampling + vision_features = vision_features.view( + batch_size, width, int(height * scale_factor), int(channels / scale_factor) + ) + # Permute dimensions to align downsampled axis correctly + vision_features = vision_features.permute(0, 2, 1, 3).contiguous() + + # Reshape to achieve final downsampled dimensions + vision_features = vision_features.view( + batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2)) + ) + + # Swap height and width back for proper orientation + vision_features = vision_features.permute(0, 2, 1, 3).contiguous() + + return vision_features + + def compress_visual_tokens_in_sentence( + self, + input_embeds: torch.Tensor, + input_ids: torch.Tensor, + lengths: torch.Tensor, + starts: torch.Tensor, + img_context_token_id: int, + gate_result, + ) -> tuple: + N, C = input_embeds.shape + + keep_mask = torch.ones(N, dtype=torch.bool, device=input_embeds.device) + + delete_flags = torch.zeros(N, dtype=torch.int32, device=input_embeds.device) + + if (lengths % 256 != 0).any(): + raise ValueError(f"lengths % 256 != 0, lengths = {lengths}") + + block_counts = lengths // 256 + total_blocks = block_counts.sum() + + starts_expended = torch.repeat_interleave(starts, block_counts) + global_range = torch.arange(total_blocks, device=input_embeds.device) + cumsum_blocks = torch.cumsum(block_counts, dim=0) + group_starts = torch.cat([torch.zeros(1, dtype=torch.long, device=lengths.device), cumsum_blocks[:-1]]) + local_block_indices = global_range - torch.repeat_interleave(group_starts, block_counts) + + all_block_starts = starts_expended + (local_block_indices * 256) + + compressed_starts = all_block_starts[gate_result] + + if compressed_starts.numel() > 0: + offsets = torch.arange(64, 256, device=input_embeds.device) + indices_to_remove = (compressed_starts.unsqueeze(1) + offsets.unsqueeze(0)).view(-1) + + keep_mask[indices_to_remove] = False + delete_flags[indices_to_remove] = 1 + + new_input_embeds = input_embeds[keep_mask.to(input_embeds.device), :] + new_input_ids = input_ids[keep_mask.to(input_ids.device)] + + return new_input_embeds, new_input_ids, keep_mask + + def get_image_num_per_sample( + self, + input_ids: torch.Tensor, + ): + if input_ids is None: + raise ValueError("input_ids cannot be None when pixel_values are provided. ") + if input_ids.dim() == 1: + input_ids = input_ids.squeeze(0) # (N,) #todo add batch size support + selected = input_ids == self.config.image_token_id + + padded = F.pad(selected.int(), (1, 1), value=0) + diff = torch.diff(padded, dim=1) + + starts_coords = (diff == 1).nonzero(as_tuple=False) + ends_coords = (diff == -1).nonzero(as_tuple=False) + + batch_indices = starts_coords[:, 0] + + starts = starts_coords[:, 1] + ends = ends_coords[:, 1] + lengths = ends - starts + + return lengths, starts, batch_indices + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + lengths: torch.Tensor, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_layer (`int` or `list[int]`): + Layer index or list of layer indices to extract features from. + Returns: + vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`. + """ + vit_embeds_1024 = self.vision_tower(pixel_values=pixel_values).last_hidden_state + + vit_embeds_1024 = vit_embeds_1024[:, 1:, :] + h = w = int(vit_embeds_1024.shape[1] ** 0.5) + vit_embeds_1024 = vit_embeds_1024.reshape(vit_embeds_1024.shape[0], h, w, -1) + + # begin moe + lengths = [int(x) for x in lengths.tolist()] + tile_splits = torch.split(vit_embeds_1024, lengths, dim=0) + vit_embeds_1024_split_and_merge = [x.reshape(-1, x.shape[-1]) for x in tile_splits] + + gate = self.pooling_before_gating(vit_embeds_1024_split_and_merge) + gate = self.gating(gate) + + vit_embeds_256 = vit_embeds_1024.clone() + + vit_embeds_64 = self.pixel_shuffle(vit_embeds_1024, scale_factor=self.config.downsample_ratio**2) + vit_embeds_64 = vit_embeds_64.reshape(vit_embeds_64.shape[0], -1, vit_embeds_64.shape[-1]) + vit_embeds_64 = self.mlp2(vit_embeds_64) + + vit_embeds_256 = self.pixel_shuffle(vit_embeds_256, scale_factor=self.config.downsample_ratio) + vit_embeds_256 = vit_embeds_256.reshape(vit_embeds_256.shape[0], -1, vit_embeds_256.shape[-1]) + vit_embeds_256 = self.multi_modal_projector(vit_embeds_256) + + relative_threshold_value = torch.quantile(gate[:, 0].to(torch.float32), self.flash_relative_threshold) + gate_mask = (gate[:, 0] > relative_threshold_value) & (gate[:, 0] >= self.flash_absolute_threshold) + + selected_embeds = [] + for i in range(gate_mask.size(0)): + prob = gate[i, 0] + + if gate_mask[i]: + feat = vit_embeds_64[i] + else: + feat = vit_embeds_256[i] + + if self.training: + feat = feat + prob - prob.detach() # straight through estimator for backpropagation + + selected_embeds.append(feat) + + vit_embeds = torch.cat(selected_embeds, dim=0) + + if self.training: + vit_embeds = vit_embeds + 0.0 * vit_embeds_64.sum() + 0.0 * vit_embeds_256.sum() + + return vit_embeds, gate_mask + + def _scatter_image_embeddings( + self, + inputs_embeds: torch.Tensor, + input_ids: torch.Tensor, + vit_embeds: torch.Tensor, + ) -> torch.Tensor: + selected_mask = input_ids == self.config.image_token_id + if selected_mask.sum() == 0: + return inputs_embeds + inputs_embeds[selected_mask] = vit_embeds.to(inputs_embeds.device) + return inputs_embeds + + def _reconstruct_batch( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor, + gate_result: torch.Tensor, + lengths: torch.Tensor, + batch_indices: torch.Tensor, + N: int, + B: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Reconstructs the batch by removing compressed visual tokens based on gating results. + """ + device = inputs_embeds.device + + gate_result_split = torch.split(gate_result, lengths.tolist()) + + compressed_blocks_per_image = torch.tensor( + [g.sum().item() for g in gate_result_split], device=device, dtype=torch.int64 + ) + + compressed_blocks_per_sample = torch.zeros(B, dtype=torch.int64, device=device) + compressed_blocks_per_sample.index_add_(0, batch_indices, compressed_blocks_per_image) + + tokens_removed = compressed_blocks_per_sample * 192 + new_lengths = N - tokens_removed + + max_len = int(new_lengths.max().item()) + hidden_dim = inputs_embeds.shape[-1] + + out_embeds = torch.zeros((B, max_len, hidden_dim), dtype=inputs_embeds.dtype, device=device) + out_mask = torch.zeros((B, max_len), dtype=attention_mask.dtype, device=device) + + split_embeds = torch.split(inputs_embeds, new_lengths.tolist()) + split_masks = torch.split(attention_mask, new_lengths.tolist()) + + for i, (emb, mask, length) in enumerate(zip(split_embeds, split_masks, new_lengths)): + L = int(length.item()) + out_embeds[i, :L] = emb + out_mask[i, :L] = mask + + return out_embeds, out_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, InternvlFlashModelOutputWithPast]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + # image feature is vit embeds + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None and input_ids is not None: + lengths, starts, batch_indices = self.get_image_num_per_sample(input_ids) + lengths_copy = lengths.clone() + lengths = lengths // 256 + + lengths_sum = torch.ones(int(lengths.sum().item()), dtype=torch.int64) + vit_embeds, gate_result = self.get_image_features(pixel_values, lengths_sum) + + B, N, C = inputs_embeds.shape + inputs_embeds = inputs_embeds.reshape(B * N, C) + + input_ids = input_ids.reshape(B * N) + + global_starts = starts + (batch_indices * N) + + inputs_embeds, input_ids, keep_mask = self.compress_visual_tokens_in_sentence( + input_embeds=inputs_embeds, + input_ids=input_ids, + img_context_token_id=self.config.image_token_id, + gate_result=gate_result, + lengths=lengths_copy, + starts=global_starts, + ) + if isinstance(attention_mask, dict): # add support for StaticCache + attention_mask = attention_mask["full_attention"] + + if attention_mask is not None: + if attention_mask.dim() > 2 or attention_mask.numel() != input_ids.numel(): + pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else 0 + attention_mask = input_ids.ne(pad_token_id) + else: + attention_mask = attention_mask.reshape(B * N) + + attention_mask = attention_mask[keep_mask].to(inputs_embeds.device) + + inputs_embeds = self._scatter_image_embeddings( + inputs_embeds=inputs_embeds, + input_ids=input_ids, + vit_embeds=vit_embeds, + ) + + inputs_embeds, attention_mask = self._reconstruct_batch( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + gate_result=gate_result, + lengths=lengths, + batch_indices=batch_indices, + N=N, + B=B, + ) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + return InternvlFlashModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=inputs_embeds if pixel_values is not None else None, + ) + + +class InternvlFlashForConditionalGeneration(LlavaForConditionalGeneration): + def __init__(self, config: InternvlFlashConfig): + super(LlavaForConditionalGeneration, self).__init__(config) + self.model = InternvlFlashModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def forward(**super_kwargs): + r""" + Example: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, AutoModelForImageTextToText + + >>> torch_device = "cuda" + >>> processor = AutoProcessor.from_pretrained("chenhaoguan/InternVL3_5-2B-Flash-hf") + >>> model = AutoModelForImageTextToText.from_pretrained( + ... "chenhaoguan/InternVL3_5-2B-Flash-hf", dtype=torch.bfloat16, device_map=torch_device + ... ) + + >>> messages = [ + ... { + ... "role": "user", + ... "content": [ + ... { + ... "type": "image", + ... "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", + ... }, + ... { + ... "type": "image", + ... "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg", + ... }, + ... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"}, + ... ], + ... }, + ... ] + + >>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device) + >>> generate_ids = model.generate(**inputs, max_new_tokens=200) + >>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)) + The images depict the Statue of Liberty and the Golden Gate Bridge. + ```""" + super().forward(**super_kwargs) + + +__all__ = [ + "InternvlFlashConfig", + "InternvlFlashVisionConfig", + "InternvlFlashVisionPreTrainedModel", + "InternvlFlashVisionModel", + "InternvlFlashPreTrainedModel", + "InternvlFlashModel", + "InternvlFlashForConditionalGeneration", +] diff --git a/tests/models/internvl_flash/__init__.py b/tests/models/internvl_flash/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/internvl_flash/test_modeling_internvl_flash.py b/tests/models/internvl_flash/test_modeling_internvl_flash.py new file mode 100644 index 000000000000..6683c68245bc --- /dev/null +++ b/tests/models/internvl_flash/test_modeling_internvl_flash.py @@ -0,0 +1,225 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch InternvlFlash model.""" + +import unittest + +import pytest + +from transformers import ( + InternvlFlashConfig, + is_torch_available, +) +from transformers.testing_utils import ( + require_torch, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import InternvlFlashForConditionalGeneration, InternvlFlashModel + + +class InternvlFlashVisionText2TextModelTester: + def __init__( + self, + parent, + batch_size=1, + seq_length=7, + image_seq_length=256, + vision_feature_layer=-1, + ignore_index=-100, + image_token_id=1, + num_channels=3, + image_size=128, + model_type="internvl_flash", + is_training=True, + text_config={ + "model_type": "qwen2", + "vocab_size": 99, + "hidden_size": 128, + "intermediate_size": 37, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "output_channels": 64, + "hidden_act": "silu", + "max_position_embeddings": 512, + "rope_theta": 10000, + "mlp_ratio": 4, + "tie_word_embeddings": True, + "bos_token_id": 3, + "eos_token_id": 4, + "pad_token_id": 5, + }, + vision_config={ + "hidden_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 128, + "image_size": (128, 128), + "patch_size": (4, 4), + "num_channels": 3, + "hidden_act": "quick_gelu", + "use_absolute_position_embeddings": True, + }, + ): + self.parent = parent + self.ignore_index = ignore_index + self.bos_token_id = text_config["bos_token_id"] + self.eos_token_id = text_config["eos_token_id"] + self.pad_token_id = text_config["pad_token_id"] + self.image_token_id = image_token_id + self.model_type = model_type + self.text_config = text_config + self.vision_config = vision_config + self.batch_size = batch_size + self.vision_feature_layer = vision_feature_layer + self.is_training = is_training + self.image_seq_length = image_seq_length + self.num_channels = num_channels + self.image_size = image_size + self.seq_length = seq_length + image_seq_length + + self.num_hidden_layers = text_config["num_hidden_layers"] + self.vocab_size = text_config["vocab_size"] + self.hidden_size = text_config["hidden_size"] + self.num_attention_heads = text_config["num_attention_heads"] + + self.flash_absolute_threshold = 0.5 + self.flash_relative_threshold = 0.1 + self.text_seq_len = 15 + + def get_config(self): + return InternvlFlashConfig( + text_config=self.text_config, + vision_config=self.vision_config, + model_type=self.model_type, + image_token_id=self.image_token_id, + image_seq_length=self.image_seq_length, + vision_feature_layer=self.vision_feature_layer, + flash_relative_threshold=self.flash_relative_threshold, + flash_absolute_threshold=self.flash_absolute_threshold, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]).to( + torch_device + ) + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + input_ids[input_ids == self.image_token_id] = self.pad_token_id + input_ids[:, : self.image_seq_length] = self.image_token_id + + attention_mask = torch.ones_like(input_ids) + + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + def create_and_check_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask): + model = InternvlFlashForConditionalGeneration(config=config) + model.to(torch_device) + model.half() + model.eval() + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values.to(torch.bfloat16), + return_dict=True, + )["logits"] + self.parent.assertFalse(torch.isnan(logits).any().item()) + + def create_and_check_model_fp16_autocast_forward(self, config, input_ids, pixel_values, attention_mask): + config.dtype = torch.float16 + model = InternvlFlashForConditionalGeneration(config=config) + model.to(torch_device) + model.eval() + with torch.autocast(device_type=torch_device, dtype=torch.float16): + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values.to(torch.bfloat16), + return_dict=True, + )["logits"] + self.parent.assertFalse(torch.isnan(logits).any().item()) + + +@require_torch +class InternvlFlashModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (InternvlFlashForConditionalGeneration, InternvlFlashModel) if is_torch_available() else () + all_generative_model_classes = (InternvlFlashForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "image-text-to-text": InternvlFlashForConditionalGeneration, + } + if is_torch_available() + else {} + ) + + def setUp(self): + self.model_tester = InternvlFlashVisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=InternvlFlashConfig, has_text_modality=False) + + def test_flex_attention_with_grads(self): + pass + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="Compile not yet supported because in LLava models") + @pytest.mark.torch_compile_test + def test_sdpa_can_compile_dynamic(self): + pass + + @unittest.skip("FlashAttention only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + @unittest.skip("Skipping compilation test: fails with batch_size=0 reshape error") + def test_generate_compile_model_forward_fullgraph(self): + pass + + @unittest.skip("query token in InternVLFlashCrossAttentionPooling generate randomly") + def test_can_init_all_missing_weights(self): + pass + + @unittest.skip("InternVLFlash model requires input_ids to process pixel_values and merge visual features.") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip( + "The Flash model relies on input_ids for visual compression and cannot support multimodal generation using inputs_embeds alone." + ) + def test_generate_from_inputs_embeds_1_beam_search(self): + pass diff --git a/utils/check_repo.py b/utils/check_repo.py index 762b50e12ceb..eb87fd146ed0 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -183,6 +183,7 @@ "Qwen2VLTextModel", # Building part of bigger (tested) model "Qwen2_5_VLTextModel", # Building part of bigger (tested) model "InternVLVisionModel", # Building part of bigger (tested) model + "InternvlFlashVisionModel", # Building part of bigger (tested) model "JanusVisionModel", # Building part of bigger (tested) model "TimesFmModel", # Building part of bigger (tested) model "CsmDepthDecoderForCausalLM", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.