Skip to content

Commit 27e90d8

Browse files
committed
fix: forward method
1 parent ca4b04b commit 27e90d8

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/tevatron/reranker/modeling.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from dataclasses import dataclass
3-
from typing import Dict, Optional
3+
from typing import Dict, Optional, Union
44

55
import torch
66
from torch import nn, Tensor
@@ -43,8 +43,16 @@ def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
4343
param.data = nn.Linear(self.hf_model.config.hidden_size, 1).weight.data
4444
logger.warning('{} data: {}'.format(name, param.data.cpu().numpy()))
4545

46-
def forward(self, pair: Dict[str, Tensor] = None):
47-
ranker_logits = self.hf_model(**pair, return_dict=True).logits
46+
def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, token_type_ids: Tensor = None, **kwargs):
47+
model_inputs = {
48+
'input_ids': input_ids,
49+
'attention_mask': attention_mask,
50+
}
51+
if token_type_ids is not None:
52+
model_inputs['token_type_ids'] = token_type_ids
53+
54+
ranker_logits = self.hf_model(**model_inputs, return_dict=True).logits
55+
4856
if self.train_batch_size:
4957
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
5058
loss = self.cross_entropy(grouped_logits, self.target_label)
@@ -60,7 +68,6 @@ def forward(self, pair: Dict[str, Tensor] = None):
6068

6169
def gradient_checkpointing_enable(self, **kwargs):
6270
return False
63-
# self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)
6471

6572
@classmethod
6673
def build(

0 commit comments

Comments
 (0)