Skip to content

Commit e2fb8d6

Browse files
authored
🚨 Generalize get_decoder() for multimodal and delete redundant code 🔪 (#42156)
* update some models * update the rest * add helper for encoder * delete encoder code from models * fix copies * fix some tests but VLM will fail * add encider tests simialr to decoder * no print * fix overwritten models * and a million exceptions with old audio models, revert back
1 parent a5c903f commit e2fb8d6

File tree

101 files changed

+346
-1330
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

101 files changed

+346
-1330
lines changed

‎src/transformers/modeling_utils.py‎

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,26 +2205,78 @@ def disable_input_require_grads(self):
22052205
"""
22062206
self._require_grads_hook.remove()
22072207

2208+
def get_encoder(self, modality: Optional[str] = None):
2209+
"""
2210+
Best-effort lookup of the *encoder* module. If provided with `modality` argument,
2211+
it looks for a modality-specific encoder in multimodal models (e.g. "image_encoder")
2212+
By default the function returns model's text encoder if any, and otherwise returns `self`.
2213+
2214+
Possible `modality` values are "image", "video" and "audio".
2215+
"""
2216+
# NOTE: new models need to use existing names for layers if possible, so this list doesn't grow infinitely
2217+
if modality in ["image", "video"]:
2218+
possible_module_names = ["vision_tower", "visual", "vision_model", "vision_encoder", "image_tower"]
2219+
elif modality == "audio":
2220+
possible_module_names = ["audio_tower", "audio_encoder", "speech_encoder"]
2221+
elif modality is None:
2222+
possible_module_names = ["text_encoder", "encoder"]
2223+
else:
2224+
raise ValueError(f'Unnrecognized modality, has to be "image", "video" or "audio" but found {modality}')
2225+
2226+
for name in possible_module_names:
2227+
if hasattr(self, name):
2228+
return getattr(self, name)
2229+
2230+
if self.base_model is not self and hasattr(self.base_model, "get_encoder"):
2231+
return self.base_model.get_encoder(modality=modality)
2232+
2233+
# If this is a base transformer model (no encoder/model attributes), return self
2234+
return self
2235+
2236+
def set_encoder(self, encoder, modality: Optional[str] = None):
2237+
"""
2238+
Symmetric setter. Mirrors the lookup logic used in `get_encoder`.
2239+
"""
2240+
2241+
# NOTE: new models need to use existing names for layers if possible, so this list doesn't grow infinitely
2242+
if modality in ["image", "video"]:
2243+
possible_module_names = ["vision_tower", "visual", "vision_model", "vision_encoder", "image_tower"]
2244+
if modality == "audio":
2245+
possible_module_names = ["audio_tower", "audio_encoder"]
2246+
elif modality is None:
2247+
possible_module_names = ["text_encoder", "encoder"]
2248+
else:
2249+
raise ValueError(f'Unnrecognized modality, has to be "image", "video" or "audio" but found {modality}')
2250+
2251+
for name in possible_module_names:
2252+
if hasattr(self, name):
2253+
setattr(self, name, encoder)
2254+
return
2255+
2256+
if self.base_model is not self:
2257+
if hasattr(self.base_model, "set_encoder"):
2258+
self.base_model.set_encoder(encoder, modality=modality)
2259+
else:
2260+
self.model = encoder
2261+
22082262
def get_decoder(self):
22092263
"""
22102264
Best-effort lookup of the *decoder* module.
22112265
22122266
Order of attempts (covers ~85 % of current usages):
22132267
2214-
1. `self.decoder`
2215-
2. `self.model` (many wrappers store the decoder here)
2216-
3. `self.model.get_decoder()` (nested wrappers)
2268+
1. `self.decoder/self.language_model/self.text_model`
2269+
2. `self.base_model` (many wrappers store the decoder here)
2270+
3. `self.base_model.get_decoder()` (nested wrappers)
22172271
4. fallback: raise for the few exotic models that need a bespoke rule
22182272
"""
2219-
if hasattr(self, "decoder"):
2220-
return self.decoder
2273+
possible_module_names = ["language_model", "text_model", "decoder", "text_decoder"]
2274+
for name in possible_module_names:
2275+
if hasattr(self, name):
2276+
return getattr(self, name)
22212277

2222-
if hasattr(self, "model"):
2223-
inner = self.model
2224-
# See: https://github.com/huggingface/transformers/issues/40815
2225-
if hasattr(inner, "get_decoder") and type(inner) is not type(self):
2226-
return inner.get_decoder()
2227-
return inner
2278+
if self.base_model is not self and hasattr(self.base_model, "get_decoder"):
2279+
return self.base_model.get_decoder()
22282280

22292281
# If this is a base transformer model (no decoder/model attributes), return self
22302282
# This handles cases like MistralModel which is itself the decoder
@@ -2235,19 +2287,18 @@ def set_decoder(self, decoder):
22352287
Symmetric setter. Mirrors the lookup logic used in `get_decoder`.
22362288
"""
22372289

2238-
if hasattr(self, "decoder"):
2239-
self.decoder = decoder
2240-
return
2290+
possible_module_names = ["language_model", "text_model", "decoder"]
2291+
for name in possible_module_names:
2292+
if hasattr(self, name):
2293+
print(name)
2294+
setattr(self, name, decoder)
2295+
return
22412296

2242-
if hasattr(self, "model"):
2243-
inner = self.model
2244-
if hasattr(inner, "set_decoder"):
2245-
inner.set_decoder(decoder)
2297+
if self.base_model is not self:
2298+
if hasattr(self.base_model, "set_decoder"):
2299+
self.base_model.set_decoder(decoder)
22462300
else:
22472301
self.model = decoder
2248-
return
2249-
2250-
return
22512302

22522303
@torch.no_grad()
22532304
def _init_weights(self, module):

‎src/transformers/models/aria/modeling_aria.py‎

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -910,12 +910,6 @@ def get_input_embeddings(self):
910910
def set_input_embeddings(self, value):
911911
self.language_model.set_input_embeddings(value)
912912

913-
def set_decoder(self, decoder):
914-
self.language_model = decoder
915-
916-
def get_decoder(self):
917-
return self.language_model
918-
919913
def get_image_features(
920914
self,
921915
pixel_values: torch.FloatTensor,
@@ -1075,12 +1069,6 @@ def set_input_embeddings(self, value):
10751069
def get_output_embeddings(self) -> nn.Module:
10761070
return self.lm_head
10771071

1078-
def set_decoder(self, decoder):
1079-
self.model.set_decoder(decoder)
1080-
1081-
def get_decoder(self):
1082-
return self.model.get_decoder()
1083-
10841072
def get_image_features(
10851073
self,
10861074
pixel_values: torch.FloatTensor,
@@ -1093,19 +1081,6 @@ def get_image_features(
10931081
vision_feature_layer=vision_feature_layer,
10941082
)
10951083

1096-
# Make modules available through conditional class for BC
1097-
@property
1098-
def language_model(self):
1099-
return self.model.language_model
1100-
1101-
@property
1102-
def vision_tower(self):
1103-
return self.model.vision_tower
1104-
1105-
@property
1106-
def multi_modal_projector(self):
1107-
return self.model.multi_modal_projector
1108-
11091084
@can_return_tuple
11101085
@auto_docstring
11111086
def forward(

‎src/transformers/models/autoformer/modeling_autoformer.py‎

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,9 +1342,6 @@ def create_network_inputs(
13421342
)
13431343
return reshaped_lagged_sequence, features, loc, scale, static_feat
13441344

1345-
def get_encoder(self):
1346-
return self.encoder
1347-
13481345
@auto_docstring
13491346
def forward(
13501347
self,
@@ -1588,12 +1585,6 @@ def __init__(self, config: AutoformerConfig):
15881585
def output_params(self, decoder_output):
15891586
return self.parameter_projection(decoder_output[:, -self.config.prediction_length :, :])
15901587

1591-
def get_encoder(self):
1592-
return self.model.get_encoder()
1593-
1594-
def get_decoder(self):
1595-
return self.model.get_decoder()
1596-
15971588
@torch.jit.ignore
15981589
def output_distribution(self, params, loc=None, scale=None, trailing_n=None) -> torch.distributions.Distribution:
15991590
sliced_params = params

‎src/transformers/models/aya_vision/modeling_aya_vision.py‎

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,6 @@ def get_input_embeddings(self):
181181
def set_input_embeddings(self, value):
182182
self.language_model.set_input_embeddings(value)
183183

184-
def set_decoder(self, decoder):
185-
self.language_model = decoder
186-
187-
def get_decoder(self):
188-
return self.language_model
189-
190184
def get_image_features(
191185
self,
192186
pixel_values: torch.FloatTensor,
@@ -357,12 +351,6 @@ def set_input_embeddings(self, value):
357351
def get_output_embeddings(self) -> nn.Module:
358352
return self.lm_head
359353

360-
def set_decoder(self, decoder):
361-
self.model.set_decoder(decoder)
362-
363-
def get_decoder(self):
364-
return self.model.get_decoder()
365-
366354
def get_image_features(
367355
self,
368356
pixel_values: torch.FloatTensor,
@@ -377,19 +365,6 @@ def get_image_features(
377365
**kwargs,
378366
)
379367

380-
# Make modules available through conditional class for BC
381-
@property
382-
def language_model(self):
383-
return self.model.language_model
384-
385-
@property
386-
def vision_tower(self):
387-
return self.model.vision_tower
388-
389-
@property
390-
def multi_modal_projector(self):
391-
return self.model.multi_modal_projector
392-
393368
@can_return_tuple
394369
@auto_docstring
395370
def forward(

‎src/transformers/models/bart/modeling_bart.py‎

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -905,9 +905,6 @@ def set_input_embeddings(self, value):
905905
self.encoder.embed_tokens = self.shared
906906
self.decoder.embed_tokens = self.shared
907907

908-
def get_encoder(self):
909-
return self.encoder
910-
911908
@auto_docstring
912909
def forward(
913910
self,
@@ -1037,12 +1034,6 @@ def __init__(self, config: BartConfig):
10371034
# Initialize weights and apply final processing
10381035
self.post_init()
10391036

1040-
def get_encoder(self):
1041-
return self.model.get_encoder()
1042-
1043-
def get_decoder(self):
1044-
return self.model.get_decoder()
1045-
10461037
def resize_token_embeddings(
10471038
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
10481039
) -> nn.Embedding:
@@ -1498,12 +1489,6 @@ def get_input_embeddings(self):
14981489
def set_input_embeddings(self, value):
14991490
self.model.decoder.embed_tokens = value
15001491

1501-
def set_decoder(self, decoder):
1502-
self.model.decoder = decoder
1503-
1504-
def get_decoder(self):
1505-
return self.model.decoder
1506-
15071492
@auto_docstring
15081493
def forward(
15091494
self,

‎src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py‎

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2083,9 +2083,6 @@ def set_input_embeddings(self, value):
20832083
self.encoder.embed_tokens = self.shared
20842084
self.decoder.embed_tokens = self.shared
20852085

2086-
def get_encoder(self):
2087-
return self.encoder
2088-
20892086
@auto_docstring
20902087
def forward(
20912088
self,
@@ -2205,12 +2202,6 @@ def __init__(self, config: BigBirdPegasusConfig):
22052202
# Initialize weights and apply final processing
22062203
self.post_init()
22072204

2208-
def get_encoder(self):
2209-
return self.model.get_encoder()
2210-
2211-
def get_decoder(self):
2212-
return self.model.get_decoder()
2213-
22142205
def resize_token_embeddings(
22152206
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
22162207
) -> nn.Embedding:
@@ -2609,12 +2600,6 @@ def get_input_embeddings(self):
26092600
def set_input_embeddings(self, value):
26102601
self.model.decoder.embed_tokens = value
26112602

2612-
def set_decoder(self, decoder):
2613-
self.model.decoder = decoder
2614-
2615-
def get_decoder(self):
2616-
return self.model.decoder
2617-
26182603
@auto_docstring
26192604
def forward(
26202605
self,

‎src/transformers/models/blenderbot/modeling_blenderbot.py‎

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -869,9 +869,6 @@ def set_input_embeddings(self, value):
869869
self.encoder.embed_tokens = self.shared
870870
self.decoder.embed_tokens = self.shared
871871

872-
def get_encoder(self):
873-
return self.encoder
874-
875872
@auto_docstring
876873
def forward(
877874
self,
@@ -1009,12 +1006,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10091006

10101007
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
10111008

1012-
def get_encoder(self):
1013-
return self.model.get_encoder()
1014-
1015-
def get_decoder(self):
1016-
return self.model.get_decoder()
1017-
10181009
def resize_token_embeddings(
10191010
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
10201011
) -> nn.Embedding:
@@ -1189,12 +1180,6 @@ def get_input_embeddings(self):
11891180
def set_input_embeddings(self, value):
11901181
self.model.decoder.embed_tokens = value
11911182

1192-
def set_decoder(self, decoder):
1193-
self.model.decoder = decoder
1194-
1195-
def get_decoder(self):
1196-
return self.model.decoder
1197-
11981183
@auto_docstring
11991184
def forward(
12001185
self,

‎src/transformers/models/blenderbot_small/modeling_blenderbot_small.py‎

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -842,9 +842,6 @@ def set_input_embeddings(self, value):
842842
self.encoder.embed_tokens = self.shared
843843
self.decoder.embed_tokens = self.shared
844844

845-
def get_encoder(self):
846-
return self.encoder
847-
848845
@auto_docstring
849846
def forward(
850847
self,
@@ -969,12 +966,6 @@ def __init__(self, config: BlenderbotSmallConfig):
969966
# Initialize weights and apply final processing
970967
self.post_init()
971968

972-
def get_encoder(self):
973-
return self.model.get_encoder()
974-
975-
def get_decoder(self):
976-
return self.model.get_decoder()
977-
978969
def resize_token_embeddings(
979970
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
980971
) -> nn.Embedding:
@@ -1149,12 +1140,6 @@ def get_input_embeddings(self):
11491140
def set_input_embeddings(self, value):
11501141
self.model.decoder.embed_tokens = value
11511142

1152-
def set_decoder(self, decoder):
1153-
self.model.decoder = decoder
1154-
1155-
def get_decoder(self):
1156-
return self.model.decoder
1157-
11581143
@auto_docstring
11591144
def forward(
11601145
self,

0 commit comments

Comments
 (0)