7
7
from transformers .trainer_utils import PredictionOutput
8
8
9
9
from grad_cache import GradCache
10
-
11
10
from grad_cache .functional import cached , cat_input_tensor
12
11
from torch .cuda .amp import autocast
13
12
14
13
logger = logging .getLogger (__name__ )
15
14
16
-
17
15
@cached
18
16
@autocast ()
19
17
def get_model_rep (model , inputs ):
20
18
outputs = model (** inputs )
21
19
return outputs .scores
22
20
23
-
24
21
@cat_input_tensor
25
22
@autocast ()
26
23
def contrastive_loss (scores ):
27
24
batch_size = scores .size (0 ) // 2
28
25
labels = torch .arange (batch_size , device = scores .device )
29
26
return nn .CrossEntropyLoss ()(scores , labels )
30
27
31
-
32
28
def split_inputs (model_input , chunk_size ):
33
29
logger .debug (f"Splitting inputs with chunk size: { chunk_size } " )
34
30
keys = list (model_input .keys ())
35
31
chunked_tensors = [model_input [k ].split (chunk_size , dim = 0 ) for k in keys ]
36
32
return [dict (zip (keys , tt )) for tt in zip (* chunked_tensors )]
37
33
38
-
39
34
class RerankerTrainer (Trainer ):
40
35
def __init__ (self , * args , ** kwargs ):
41
36
super ().__init__ (* args , ** kwargs )
42
37
logger .info ("Initializing RerankerTrainer with GradCache" )
43
38
self .args : TrainingArguments
44
39
45
- # Add these lines to include the necessary parameters
46
- self .gc_chunk_size = getattr (self .args , 'gc_chunk_size' , 4 ) # default to 4 if not provided
40
+ self .gc_chunk_size = getattr (self .args , 'gc_chunk_size' , 4 )
41
+
42
+ # If the model is wrapped in DDP, we need to use the .module attribute
43
+ model_for_gc = self .model .module if hasattr (self .model , 'module' ) else self .model
47
44
48
45
self .gc = GradCache (
49
- models = [self . model ],
46
+ models = [model_for_gc ],
50
47
chunk_sizes = self .gc_chunk_size ,
51
48
loss_fn = contrastive_loss ,
52
49
split_input_fn = split_inputs ,
@@ -68,17 +65,17 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
68
65
logger .debug ("Entering training step" )
69
66
model .train ()
70
67
inputs = self ._prepare_inputs (inputs )
71
- _distributed = self .args .local_rank > - 1
68
+ _distributed = self .args .local_rank != - 1
72
69
loss = self .gc (inputs , no_sync_except_last = _distributed )
73
70
logger .debug (f"Training step loss: { loss .item ()} " )
74
71
return loss
75
72
76
73
def prediction_step (
77
- self ,
78
- model : nn .Module ,
79
- inputs : Dict [str , Union [torch .Tensor , Any ]],
80
- prediction_loss_only : bool ,
81
- ignore_keys : bool = None ,
74
+ self ,
75
+ model : nn .Module ,
76
+ inputs : Dict [str , Union [torch .Tensor , Any ]],
77
+ prediction_loss_only : bool ,
78
+ ignore_keys : bool = None ,
82
79
) -> PredictionOutput :
83
80
logger .debug ("Entering prediction step" )
84
81
inputs = self ._prepare_inputs (inputs )
@@ -87,4 +84,4 @@ def prediction_step(
87
84
scores = outputs .scores
88
85
loss = contrastive_loss (scores )
89
86
logger .debug (f"Prediction step loss: { loss .item () if loss is not None else 'N/A' } " )
90
- return PredictionOutput (predictions = scores , label_ids = None , metrics = None )
87
+ return PredictionOutput (predictions = scores , label_ids = None , metrics = None )
0 commit comments