Skip to content

Commit ca4b04b

Browse files
committed
hotfix: gradient_checkpointing_enable
1 parent bfb8e5a commit ca4b04b

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

src/tevatron/reranker/modeling.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from transformers import TrainingArguments
1010
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
1111

12-
1312
from tevatron.reranker.arguments import ModelArguments
1413

1514
import logging
@@ -22,6 +21,7 @@ class RerankerOutput(ModelOutput):
2221
loss: Optional[Tensor] = None
2322
scores: Optional[Tensor] = None
2423

24+
2525
class RerankerModel(nn.Module):
2626
TRANSFORMER_CLS = AutoModelForSequenceClassification
2727

@@ -49,17 +49,18 @@ def forward(self, pair: Dict[str, Tensor] = None):
4949
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
5050
loss = self.cross_entropy(grouped_logits, self.target_label)
5151
return RerankerOutput(
52-
loss = loss,
53-
scores = ranker_logits
52+
loss=loss,
53+
scores=ranker_logits
5454
)
5555

5656
return RerankerOutput(
57-
loss = None,
58-
scores = ranker_logits
57+
loss=None,
58+
scores=ranker_logits
5959
)
60-
60+
6161
def gradient_checkpointing_enable(self, **kwargs):
62-
self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)
62+
return False
63+
# self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)
6364

6465
@classmethod
6566
def build(
@@ -79,7 +80,9 @@ def build(
7980
base_model.enable_input_require_grads()
8081
if model_args.lora_name_or_path:
8182
lora_config = LoraConfig.from_pretrained(model_args.lora_name_or_path, **hf_kwargs)
82-
lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
83+
lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path,
84+
torch_dtype=torch.bfloat16,
85+
attn_implementation="flash_attention_2")
8386
else:
8487
lora_config = LoraConfig(
8588
base_model_name_or_path=model_args.model_name_or_path,
@@ -107,7 +110,9 @@ def load(cls,
107110
model_name_or_path: str,
108111
lora_name_or_path: str = None,
109112
**hf_kwargs):
110-
base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, num_labels=1, **hf_kwargs, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
113+
base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, num_labels=1, **hf_kwargs,
114+
torch_dtype=torch.bfloat16,
115+
attn_implementation="flash_attention_2")
111116
if base_model.config.pad_token_id is None:
112117
base_model.config.pad_token_id = 0
113118
if lora_name_or_path:

0 commit comments

Comments
 (0)