diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2c041bcaf398..333fa759734f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2205,26 +2205,78 @@ def disable_input_require_grads(self): """ self._require_grads_hook.remove() + def get_encoder(self, modality: Optional[str] = None): + """ + Best-effort lookup of the *encoder* module. If provided with `modality` argument, + it looks for a modality-specific encoder in multimodal models (e.g. "image_encoder") + By default the function returns model's text encoder if any, and otherwise returns `self`. + + Possible `modality` values are "image", "video" and "audio". + """ + # NOTE: new models need to use existing names for layers if possible, so this list doesn't grow infinitely + if modality in ["image", "video"]: + possible_module_names = ["vision_tower", "visual", "vision_model", "vision_encoder", "image_tower"] + elif modality == "audio": + possible_module_names = ["audio_tower", "audio_encoder", "speech_encoder"] + elif modality is None: + possible_module_names = ["text_encoder", "encoder"] + else: + raise ValueError(f'Unnrecognized modality, has to be "image", "video" or "audio" but found {modality}') + + for name in possible_module_names: + if hasattr(self, name): + return getattr(self, name) + + if self.base_model is not self and hasattr(self.base_model, "get_encoder"): + return self.base_model.get_encoder(modality=modality) + + # If this is a base transformer model (no encoder/model attributes), return self + return self + + def set_encoder(self, encoder, modality: Optional[str] = None): + """ + Symmetric setter. Mirrors the lookup logic used in `get_encoder`. + """ + + # NOTE: new models need to use existing names for layers if possible, so this list doesn't grow infinitely + if modality in ["image", "video"]: + possible_module_names = ["vision_tower", "visual", "vision_model", "vision_encoder", "image_tower"] + if modality == "audio": + possible_module_names = ["audio_tower", "audio_encoder"] + elif modality is None: + possible_module_names = ["text_encoder", "encoder"] + else: + raise ValueError(f'Unnrecognized modality, has to be "image", "video" or "audio" but found {modality}') + + for name in possible_module_names: + if hasattr(self, name): + setattr(self, name, encoder) + return + + if self.base_model is not self: + if hasattr(self.base_model, "set_encoder"): + self.base_model.set_encoder(encoder, modality=modality) + else: + self.model = encoder + def get_decoder(self): """ Best-effort lookup of the *decoder* module. Order of attempts (covers ~85 % of current usages): - 1. `self.decoder` - 2. `self.model` (many wrappers store the decoder here) - 3. `self.model.get_decoder()` (nested wrappers) + 1. `self.decoder/self.language_model/self.text_model` + 2. `self.base_model` (many wrappers store the decoder here) + 3. `self.base_model.get_decoder()` (nested wrappers) 4. fallback: raise for the few exotic models that need a bespoke rule """ - if hasattr(self, "decoder"): - return self.decoder + possible_module_names = ["language_model", "text_model", "decoder", "text_decoder"] + for name in possible_module_names: + if hasattr(self, name): + return getattr(self, name) - if hasattr(self, "model"): - inner = self.model - # See: https://github.com/huggingface/transformers/issues/40815 - if hasattr(inner, "get_decoder") and type(inner) is not type(self): - return inner.get_decoder() - return inner + if self.base_model is not self and hasattr(self.base_model, "get_decoder"): + return self.base_model.get_decoder() # If this is a base transformer model (no decoder/model attributes), return self # This handles cases like MistralModel which is itself the decoder @@ -2235,19 +2287,18 @@ def set_decoder(self, decoder): Symmetric setter. Mirrors the lookup logic used in `get_decoder`. """ - if hasattr(self, "decoder"): - self.decoder = decoder - return + possible_module_names = ["language_model", "text_model", "decoder"] + for name in possible_module_names: + if hasattr(self, name): + print(name) + setattr(self, name, decoder) + return - if hasattr(self, "model"): - inner = self.model - if hasattr(inner, "set_decoder"): - inner.set_decoder(decoder) + if self.base_model is not self: + if hasattr(self.base_model, "set_decoder"): + self.base_model.set_decoder(decoder) else: self.model = decoder - return - - return @torch.no_grad() def _init_weights(self, module): diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 3c0abffe2430..dbf73ce85b0a 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -910,12 +910,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -1075,12 +1069,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -1093,19 +1081,6 @@ def get_image_features( vision_feature_layer=vision_feature_layer, ) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 581985fa0c4c..c026c8a6d206 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -1342,9 +1342,6 @@ def create_network_inputs( ) return reshaped_lagged_sequence, features, loc, scale, static_feat - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1588,12 +1585,6 @@ def __init__(self, config: AutoformerConfig): def output_params(self, decoder_output): return self.parameter_projection(decoder_output[:, -self.config.prediction_length :, :]) - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - @torch.jit.ignore def output_distribution(self, params, loc=None, scale=None, trailing_n=None) -> torch.distributions.Distribution: sliced_params = params diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 734dff382416..742d7374aef2 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -181,12 +181,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -357,12 +351,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -377,19 +365,6 @@ def get_image_features( **kwargs, ) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 6e25fa4d30fc..d663821578a2 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -905,9 +905,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1037,12 +1034,6 @@ def __init__(self, config: BartConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: @@ -1498,12 +1489,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.decoder.embed_tokens = value - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 9fb140bd1e2b..263d308c47d1 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2083,9 +2083,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -2205,12 +2202,6 @@ def __init__(self, config: BigBirdPegasusConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: @@ -2609,12 +2600,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.decoder.embed_tokens = value - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 63021a27ca48..833d529671c9 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -869,9 +869,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1009,12 +1006,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: @@ -1189,12 +1180,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.decoder.embed_tokens = value - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 85e342619ff9..92f2f993fafa 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -842,9 +842,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -969,12 +966,6 @@ def __init__(self, config: BlenderbotSmallConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: @@ -1149,12 +1140,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.decoder.embed_tokens = value - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 542eaca5ca29..d6ee68a6680b 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1058,11 +1058,11 @@ def set_output_embeddings(self, new_embeddings): def get_output_embeddings(self) -> nn.Module: return self.language_model.get_output_embeddings() - def get_encoder(self): - return self.language_model.get_encoder() - - def get_decoder(self): - return self.language_model.get_decoder() + def get_encoder(self, modality=None): + if modality is None: + return self.language_model.get_encoder() + else: + return super().get_encoder(modality=modality) @filter_out_non_signature_kwargs() @auto_docstring @@ -1579,11 +1579,11 @@ def set_output_embeddings(self, new_embeddings): def get_output_embeddings(self) -> nn.Module: return self.language_model.get_output_embeddings() - def get_encoder(self): - return self.language_model.get_encoder() - - def get_decoder(self): - return self.language_model.get_decoder() + def get_encoder(self, modality=None): + if modality is None: + return self.language_model.get_encoder() + else: + return super().get_encoder(modality=modality) def _preprocess_accelerate(self): r""" diff --git a/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py b/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py index 083fde1f9197..2ebfa7f044cd 100644 --- a/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py +++ b/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py @@ -167,12 +167,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features(self, pixel_values: torch.FloatTensor): """ Obtains image last hidden states from the vision tower and apply multimodal projection. @@ -285,28 +279,9 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features(self, pixel_values: torch.FloatTensor): return self.model.get_image_features(pixel_values=pixel_values) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @check_model_inputs() @auto_docstring def forward( diff --git a/src/transformers/models/colqwen2/modeling_colqwen2.py b/src/transformers/models/colqwen2/modeling_colqwen2.py index 922d9bf3bd86..3d0ca8d3a84b 100644 --- a/src/transformers/models/colqwen2/modeling_colqwen2.py +++ b/src/transformers/models/colqwen2/modeling_colqwen2.py @@ -171,10 +171,10 @@ def forward( # Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs. if inputs_embeds is None: - inputs_embeds = self.vlm.language_model.embed_tokens(input_ids) + inputs_embeds = self.vlm.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw) + image_embeds = self.vlm.model.visual(pixel_values, grid_thw=image_grid_thw) image_mask = ( (input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) ) diff --git a/src/transformers/models/colqwen2/modular_colqwen2.py b/src/transformers/models/colqwen2/modular_colqwen2.py index d7474bfd6211..6e468e6ae3fa 100644 --- a/src/transformers/models/colqwen2/modular_colqwen2.py +++ b/src/transformers/models/colqwen2/modular_colqwen2.py @@ -352,10 +352,10 @@ def forward( # Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs. if inputs_embeds is None: - inputs_embeds = self.vlm.language_model.embed_tokens(input_ids) + inputs_embeds = self.vlm.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw) + image_embeds = self.vlm.model.visual(pixel_values, grid_thw=image_grid_thw) image_mask = ( (input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) ) diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 9843dda12a60..3721678dd688 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -1324,9 +1324,6 @@ def __init__(self, config: ConditionalDetrConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.encoder - def freeze_backbone(self): for name, param in self.backbone.conv_encoder.model.named_parameters(): param.requires_grad_(False) diff --git a/src/transformers/models/d_fine/modeling_d_fine.py b/src/transformers/models/d_fine/modeling_d_fine.py index d61b96ce566b..86953cd47ecf 100644 --- a/src/transformers/models/d_fine/modeling_d_fine.py +++ b/src/transformers/models/d_fine/modeling_d_fine.py @@ -1198,9 +1198,6 @@ def __init__(self, config: DFineConfig): self.post_init() - def get_encoder(self): - return self.encoder - def freeze_backbone(self): for param in self.backbone.parameters(): param.requires_grad_(False) diff --git a/src/transformers/models/dab_detr/modeling_dab_detr.py b/src/transformers/models/dab_detr/modeling_dab_detr.py index e1d25bc121d0..a1337acefd7a 100644 --- a/src/transformers/models/dab_detr/modeling_dab_detr.py +++ b/src/transformers/models/dab_detr/modeling_dab_detr.py @@ -1202,9 +1202,6 @@ def __init__(self, config: DabDetrConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.encoder - def freeze_backbone(self): for name, param in self.backbone.conv_encoder.model.named_parameters(): param.requires_grad_(False) diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index e92ce998e8b8..a8133c85b573 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1359,9 +1359,6 @@ def __init__(self, config: DeformableDetrConfig): self.post_init() - def get_encoder(self): - return self.encoder - def freeze_backbone(self): for name, param in self.backbone.conv_encoder.model.named_parameters(): param.requires_grad_(False) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 968859bf3d21..1727c9b33700 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -1058,9 +1058,6 @@ def __init__(self, config: DetrConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.encoder - def freeze_backbone(self): for name, param in self.backbone.conv_encoder.model.named_parameters(): param.requires_grad_(False) diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 6db8b1e9766f..3a0ddf6e3f90 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -686,9 +686,6 @@ def __init__(self, config: DiaConfig): self.decoder = DiaDecoder(config.decoder_config) self.post_init() - def get_encoder(self): - return self.encoder - @auto_docstring @can_return_tuple def forward( @@ -825,12 +822,6 @@ def __init__(self, config: DiaConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - @auto_docstring @can_return_tuple def forward( diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py index 514c40d79816..3c6a6b3d17cb 100644 --- a/src/transformers/models/dia/modular_dia.py +++ b/src/transformers/models/dia/modular_dia.py @@ -478,9 +478,6 @@ def __init__(self, config: DiaConfig): self.decoder = DiaDecoder(config.decoder_config) self.post_init() - def get_encoder(self): - return self.encoder - @auto_docstring @can_return_tuple def forward( @@ -617,12 +614,6 @@ def __init__(self, config: DiaConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - @auto_docstring @can_return_tuple def forward( diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 7d900cf4a27a..750c844e65ae 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1354,12 +1354,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.text_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.text_model = decoder - - def get_decoder(self): - return self.text_model - def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): """ Tokenizes images into discrete tokens with VQGAN module. Converts @@ -1514,25 +1508,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - - # Make modules available through conditional class for BC - @property - def text_model(self): - return self.model.text_model - - @property - def vqmodel(self): - return self.model.vqmodel - - @property - def vocabulary_mapping(self): - return self.model.vocabulary_mapping - def decode_image_tokens(self, **kwargs): return self.model.decode_image_tokens(**kwargs) diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 1ef5e23a4436..9e1e2ba4f5f0 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -908,12 +908,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.text_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.text_model = decoder - - def get_decoder(self): - return self.text_model - def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): """ Tokenizes images into discrete tokens with VQGAN module. Converts @@ -1068,25 +1062,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - - # Make modules available through conditional class for BC - @property - def text_model(self): - return self.model.text_model - - @property - def vqmodel(self): - return self.model.vqmodel - - @property - def vocabulary_mapping(self): - return self.model.vocabulary_mapping - def decode_image_tokens(self, **kwargs): return self.model.decode_image_tokens(**kwargs) diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py index a02ee98e9a30..50fb2c56dce8 100644 --- a/src/transformers/models/encodec/modeling_encodec.py +++ b/src/transformers/models/encodec/modeling_encodec.py @@ -498,9 +498,6 @@ def __init__(self, config: EncodecConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.encoder - def _encode_frame( self, input_values: torch.Tensor, bandwidth: float ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index fe3bec2f57cc..64b972c69047 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -173,9 +173,6 @@ def _init_weights(self, module): elif module in self.decoder.modules(): self.decoder._init_weights(module) - def get_encoder(self): - return self.encoder - def get_input_embeddings(self): return self.encoder.get_input_embeddings() diff --git a/src/transformers/models/florence2/modeling_florence2.py b/src/transformers/models/florence2/modeling_florence2.py index 4d6d38c5db29..9e761c0a8c20 100644 --- a/src/transformers/models/florence2/modeling_florence2.py +++ b/src/transformers/models/florence2/modeling_florence2.py @@ -652,12 +652,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model.get_decoder() - def get_image_features(self, pixel_values: torch.Tensor, **kwargs): """ Obtains image last hidden states from the vision tower and apply multimodal projection. @@ -775,8 +769,11 @@ def forward( image_hidden_states=image_features if pixel_values is not None else None, ) - def get_encoder(self): - return self.language_model.get_encoder() + def get_encoder(self, modality=None): + if modality is None: + return self.language_model.get_encoder() + else: + return super().get_encoder(modality=modality) def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): @@ -821,28 +818,9 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features(self, pixel_values: torch.Tensor, **kwargs): return self.model.get_image_features(pixel_values=pixel_values, **kwargs) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple @auto_docstring def forward( @@ -979,9 +957,6 @@ def prepare_inputs_for_generation( return model_inputs - def get_encoder(self): - return self.model.get_encoder() - def get_placeholder_mask( self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor ): diff --git a/src/transformers/models/florence2/modular_florence2.py b/src/transformers/models/florence2/modular_florence2.py index d2bf13544b1f..86bd02bda077 100644 --- a/src/transformers/models/florence2/modular_florence2.py +++ b/src/transformers/models/florence2/modular_florence2.py @@ -1513,11 +1513,11 @@ def __init__(self, config: Florence2Config): super().__init__(config) self.vision_tower = Florence2VisionBackbone(config=config.vision_config) - def get_encoder(self): - return self.language_model.get_encoder() - - def get_decoder(self): - return self.language_model.get_decoder() + def get_encoder(self, modality=None): + if modality is None: + return self.language_model.get_encoder() + else: + return super().get_encoder(modality=modality) def get_image_features(self, pixel_values: torch.Tensor, **kwargs): """ @@ -1624,9 +1624,6 @@ class Florence2ForConditionalGeneration(LlavaForConditionalGeneration): "lm_head.weight": "model.language_model.shared.weight", } - def get_encoder(self): - return self.model.get_encoder() - def get_image_features(self, pixel_values: torch.Tensor, **kwargs): return self.model.get_image_features(pixel_values=pixel_values, **kwargs) diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index cc4607023fbd..4bac57096cb4 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -827,9 +827,6 @@ def __init__(self, config: FSMTConfig): self.decoder = FSMTDecoder(config) self.post_init() - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1066,12 +1063,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) - def get_encoder(self): - return self.model.encoder - - def get_decoder(self): - return self.model.decoder - def get_output_embeddings(self): return self.model.decoder.embed_tokens diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index e0a969fea3a8..c4983e007ba7 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -72,12 +72,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def gather_continuous_embeddings( self, word_embeddings: torch.Tensor, @@ -260,12 +254,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index c623324226ac..8e93ef9231b5 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -872,12 +872,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: """ Projects the last hidden state from the vision model into language model space. @@ -1062,28 +1056,9 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features(self, pixel_values): return self.model.get_image_features(pixel_values) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @auto_docstring def forward( self, diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 0b3088aadec7..cc0b919bc85c 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -2107,12 +2107,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: """ Projects the last hidden state from the vision model into language model space. @@ -2361,28 +2355,9 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features(self, pixel_values): return self.model.get_image_features(pixel_values) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - raise AttributeError("Use embed_vision instead of multi_modal_projector.") - @can_return_tuple @auto_docstring def forward( @@ -2557,10 +2532,6 @@ def prepare_inputs_for_generation( return model_inputs - @property - def audio_tower(self): - return self.model.audio_tower - __all__ = [ "Gemma3nAudioEncoder", diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index 02cf9b9f4833..9dc324b69411 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -2421,14 +2421,6 @@ def get_audio_features( class Gemma3nForConditionalGeneration(PaliGemmaForConditionalGeneration): _checkpoint_conversion_mapping = {} - @property - def audio_tower(self): - return self.model.audio_tower - - @property - def multi_modal_projector(self): - raise AttributeError("Use embed_vision instead of multi_modal_projector.") - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/glm46v/modeling_glm46v.py b/src/transformers/models/glm46v/modeling_glm46v.py index 8c20786c955e..2f00d22fb040 100644 --- a/src/transformers/models/glm46v/modeling_glm46v.py +++ b/src/transformers/models/glm46v/modeling_glm46v.py @@ -100,12 +100,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, @@ -533,12 +527,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None ): @@ -547,15 +535,6 @@ def get_video_features( def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): return self.model.get_image_features(pixel_values, image_grid_thw) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def visual(self): - return self.model.visual - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 9942770d70e3..ff5e0a00cc0d 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -948,12 +948,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, @@ -1381,12 +1375,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None ): @@ -1395,15 +1383,6 @@ def get_video_features( def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): return self.model.get_image_features(pixel_values, image_grid_thw) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def visual(self): - return self.model.visual - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index a212aed44e81..3ea0ef86c23f 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -1112,12 +1112,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, @@ -1598,12 +1592,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None ): @@ -1612,15 +1600,6 @@ def get_video_features( def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): return self.model.get_image_features(pixel_values, image_grid_thw) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def visual(self): - return self.model.visual - @auto_docstring @check_model_inputs() def forward( diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 3ac6bd177220..3fd5a2c2dbfb 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -551,12 +551,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -684,12 +678,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -704,19 +692,6 @@ def get_image_features( **kwargs, ) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index dcb05efb752f..7e3bc1d577bf 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1278,9 +1278,6 @@ def create_network_inputs( return transformer_inputs, loc, scale, static_feat - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1550,12 +1547,6 @@ def __init__(self, config: InformerConfig): def output_params(self, dec_output): return self.parameter_projection(dec_output) - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - @torch.jit.ignore def output_distribution(self, params, loc=None, scale=None, trailing_n=None) -> torch.distributions.Distribution: sliced_params = params diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 5432283b4b0c..5e4a2a7e864b 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -1142,8 +1142,11 @@ def set_output_embeddings(self, new_embeddings): def get_output_embeddings(self) -> nn.Module: return self.language_model.get_output_embeddings() - def get_encoder(self): - return self.language_model.get_encoder() + def get_encoder(self, modality=None): + if modality is None: + return self.language_model.get_encoder() + else: + return super().get_encoder(modality=modality) def get_decoder(self): return self.language_model.get_decoder() diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 0375ec5042cf..2268ba28893b 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -1172,8 +1172,11 @@ def set_output_embeddings(self, new_embeddings): def get_output_embeddings(self) -> nn.Module: return self.language_model.get_output_embeddings() - def get_encoder(self): - return self.language_model.get_encoder() + def get_encoder(self, modality=None): + if modality is None: + return self.language_model.get_encoder() + else: + return super().get_encoder(modality=modality) def get_decoder(self): return self.language_model.get_decoder() diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 6e586ce999d5..51691e0ba4ab 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -549,12 +549,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -787,12 +781,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -807,19 +795,6 @@ def get_image_features( **kwargs, ) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 40d6b82c9ede..409caf2091ad 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1771,9 +1771,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1907,12 +1904,6 @@ def __init__(self, config: LEDConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.led.get_encoder() - - def get_decoder(self): - return self.led.get_decoder() - def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: diff --git a/src/transformers/models/lfm2_vl/modeling_lfm2_vl.py b/src/transformers/models/lfm2_vl/modeling_lfm2_vl.py index 27ba1ece7af4..ce46c62baeab 100755 --- a/src/transformers/models/lfm2_vl/modeling_lfm2_vl.py +++ b/src/transformers/models/lfm2_vl/modeling_lfm2_vl.py @@ -161,12 +161,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -324,12 +318,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -344,19 +332,6 @@ def get_image_features( **kwargs, ) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple def forward( self, diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 1f8a2a9645ea..2947295a4775 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -147,12 +147,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -332,12 +326,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -352,19 +340,6 @@ def get_image_features( **kwargs, ) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index a83821d98f96..fffd56a941c5 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -279,12 +279,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. @@ -562,12 +556,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): return self.model.pack_image_features( image_features=image_features, @@ -590,19 +578,6 @@ def get_image_features( vision_feature_select_strategy=vision_feature_select_strategy, ) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 7f6fbffaec07..6e79d602eb94 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -331,12 +331,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. @@ -701,12 +695,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): return self.model.pack_image_features( image_features=image_features, @@ -729,19 +717,6 @@ def get_image_features( vision_feature_select_strategy=vision_feature_select_strategy, ) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index efe4f6fb1ba6..70d17ff3e6d4 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -290,12 +290,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. @@ -689,12 +683,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): return self.model.pack_image_features( image_features=image_features, @@ -717,19 +705,6 @@ def get_image_features( vision_feature_select_strategy=vision_feature_select_strategy, ) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index df716432bc26..12c7df3310eb 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1602,9 +1602,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1769,9 +1766,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1943,9 +1937,6 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 676938c6789f..4fe228b4021d 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -925,9 +925,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1032,12 +1029,6 @@ def __init__(self, config: M2M100Config): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - @auto_docstring def forward( self, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 3a211ddbe715..960898693b4d 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -885,9 +885,6 @@ def set_decoder_input_embeddings(self, value): ) self.decoder.embed_tokens = value - def get_encoder(self): - return self.encoder - def resize_decoder_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: if self.config.share_encoder_decoder_embeddings: raise ValueError( @@ -1051,12 +1048,6 @@ def __init__(self, config: MarianConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: @@ -1281,12 +1272,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.decoder.embed_tokens = value - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index b4a3ed8ebd02..dd8619fc6c94 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -903,9 +903,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1022,12 +1019,6 @@ def __init__(self, config: MBartConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: @@ -1473,12 +1464,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.decoder.embed_tokens = value - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 7a367e05261a..82f2868a1421 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1453,9 +1453,6 @@ def __init__(self, config: MimiConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.encoder - def _encode_frame( self, input_values: torch.Tensor, diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 6a67f21f216f..0f6e2a1d3efc 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -213,12 +213,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -383,12 +377,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -403,19 +391,6 @@ def get_image_features( **kwargs, ) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 9f811fa7f010..a2d303782bdd 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1465,12 +1465,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - @check_model_inputs() @can_return_tuple @auto_docstring @@ -1600,21 +1594,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_model(self): - return self.model.vision_model - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 0840c1623489..373e1db4a217 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -860,9 +860,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.decoder.embed_tokens = value - def get_encoder(self): - return self.encoder - def freeze_encoder(self): """ Calling this function will disable the gradient computation for the Moonshine encoder so that its parameters will @@ -1019,12 +1016,6 @@ def __init__(self, config: MoonshineConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - def get_output_embeddings(self): return self.proj_out diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 38314c4535a6..1922428f8f8a 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -774,12 +774,6 @@ def __init__(self, config: MoonshineConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - def get_output_embeddings(self): return self.proj_out diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index eb379f0bcf1e..259caa7a5daf 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1601,9 +1601,6 @@ def __init__(self, config: MoshiConfig): self.num_codebooks = config.num_codebooks self.post_init() - def get_audio_encoder(self): - return self.audio_encoder - def get_depth_decoder(self): return self.depth_decoder diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index cf4b1189ee95..1f05521accbe 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -881,10 +881,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - # Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder - def get_encoder(self): - return self.encoder - @auto_docstring # Copied from transformers.models.t5.modeling_t5.T5Model.forward with google-t5/->google/, T5->MT5, t5->mt5 def forward( @@ -1067,10 +1063,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_encoder - def get_encoder(self): - return self.encoder - @auto_docstring # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward with google-t5/->google/, T5->MT5, t5->mt5 def forward( @@ -1266,10 +1258,6 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) - # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_encoder - def get_encoder(self): - return self.encoder - @auto_docstring # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.forward with google-t5/->google/, T5->MT5, t5->mt5 def forward( @@ -1582,10 +1570,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder - def get_encoder(self): - return self.encoder - @auto_docstring # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.forward def forward( diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 35969de777ab..1bb423c63d8b 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -815,12 +815,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_heads = new_embeddings - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - @auto_docstring def forward( self, @@ -1398,16 +1392,6 @@ def __init__( # tie text encoder, decoder weights if config set accordingly self.post_init() - def get_audio_encoder(self): - return self.audio_encoder - - def get_text_encoder(self): - return self.text_encoder - - def get_encoder(self): - # get the text encoder to compute the encoder hidden-states for generation - return self.get_text_encoder() - def get_input_embeddings(self): return self.text_encoder.get_input_embeddings() @@ -1898,7 +1882,7 @@ def _prepare_text_encoder_kwargs_for_generation( generation_config: GenerationConfig, ) -> dict[str, Any]: # 1. get text encoder - encoder = self.get_text_encoder() + encoder = self.get_encoder() # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device # as the inputs. if hasattr(encoder, "_hf_hook"): @@ -1943,7 +1927,7 @@ def _prepare_audio_encoder_kwargs_for_generation( self, input_values, model_kwargs, model_input_name: Optional[str] = None ): # 1. get audio encoder - encoder = self.get_audio_encoder() + encoder = self.get_encoder(modality="audio") # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device # as the inputs. if hasattr(encoder, "_hf_hook"): diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index bd8f29a86a5d..279c984fe6b6 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -769,12 +769,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_heads = new_embeddings - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - @auto_docstring # Ignore copy def forward( @@ -1318,13 +1312,6 @@ def _init_weights(self, module): if module.bias is not None: init.zeros_(module.bias) - def get_text_encoder(self): - return self.text_encoder - - def get_encoder(self): - # get the text encoder to compute the conditioning hidden-states for generation - return self.get_text_encoder() - def get_input_embeddings(self): return self.text_encoder.get_input_embeddings() @@ -1824,7 +1811,7 @@ def _prepare_encoder_hidden_states_kwargs_for_generation( # 1. condition on text if inputs_tensor is not None: - encoder = self.get_text_encoder() + encoder = self.get_encoder() # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device # as the inputs. if hasattr(encoder, "_hf_hook"): diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 28fbf6960545..1d315d305912 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -893,9 +893,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def get_encoder(self): - return self.encoder - def set_lightweight_tuning(self): assert self.use_prompt, "If you want to use lightweight tuning, make sure that `use_prompt=True`." @@ -1031,12 +1028,6 @@ def __init__(self, config: MvpConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: @@ -1537,12 +1528,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.decoder.embed_tokens = value - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - def set_lightweight_tuning(self): self.model.set_lightweight_tuning() self.lm_head.requires_grad_(False) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 7f0de0ce824c..a457abddc6cb 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -893,9 +893,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def get_encoder(self): - return self.encoder - @auto_docstring @can_return_tuple def forward( @@ -1066,12 +1063,6 @@ def __init__(self, config: NllbMoeConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 025ea7d2cd80..6368c1b1131d 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -720,12 +720,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.decoder.embed_tokens = value - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/ovis2/modeling_ovis2.py b/src/transformers/models/ovis2/modeling_ovis2.py index 710f0a5603bf..d990e08190b6 100644 --- a/src/transformers/models/ovis2/modeling_ovis2.py +++ b/src/transformers/models/ovis2/modeling_ovis2.py @@ -518,12 +518,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -688,28 +682,9 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features(self, pixel_values: torch.FloatTensor): return self.model.get_image_features(pixel_values=pixel_values) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - raise AttributeError("Not needed for Ovis2") - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/ovis2/modular_ovis2.py b/src/transformers/models/ovis2/modular_ovis2.py index f6277dc91a0a..0acf99daafe6 100644 --- a/src/transformers/models/ovis2/modular_ovis2.py +++ b/src/transformers/models/ovis2/modular_ovis2.py @@ -338,10 +338,6 @@ def __init__(self, config: Ovis2Config): super().__init__(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - @property - def multi_modal_projector(self): - raise AttributeError("Not needed for Ovis2") - def get_image_features(self, pixel_values: torch.FloatTensor): return self.model.get_image_features(pixel_values=pixel_values) diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index f1464bde5fb2..63538043506b 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -257,12 +257,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features(self, pixel_values: torch.FloatTensor): """ Obtains image last hidden states from the vision tower and apply multimodal projection. @@ -450,28 +444,9 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features(self, pixel_values): return self.model.get_image_features(pixel_values) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 34205b2e5972..41ff955d574f 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -907,9 +907,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def get_encoder(self): - return self.encoder - def resize_position_embeddings(self, new_num_position_embeddings: int): """ Resizes position embeddings matrix of the model if `new_num_position_embeddings != @@ -1058,12 +1055,6 @@ def __init__(self, config: PegasusConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: @@ -1253,12 +1244,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.decoder.embed_tokens = value - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - def get_position_embeddings(self) -> nn.Embedding: """ Returns the position embeddings matrix diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 53e44ef414d0..ffe26a3be7ce 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1202,9 +1202,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def get_encoder(self): - return self.encoder - def resize_position_embeddings(self, new_num_position_embeddings: int): """ Resizes position embeddings matrix of the model if `new_num_position_embeddings != @@ -1351,12 +1348,6 @@ def __init__(self, config: PegasusXConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - def resize_position_embeddings(self, new_num_position_embeddings: int): """ Resizes position embeddings matrix of the model if `new_num_position_embeddings != diff --git a/src/transformers/models/perception_lm/modeling_perception_lm.py b/src/transformers/models/perception_lm/modeling_perception_lm.py index 0a601deac183..6e6724288b06 100644 --- a/src/transformers/models/perception_lm/modeling_perception_lm.py +++ b/src/transformers/models/perception_lm/modeling_perception_lm.py @@ -180,12 +180,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -340,12 +334,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/perception_lm/modular_perception_lm.py b/src/transformers/models/perception_lm/modular_perception_lm.py index 2b50b8242202..823163e23749 100644 --- a/src/transformers/models/perception_lm/modular_perception_lm.py +++ b/src/transformers/models/perception_lm/modular_perception_lm.py @@ -424,15 +424,6 @@ def forward( def get_image_features(self, **kwargs): raise AttributeError("Not needed for PerceptionLM") - def language_model(self): - raise AttributeError("Not needed for PerceptionLM") - - def vision_tower(self): - raise AttributeError("Not needed for PerceptionLM") - - def multi_modal_projector(self): - raise AttributeError("Not needed for PerceptionLM") - __all__ = [ "PerceptionLMForConditionalGeneration", diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 8a58a353617f..79565625b6ff 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1343,9 +1343,6 @@ def get_output_embeddings(self) -> nn.Module: def set_output_embeddings(self, new_embeddings): self.decoder.set_output_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index e1b9f375c542..13f033b67fcc 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -851,9 +851,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -972,12 +969,6 @@ def __init__(self, config: PLBartConfig): self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: @@ -1309,12 +1300,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.decoder.embed_tokens = value - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index 796f63d54632..e060fbb1d278 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -92,9 +92,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -213,12 +210,6 @@ def __init__(self, config: PLBartConfig): self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 493b9d4776ec..546e64af7550 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -984,9 +984,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - def get_mel_conditioner_outputs( self, input_features: torch.FloatTensor, diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 619769b9d2b7..5aec96458d76 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1400,9 +1400,6 @@ def set_input_embeddings(self, value): self.encoder.word_embeddings = self.word_embeddings self.decoder.word_embeddings = self.word_embeddings - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1671,11 +1668,11 @@ def _compute_loss(self, logits, labels, ignore_index=-100): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - def get_encoder(self): - return self.prophetnet.encoder - - def get_decoder(self): - return self.prophetnet.decoder + def get_encoder(self, modality=None): + if modality is None: + return self.prophetnet.encoder + else: + return super().get_encoder(modality=modality) @auto_docstring( @@ -1711,12 +1708,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.prophetnet.decoder.word_embeddings = value - def set_decoder(self, decoder): - self.prophetnet.decoder = decoder - - def get_decoder(self): - return self.prophetnet.decoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 4c035c3144eb..b18e1e9f24dd 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -972,12 +972,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, @@ -1390,12 +1384,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None ): @@ -1404,15 +1392,6 @@ def get_video_features( def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): return self.model.get_image_features(pixel_values, image_grid_thw) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def visual(self): - return self.model.visual - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index bad9709e32c1..e058cceb1fa9 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -943,12 +943,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, @@ -1288,12 +1282,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None ): @@ -1302,15 +1290,6 @@ def get_video_features( def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): return self.model.get_image_features(pixel_values, image_grid_thw) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def visual(self): - return self.model.visual - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index b2502161c0ea..d215f689da65 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -937,12 +937,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, @@ -1318,12 +1312,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None ): @@ -1332,15 +1320,6 @@ def get_video_features( def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): return self.model.get_image_features(pixel_values, image_grid_thw) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def visual(self): - return self.model.visual - @check_model_inputs() def forward( self, diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index efd0e8d24926..4d7b41f290a5 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -1094,12 +1094,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, @@ -1528,12 +1522,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None ): @@ -1542,15 +1530,6 @@ def get_video_features( def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): return self.model.get_image_features(pixel_values, image_grid_thw) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def visual(self): - return self.model.visual - @check_model_inputs() def forward( self, diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index d91c8fd92d90..bfdf4dd75fae 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -1543,9 +1543,6 @@ def __init__(self, config: RTDetrConfig): self.post_init() - def get_encoder(self): - return self.encoder - def freeze_backbone(self): for param in self.backbone.parameters(): param.requires_grad_(False) diff --git a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py index 1a5a510ce11c..33cefeb1729c 100644 --- a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +++ b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py @@ -1439,9 +1439,6 @@ def __init__(self, config: RTDetrV2Config): self.post_init() - def get_encoder(self): - return self.encoder - def freeze_backbone(self): for param in self.backbone.parameters(): param.requires_grad_(False) diff --git a/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py b/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py index fba702ecc342..7f317bb3b5f0 100644 --- a/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py +++ b/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py @@ -70,12 +70,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.model.language_model.set_output_embeddings(new_embeddings) - def set_decoder(self, decoder): - self.model.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.model.language_model.get_decoder() - @auto_docstring def forward( self, diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index c53034a66510..5a2babb00d7b 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -148,9 +148,6 @@ def __init__( self.post_init() - def get_encoder(self): - return self.encoder - def get_input_embeddings(self): return self.decoder.get_input_embeddings() diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index a0f017842559..e2078910afe5 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -884,9 +884,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.decoder.embed_tokens = value - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1022,12 +1019,6 @@ def __init__(self, config: Speech2TextConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - @auto_docstring def forward( self, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 114c4d444124..6c7ce5b911ae 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -1882,9 +1882,6 @@ def set_input_embeddings(self, value): if isinstance(self.decoder, SpeechT5DecoderWithTextPrenet): self.decoder.set_input_embeddings(value) - def get_encoder(self): - return self.encoder - def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will @@ -2021,12 +2018,6 @@ def __init__(self, config: SpeechT5Config): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.speecht5.get_encoder() - - def get_decoder(self): - return self.speecht5.get_decoder() - def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will @@ -2348,12 +2339,6 @@ def can_generate(cls) -> bool: # but we need to override it so as to do `GenerationConfig` handling in multiple parts of the codebase. return True - def get_encoder(self): - return self.speecht5.get_encoder() - - def get_decoder(self): - return self.speecht5.get_decoder() - @auto_docstring def forward( self, @@ -2685,12 +2670,6 @@ def __init__(self, config: SpeechT5Config): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.speecht5.get_encoder() - - def get_decoder(self): - return self.speecht5.get_decoder() - def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 8792f50d4e38..9867677e703a 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -940,9 +940,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring @can_return_tuple def forward( @@ -1098,9 +1095,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring @can_return_tuple def forward( @@ -1246,9 +1240,6 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring @check_model_inputs() def forward( diff --git a/src/transformers/models/switch_transformers/modular_switch_transformers.py b/src/transformers/models/switch_transformers/modular_switch_transformers.py index 33009c21a15a..e5004f41d201 100644 --- a/src/transformers/models/switch_transformers/modular_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modular_switch_transformers.py @@ -696,9 +696,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring @can_return_tuple def forward( @@ -789,9 +786,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring @can_return_tuple def forward( @@ -937,9 +931,6 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring @check_model_inputs() def forward( diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index f3cb42b4bc23..051fd8a5e7d0 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -863,9 +863,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1030,9 +1027,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1206,9 +1200,6 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1513,9 +1504,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 5f70baa4812e..ac9b64929280 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -885,9 +885,6 @@ def __init__(self, config: T5GemmaConfig): self.post_init() - def get_encoder(self): - return self.encoder - def get_input_embeddings(self): return self.encoder.get_input_embeddings() @@ -1014,12 +1011,6 @@ def set_output_embeddings(self, new_embeddings): def get_output_embeddings(self): return self.lm_head.out_proj - def get_encoder(self): - return self.model.encoder - - def get_decoder(self): - return self.model.decoder - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index b3ed006f007e..3fa256a3c99f 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -892,9 +892,6 @@ def __init__(self, config: T5GemmaConfig): self.post_init() - def get_encoder(self): - return self.encoder - def get_input_embeddings(self): return self.encoder.get_input_embeddings() @@ -1021,12 +1018,6 @@ def set_output_embeddings(self, new_embeddings): def get_output_embeddings(self): return self.lm_head.out_proj - def get_encoder(self): - return self.model.encoder - - def get_decoder(self): - return self.model.decoder - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 8c4777e74420..b788b2c66359 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -1023,9 +1023,6 @@ def __init__(self, config: TableTransformerConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.encoder - def freeze_backbone(self): for name, param in self.backbone.conv_encoder.model.named_parameters(): param.requires_grad_(False) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index ab4976f1121d..d2476457d7f9 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -1057,9 +1057,6 @@ def create_network_inputs( return transformer_inputs, loc, scale, static_feat - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1297,12 +1294,6 @@ def __init__(self, config: TimeSeriesTransformerConfig): def output_params(self, dec_output): return self.parameter_projection(dec_output) - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - @torch.jit.ignore def output_distribution(self, params, loc=None, scale=None, trailing_n=None) -> torch.distributions.Distribution: sliced_params = params diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index e3a513f8687e..b96892271356 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -671,12 +671,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.output_projection = new_embeddings - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 4007d9b0926c..c20fae835ad5 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1455,9 +1455,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1635,9 +1632,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1815,9 +1809,6 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 69f6c77e06ae..0b7b03ec1c08 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -950,10 +950,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - # Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1134,10 +1130,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_encoder - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, @@ -1330,10 +1322,6 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) - # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_encoder - def get_encoder(self): - return self.encoder - @auto_docstring # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.forward with T5->UMT5, google-t5/t5-small->google/umt5-small, t5#training->umt5#training def forward( @@ -1645,10 +1633,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder - def get_encoder(self): - return self.encoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index 0b6d0277ad4a..8442a4a7c175 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -537,12 +537,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_video_features( self, pixel_values_videos: torch.FloatTensor, @@ -761,12 +755,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None ): @@ -775,11 +763,6 @@ def get_video_features( def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): return self.model.get_image_features(pixel_values, image_grid_thw) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - @can_return_tuple @auto_docstring def forward( @@ -1103,10 +1086,6 @@ def _expand_dict_for_generation(dict_to_expand): return input_ids, model_kwargs - @property - def vision_model(self): - return self.model.vision_model - __all__ = [ "VideoLlama3VisionModel", diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index cfcaf0dec655..0bd05ec9acf8 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -750,13 +750,6 @@ class VideoLlama3ForConditionalGeneration(Qwen2VLForConditionalGeneration): def __init__(self, config: VideoLlama3Config): super().__init__(config) # just to add type hint on config - def visual(self): - raise AttributeError("Not needed for VideoLLaMA3") - - @property - def vision_model(self): - return self.model.vision_model - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 559c30ef1f65..29184ca8a165 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -172,12 +172,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features( self, pixel_values_images: torch.FloatTensor, @@ -432,12 +426,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features( self, pixel_values_images: torch.FloatTensor, @@ -450,23 +438,6 @@ def get_image_features( vision_feature_select_strategy=vision_feature_select_strategy, ) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def video_tower(self): - return self.model.video_tower - - @property - def image_tower(self): - return self.model.image_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 02fdf4c4638c..1f1d5b04b5a5 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -150,12 +150,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features( self, pixel_values: torch.FloatTensor, vision_feature_layers: Optional[Union[int, list[int]]] = None ): @@ -310,30 +304,11 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, vision_feature_layers: Optional[Union[int, list[int]]] = None ): return self.model.get_image_features(pixel_values=pixel_values, vision_feature_layers=vision_feature_layers) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index d3c64e1848df..87fb80dd2aa8 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -147,9 +147,6 @@ def __init__( self.post_init() - def get_encoder(self): - return self.encoder - def get_input_embeddings(self): return self.decoder.get_input_embeddings() diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 22eabb6d9299..1f1b9a96a94e 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1265,9 +1265,6 @@ def __init__(self, config: VitsConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.text_encoder - @auto_docstring def forward( self, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index cff7e10b4b2f..c143e7a2ec8f 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -916,9 +916,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.decoder.embed_tokens = value - def get_encoder(self): - return self.encoder - def freeze_encoder(self): """ Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will @@ -1099,12 +1096,6 @@ def __init__(self, config: WhisperConfig): # Initialize weights and apply final processing self.post_init() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - def get_output_embeddings(self): return self.proj_out @@ -1294,12 +1285,6 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value): self.model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - @auto_docstring def forward( self, diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 036fea5ee3d1..a7a40d887787 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3030,7 +3030,7 @@ def test_sdpa_can_dispatch_composite_models(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_sdpa = model_class.from_pretrained(tmpdirname) - model_sdpa = model_sdpa.eval().to(torch_device) + model_sdpa = model_sdpa.base_model vision_model_names = {"visual", "image_tower", "vision_tower", "vision_model"} language_model_names = {"language_model", "model", "text_model"} @@ -3048,7 +3048,7 @@ def test_sdpa_can_dispatch_composite_models(self): self.assertTrue(vision_model_sdpa.config._attn_implementation == vision_attn) model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") - model_eager = model_eager.eval().to(torch_device) + model_eager = model_eager.base_model self.assertTrue(getattr(model_eager, language_model_name).config._attn_implementation == "eager") self.assertTrue(getattr(model_eager, vision_model_name).config._attn_implementation == "eager") diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 9c0be2abb8ac..cb6111722a33 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -40,6 +40,7 @@ AutoModelForSequenceClassification, BartConfig, BartForConditionalGeneration, + BartModel, CLIPTextModelWithProjection, DynamicCache, GPT2Config, @@ -109,6 +110,8 @@ BertModel, CLIPTextModel, GenerationMixin, + MusicgenConfig, + MusicgenForConditionalGeneration, PreTrainedModel, T5Config, T5ForConditionalGeneration, @@ -578,37 +581,37 @@ def test_model_from_config_dtype_composite(self): """ # Load without dtype specified model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA) - self.assertEqual(model.language_model.dtype, torch.float32) - self.assertEqual(model.vision_tower.dtype, torch.float32) + self.assertEqual(model.model.language_model.dtype, torch.float32) + self.assertEqual(model.model.vision_tower.dtype, torch.float32) self.assertIsInstance(model.config.dtype, torch.dtype) # should be able to set dtype as a simple string and the model loads it correctly model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, dtype="float32") - self.assertEqual(model.language_model.dtype, torch.float32) - self.assertEqual(model.vision_tower.dtype, torch.float32) + self.assertEqual(model.model.language_model.dtype, torch.float32) + self.assertEqual(model.model.vision_tower.dtype, torch.float32) self.assertIsInstance(model.config.dtype, torch.dtype) model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, dtype=torch.float16) - self.assertEqual(model.language_model.dtype, torch.float16) - self.assertEqual(model.vision_tower.dtype, torch.float16) + self.assertEqual(model.model.language_model.dtype, torch.float16) + self.assertEqual(model.model.vision_tower.dtype, torch.float16) self.assertIsInstance(model.config.dtype, torch.dtype) # should be able to set dtype as a dict for each sub-config model = LlavaForConditionalGeneration.from_pretrained( TINY_LLAVA, dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"} ) - self.assertEqual(model.language_model.dtype, torch.float32) - self.assertEqual(model.vision_tower.dtype, torch.float16) - self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16) + self.assertEqual(model.model.language_model.dtype, torch.float32) + self.assertEqual(model.model.vision_tower.dtype, torch.float16) + self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16) self.assertIsInstance(model.config.dtype, torch.dtype) # should be able to set the values as torch.dtype (not str) model = LlavaForConditionalGeneration.from_pretrained( TINY_LLAVA, dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16} ) - self.assertEqual(model.language_model.dtype, torch.float32) - self.assertEqual(model.vision_tower.dtype, torch.float16) - self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16) + self.assertEqual(model.model.language_model.dtype, torch.float32) + self.assertEqual(model.model.vision_tower.dtype, torch.float16) + self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16) self.assertIsInstance(model.config.dtype, torch.dtype) # should be able to set the values in configs directly and pass it to `from_pretrained` @@ -617,17 +620,17 @@ def test_model_from_config_dtype_composite(self): config.vision_config.dtype = torch.bfloat16 config.dtype = torch.float16 model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, dtype="auto") - self.assertEqual(model.language_model.dtype, torch.float32) - self.assertEqual(model.vision_tower.dtype, torch.bfloat16) - self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16) + self.assertEqual(model.model.language_model.dtype, torch.float32) + self.assertEqual(model.model.vision_tower.dtype, torch.bfloat16) + self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.float16) self.assertIsInstance(model.config.dtype, torch.dtype) # but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"] model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, dtype="auto") - self.assertEqual(model.language_model.dtype, torch.float32) - self.assertEqual(model.vision_tower.dtype, torch.bfloat16) - self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32) + self.assertEqual(model.model.language_model.dtype, torch.float32) + self.assertEqual(model.model.vision_tower.dtype, torch.bfloat16) + self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.float32) self.assertIsInstance(model.config.dtype, torch.dtype) # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type @@ -3131,7 +3134,7 @@ def test_nested_wrapper_recursion(self): model = GPT2LMHeadModel(cfg) dec = model.get_decoder() - assert dec is model, f"GPT2 get_decoder() should return self (fallback), got {type(dec)}" + assert dec is model.transformer, f"GPT2 get_decoder() should return self (fallback), got {type(dec)}" def test_model_without_get_decoder(self): """Test edge case where model has model attribute but no get_decoder method.""" @@ -3191,4 +3194,211 @@ def test_vision_language_model(self): model = LlavaForConditionalGeneration(cfg) dec = model.get_decoder() - assert dec is model.language_model, f"LLaVA get_decoder() should return language_model, got {type(dec)}" + assert dec is model.model.language_model, f"LLaVA get_decoder() should return language_model, got {type(dec)}" + + +class TestGetEncoder(unittest.TestCase): + def test_seq2seq_lm_get_encoder_returns_encoder(self): + cfg = BartConfig( + vocab_size=128, + d_model=32, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=4, + decoder_attention_heads=4, + encoder_ffn_dim=64, + decoder_ffn_dim=64, + ) + model = BartForConditionalGeneration(cfg) + encoder = model.get_encoder() + + assert encoder is model.model.encoder, ( + f"Expected get_encoder() to return model.model.encoder, got {type(encoder)}" + ) + + def test_base_model_returns_encoder(self): + cfg = BartConfig( + vocab_size=128, + d_model=32, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=4, + decoder_attention_heads=4, + encoder_ffn_dim=64, + decoder_ffn_dim=64, + ) + model = BartModel(cfg) + encoder = model.get_encoder() + + assert encoder is model.encoder, f"Expected get_encoder() to return model.encoder, got {type(encoder)}" + + def test_decoder_only_model_returns_self(self): + """Test that decoder-only models (no encoder) return self.""" + cfg = MistralConfig( + vocab_size=128, + hidden_size=32, + intermediate_size=64, + num_hidden_layers=2, + num_attention_heads=4, + ) + model = MistralModel(cfg) + encoder = model.get_encoder() + + assert encoder is model, f"Base model get_encoder() should return self, got {type(encoder)}" + + def test_when_encoder_has_different_name(self): + """Test models with non-standard name for encoder modular (Musicgen has `self.model.text_encoder`).""" + cfg = MusicgenConfig( + text_encoder={ + "model_type": "t5", + "vocab_size": 99, + "d_model": 32, + "d_ff": 37, + "num_layers": 2, + "num_heads": 2, + }, + audio_encoder={ + "model_type": "encodec", + "hidden_size": 99, + "compress": 1, + "num_filters": 2, + "codebook_size": 32, + "codebook_dim": 32, + }, + decoder={ + "vocab_size": 99, + "ffn_dim": 32, + "num_attention_heads": 2, + "hidden_size": 32, + "num_hidden_layers": 2, + }, + ) + model = MusicgenForConditionalGeneration(cfg) + encoder = model.get_encoder() + + assert encoder is model.text_encoder, ( + f"MusicgenForConditionalGeneration get_encoder() should return model.model.text_encoder, got {type(encoder)}" + ) + + def test_audio_encoder(self): + """Test models with multiple modality encoders (Musicgen has `self.model.audio_encoder`).""" + cfg = MusicgenConfig( + text_encoder={ + "model_type": "t5", + "vocab_size": 99, + "d_model": 32, + "d_ff": 37, + "num_layers": 2, + "num_heads": 2, + }, + audio_encoder={ + "model_type": "encodec", + "hidden_size": 99, + "compress": 1, + "num_filters": 2, + "codebook_size": 32, + "codebook_dim": 32, + }, + decoder={ + "vocab_size": 99, + "ffn_dim": 32, + "num_attention_heads": 2, + "hidden_size": 32, + "num_hidden_layers": 2, + }, + ) + model = MusicgenForConditionalGeneration(cfg) + encoder = model.get_encoder(modality="audio") + + assert encoder is model.audio_encoder, ( + f"MusicgenForConditionalGeneration get_encoder(modality='audio') should return model.model.audio_encoder, got {type(encoder)}" + ) + + def test_non_existant_modality_throws_error(self): + """Test that an error is thrown when a rquested modality does not exist.""" + cfg = MistralConfig( + vocab_size=128, + hidden_size=32, + intermediate_size=64, + num_hidden_layers=2, + num_attention_heads=4, + ) + model = MistralModel(cfg) + with self.assertRaises(ValueError): + _ = model.get_encoder(modality="3d") + + def test_encoder_return_self_when_modality_not_found(self): + """Test that `self` is returned if the model has no encoder for requested modality.""" + cfg = MistralConfig( + vocab_size=128, + hidden_size=32, + intermediate_size=64, + num_hidden_layers=2, + num_attention_heads=4, + ) + model = MistralModel(cfg) + encoder = model.get_encoder(modality="image") + + assert encoder is model, f"Mistral get_encoder(modality='image') should return self, got {type(encoder)}" + + def test_model_without_get_encoder(self): + """Test edge case where model has model attribute but no get_encoder method.""" + + class MockInnerModel: + """Mock model without get_encoder method.""" + + pass + + class MockWrapperModel: + """Mock wrapper with model attribute but inner has no get_encoder.""" + + def __init__(self): + self.model = MockInnerModel() + + def get_encoder(self): + if hasattr(self, "encoder"): + return self.encoder + if hasattr(self, "model"): + inner = self.model + if hasattr(inner, "get_encoder") and type(inner) is not type(self): + return inner.get_encoder() + return inner + return self + + wrapper = MockWrapperModel() + encoder = wrapper.get_encoder() + + assert encoder is wrapper.model, f"Should return inner model when no get_encoder, got {type(encoder)}" + + def test_vision_language_model(self): + """Test vision-language models like LLaVA can find the modality encoder ("image").""" + text_config = MistralConfig( + vocab_size=128, + hidden_size=32, + intermediate_size=64, + num_hidden_layers=2, + num_attention_heads=4, + ) + + vision_config = { + "hidden_size": 32, + "intermediate_size": 64, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_channels": 3, + "image_size": 224, + "patch_size": 16, + } + + cfg = LlavaConfig( + text_config=text_config.to_dict(), + vision_config=vision_config, + vocab_size=128, + ) + + model = LlavaForConditionalGeneration(cfg) + image_encoder = model.get_encoder(modality="image") + + assert image_encoder is model.model.vision_tower, ( + f"LLaVA get_encoder(modality='image') should return vision_tower, got {type(image_encoder)}" + )