1
- from tevatron .reranker .modeling import RerankerOutput
2
- from tevatron .retriever .trainer import TevatronTrainer
1
+ import logging
2
+ from typing import Dict , Union , Any
3
+
4
+ import torch
5
+ from torch import nn
6
+ from transformers import Trainer
7
+ from transformers .trainer_utils import PredictionOutput
8
+
9
+ from tevatron .arguments import TevatronTrainingArguments
3
10
from grad_cache import GradCache
4
11
12
+ logger = logging .getLogger (__name__ )
13
+
5
14
def split_inputs (model_input , chunk_size ):
15
+ logger .debug (f"Splitting inputs with chunk size: { chunk_size } " )
6
16
keys = list (model_input .keys ())
7
17
chunked_tensors = [model_input [k ].split (chunk_size , dim = 0 ) for k in keys ]
8
18
return [dict (zip (keys , tt )) for tt in zip (* chunked_tensors )]
9
19
10
- def get_rep (x : RerankerOutput ):
11
- return x .scores
20
+ def get_rep (model_output ):
21
+ logger .debug (f"Getting representation from model output: { type (model_output )} " )
22
+ return model_output .scores
12
23
13
- class RerankerTrainer (TevatronTrainer ):
24
+ class RerankerTrainer (Trainer ):
14
25
def __init__ (self , * args , ** kwargs ):
15
26
super ().__init__ (* args , ** kwargs )
16
- loss_fn = lambda x , y : self .compute_loss (self .model , {'input_ids' : x , 'labels' : y })
27
+ logger .info ("Initializing RerankerTrainer" )
28
+ self .args : TevatronTrainingArguments
29
+
30
+ def loss_fn (scores , labels ):
31
+ grouped_scores = scores .view (self .args .train_group_size , - 1 )
32
+ labels = torch .zeros (self .args .train_group_size , dtype = torch .long , device = scores .device )
33
+ return nn .CrossEntropyLoss ()(grouped_scores , labels )
34
+
17
35
self .gc = GradCache (
18
36
models = [self .model ],
19
37
chunk_sizes = [self .args .gc_chunk_size ],
@@ -23,14 +41,37 @@ def __init__(self, *args, **kwargs):
23
41
fp16 = self .args .fp16 ,
24
42
scaler = self .scaler if self .args .fp16 else None
25
43
)
44
+ logger .info (f"GradCache initialized with chunk size: { self .args .gc_chunk_size } " )
26
45
27
46
def compute_loss (self , model , inputs , return_outputs = False ):
47
+ logger .debug (f"Computing loss with inputs: { inputs .keys ()} " )
28
48
outputs = model (** inputs )
29
49
loss = outputs .loss
50
+ logger .debug (f"Computed loss: { loss .item ()} " )
30
51
return (loss , outputs ) if return_outputs else loss
31
52
32
- def training_step (self , model , inputs ):
53
+ def training_step (self , model : nn .Module , inputs : Dict [str , Union [torch .Tensor , Any ]]) -> torch .Tensor :
54
+ logger .debug ("Entering training step" )
55
+ model .train ()
56
+ inputs = self ._prepare_inputs (inputs )
33
57
_distributed = self .args .local_rank > - 1
34
58
self .gc .models = [model ]
35
59
loss = self .gc (inputs , no_sync_except_last = _distributed )
60
+ logger .debug (f"Training step loss: { loss .item ()} " )
36
61
return loss
62
+
63
+ def prediction_step (
64
+ self ,
65
+ model : nn .Module ,
66
+ inputs : Dict [str , Union [torch .Tensor , Any ]],
67
+ prediction_loss_only : bool ,
68
+ ignore_keys : bool = None ,
69
+ ) -> PredictionOutput :
70
+ logger .debug ("Entering prediction step" )
71
+ inputs = self ._prepare_inputs (inputs )
72
+ with torch .no_grad ():
73
+ outputs = model (** inputs )
74
+ loss = outputs .loss
75
+ logits = outputs .scores
76
+ logger .debug (f"Prediction step loss: { loss .item () if loss is not None else 'N/A' } " )
77
+ return PredictionOutput (predictions = logits , label_ids = inputs .get ("labels" ), metrics = None )
0 commit comments