-
Notifications
You must be signed in to change notification settings - Fork 29.8k
Open
Labels
Good First IssueUsageGeneral questions about the libraryGeneral questions about the librarybugtrainer
Description
System Info
transformers
version: 4.53.2- Platform: Ubuntu 22.04 Linux 5.15.0-139-generic
- Python 3.10.18 + ipykernel 6.29.5
- Pytorch 2.7.1+cu118
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
I want to build a new MT model with bert-based encoder and a decoder from opus-mt-en-zh (loaded as MarianMTModel
), BUT when I execute Trainer.train()
, It report ValueError: You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time
. This is code about my model and trainer.
Thanks for helping!
# ManchuBERT Encoder + Opus-MT-zh Decoder
import torch
from torch import nn
from transformers.modeling_outputs import Seq2SeqLMOutput
def get_extended_attention_mask(attention_mask, input_shape, device, dtype=torch.float32):
"""
attention_mask: [B, seq_len]
return: [B, 1, 1, seq_len]
"""
mask = attention_mask[:, None, None, :] # [B, 1, 1, seq_len]
mask = mask.to(dtype=dtype)
mask = (1.0 - mask) * -10000.0
return mask
class ManchuZhMT(nn.Module):
def __init__(self, bert, marian):
super().__init__()
self.decoder_embeddings = marian.model.decoder.embed_tokens
self.embeddings = bert.embeddings
self.encoder = bert.encoder
self.decoder = marian.model.decoder
self.lm_head = marian.lm_head
self.final_logits_bias = marian.final_logits_bias
self.config = marian.config
def forward(self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
labels=None,
**kwargs):
hidden_states = self.embeddings(input_ids=input_ids)
attention_mask = attention_mask.to(dtype=torch.float32)
extended_mask = get_extended_attention_mask(attention_mask, input_ids.shape, input_ids.device)
enc_out = self.encoder(hidden_states=hidden_states,
attention_mask=extended_mask,
return_dict=True)
dec_out = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=enc_out.last_hidden_state,
encoder_attention_mask=extended_mask,
return_dict=True)
logits = self.lm_head(dec_out.last_hidden_state) + self.final_logits_bias
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return Seq2SeqLMOutput(loss=loss, logits=logits)
def prepare_inputs_for_generation(self, *args, **kwargs):
return self.decoder.prepare_inputs_for_generation(*args, **kwargs)
def _prepare_encoder_decoder_kwargs_for_generation(self, *args, **kwargs):
return self.decoder._prepare_encoder_decoder_kwargs_for_generation(*args, **kwargs)
model = ManchuZhMT(manchu_model, chn_model)
print(model)
# freeze Decoder + LM Head
for p in model.decoder.parameters():
p.requires_grad = False
for p in model.lm_head.parameters():
p.requires_grad = False
# Add LoRA for Encoder
from peft import LoraConfig, get_peft_model, TaskType
num_layers = len(model.encoder.layer)
target_modules = []
for i in range(num_layers):
target_modules.extend([
f"encoder.layer.{i}.attention.self.query",
f"encoder.layer.{i}.attention.self.key",
f"encoder.layer.{i}.attention.self.value",
f"encoder.layer.{i}.attention.output.dense",
f"encoder.layer.{i}.intermediate.dense",
f"encoder.layer.{i}.output.dense",
])
lora_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
target_modules=target_modules,
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Start Train!
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
args = Seq2SeqTrainingArguments(
output_dir="./lora_with_bert",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=10,
learning_rate=3e-4,
fp16=True,
save_strategy="epoch",
predict_with_generate=True,
logging_steps=100,
report_to="none",
)
trainer = Seq2SeqTrainer(
model=model,
args=args,
train_dataset=tokenized_ds["train"],
eval_dataset=tokenized_ds["val"],
tokenizer=manchu_tok,
)
trainer.train()
trainer.save_model("./lora_with_bert/final")
Expected behavior
I expected the script to train normally just as using opus-mt-en-zh
only and get the lora checkpoint.
hubsbendang
Metadata
Metadata
Assignees
Labels
Good First IssueUsageGeneral questions about the libraryGeneral questions about the librarybugtrainer