Skip to content

Commit 2d128f1

Browse files
FIX Beam search w/ mixed adapter batches & encoder (#2921)
When using mixed adapter batches (i.e. using different LoRA adapters in the same batch), users have to pass adapter_names. When simultaneously using beam search, these adapter names have to be extended by the number of beams. For encoder-decoder models, even when applying beam search, the encoder part of the model should, however, not use the extended adapter_names. This is because the encoder still uses the original, non-extended samples. The need for this used to be checked by calling model.get_encoder(). However, with transformers v5, every PretrainedModel will have a get_encoder method. The new convention is that it will return self if there is no encoder. This is now what's being checked. huggingface/transformers#42156 Note that said PR contains a small bug that leads to self not always being returned. Therefore, for the full fix of the issue on transformers main, we also need to await this PR: huggingface/transformers#42295
1 parent 64f9582 commit 2d128f1

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

src/peft/tuners/lora/model.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,19 @@ def _alora_offsets_pre_forward_hook(target, args, kwargs, alora_offsets):
6666
return args, kwargs
6767

6868

69+
def _get_encoder(model: nn.Module) -> nn.Module | None:
70+
"""Check if the model has an encoder and if it has, returns it; otherwise returns None"""
71+
if not hasattr(model, "get_encoder"):
72+
return None
73+
74+
encoder = model.get_encoder()
75+
# https://github.com/huggingface/transformers/pull/42156
76+
# new logic in transformers v5: all PretrainedModels return a model here, but it is self if there is no encoder
77+
if encoder is model:
78+
return None
79+
return encoder
80+
81+
6982
class LoraModel(BaseTuner):
7083
"""
7184
Creates Low Rank Adapter (LoRA) model from a pretrained transformers model.
@@ -438,10 +451,11 @@ def backward_hook(name, module, *grad_output, **kwargs):
438451
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
439452
hook_handles.append(handle)
440453

441-
if uses_beam_search and hasattr(self.model, "get_encoder"):
454+
encoder = _get_encoder(self.model)
455+
if uses_beam_search and (encoder is not None):
442456
# For encoder-decoder models, even when applying beam search, the encoder part of the model should not use
443457
# the extended adapter_names. This is because the encoder still uses the original, non-extended samples.
444-
for module in self.model.get_encoder().modules():
458+
for module in encoder.modules():
445459
if isinstance(module, LoraLayer) or isinstance(module, AuxiliaryTrainingWrapper):
446460
# Add another hook to overwrite the kwargs with the original adapter names -- this is easier than
447461
# trying to exclude the encoder.

0 commit comments

Comments
 (0)