1
1
import os
2
2
from dataclasses import dataclass
3
- from typing import Dict , Optional
3
+ from typing import Dict , Optional , Union
4
4
5
5
import torch
6
6
from torch import nn , Tensor
@@ -43,8 +43,16 @@ def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
43
43
param .data = nn .Linear (self .hf_model .config .hidden_size , 1 ).weight .data
44
44
logger .warning ('{} data: {}' .format (name , param .data .cpu ().numpy ()))
45
45
46
- def forward (self , pair : Dict [str , Tensor ] = None ):
47
- ranker_logits = self .hf_model (** pair , return_dict = True ).logits
46
+ def forward (self , input_ids : Tensor = None , attention_mask : Tensor = None , token_type_ids : Tensor = None , ** kwargs ):
47
+ model_inputs = {
48
+ 'input_ids' : input_ids ,
49
+ 'attention_mask' : attention_mask ,
50
+ }
51
+ if token_type_ids is not None :
52
+ model_inputs ['token_type_ids' ] = token_type_ids
53
+
54
+ ranker_logits = self .hf_model (** model_inputs , return_dict = True ).logits
55
+
48
56
if self .train_batch_size :
49
57
grouped_logits = ranker_logits .view (self .train_batch_size , - 1 )
50
58
loss = self .cross_entropy (grouped_logits , self .target_label )
@@ -60,7 +68,6 @@ def forward(self, pair: Dict[str, Tensor] = None):
60
68
61
69
def gradient_checkpointing_enable (self , ** kwargs ):
62
70
return False
63
- # self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)
64
71
65
72
@classmethod
66
73
def build (
0 commit comments