9
9
from transformers import TrainingArguments
10
10
from peft import LoraConfig , PeftModel , TaskType , get_peft_model
11
11
12
-
13
12
from tevatron .reranker .arguments import ModelArguments
14
13
15
14
import logging
@@ -22,6 +21,7 @@ class RerankerOutput(ModelOutput):
22
21
loss : Optional [Tensor ] = None
23
22
scores : Optional [Tensor ] = None
24
23
24
+
25
25
class RerankerModel (nn .Module ):
26
26
TRANSFORMER_CLS = AutoModelForSequenceClassification
27
27
@@ -49,17 +49,18 @@ def forward(self, pair: Dict[str, Tensor] = None):
49
49
grouped_logits = ranker_logits .view (self .train_batch_size , - 1 )
50
50
loss = self .cross_entropy (grouped_logits , self .target_label )
51
51
return RerankerOutput (
52
- loss = loss ,
53
- scores = ranker_logits
52
+ loss = loss ,
53
+ scores = ranker_logits
54
54
)
55
55
56
56
return RerankerOutput (
57
- loss = None ,
58
- scores = ranker_logits
57
+ loss = None ,
58
+ scores = ranker_logits
59
59
)
60
-
60
+
61
61
def gradient_checkpointing_enable (self , ** kwargs ):
62
- self .hf_model .base_model .model .gradient_checkpointing_enable (** kwargs )
62
+ return False
63
+ # self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)
63
64
64
65
@classmethod
65
66
def build (
@@ -79,7 +80,9 @@ def build(
79
80
base_model .enable_input_require_grads ()
80
81
if model_args .lora_name_or_path :
81
82
lora_config = LoraConfig .from_pretrained (model_args .lora_name_or_path , ** hf_kwargs )
82
- lora_model = PeftModel .from_pretrained (base_model , model_args .lora_name_or_path , torch_dtype = torch .bfloat16 , attn_implementation = "flash_attention_2" )
83
+ lora_model = PeftModel .from_pretrained (base_model , model_args .lora_name_or_path ,
84
+ torch_dtype = torch .bfloat16 ,
85
+ attn_implementation = "flash_attention_2" )
83
86
else :
84
87
lora_config = LoraConfig (
85
88
base_model_name_or_path = model_args .model_name_or_path ,
@@ -107,7 +110,9 @@ def load(cls,
107
110
model_name_or_path : str ,
108
111
lora_name_or_path : str = None ,
109
112
** hf_kwargs ):
110
- base_model = cls .TRANSFORMER_CLS .from_pretrained (model_name_or_path , num_labels = 1 , ** hf_kwargs , torch_dtype = torch .bfloat16 , attn_implementation = "flash_attention_2" )
113
+ base_model = cls .TRANSFORMER_CLS .from_pretrained (model_name_or_path , num_labels = 1 , ** hf_kwargs ,
114
+ torch_dtype = torch .bfloat16 ,
115
+ attn_implementation = "flash_attention_2" )
111
116
if base_model .config .pad_token_id is None :
112
117
base_model .config .pad_token_id = 0
113
118
if lora_name_or_path :
0 commit comments