7
7
HfArgumentParser ,
8
8
set_seed ,
9
9
)
10
- from transformers import TrainingArguments
11
10
12
- from tevatron .reranker .arguments import ModelArguments , DataArguments , \
13
- TevatronTrainingArguments
11
+ from tevatron .reranker .arguments import ModelArguments , DataArguments , TevatronTrainingArguments
14
12
from tevatron .reranker .modeling import RerankerModel
15
13
from tevatron .reranker .dataset import RerankerTrainDataset
16
14
from tevatron .reranker .collator import RerankerTrainCollator
17
15
from tevatron .reranker .trainer import RerankerTrainer
16
+ from tevatron .reranker .gc_trainer import GradCacheTrainer
18
17
19
18
logger = logging .getLogger (__name__ )
20
19
21
20
def main ():
22
21
parser = HfArgumentParser ((ModelArguments , DataArguments , TevatronTrainingArguments ))
23
22
24
23
if len (sys .argv ) == 2 and sys .argv [1 ].endswith (".json" ):
25
- model_args , data_args , training_args , tevatron_args = parser .parse_json_file (json_file = os .path .abspath (sys .argv [1 ]))
24
+ model_args , data_args , training_args = parser .parse_json_file (json_file = os .path .abspath (sys .argv [1 ]))
26
25
else :
27
- model_args , data_args , training_args , tevatron_args = parser .parse_args_into_dataclasses ()
28
- model_args : ModelArguments
29
- data_args : DataArguments
30
- training_args : TrainingArguments
31
- tevatron_args : TevatronTrainingArguments
32
-
33
- # Combine TrainingArguments and TevatronTrainingArguments
34
- for key , value in vars (tevatron_args ).items ():
35
- setattr (training_args , key , value )
26
+ model_args , data_args , training_args = parser .parse_args_into_dataclasses ()
36
27
37
28
if (
38
29
os .path .exists (training_args .output_dir )
@@ -60,7 +51,6 @@ def main():
60
51
)
61
52
logger .info ("Training/evaluation parameters %s" , training_args )
62
53
logger .info ("MODEL parameters %s" , model_args )
63
- logger .info ("Tevatron parameters %s" , tevatron_args )
64
54
65
55
set_seed (training_args .seed )
66
56
@@ -81,7 +71,9 @@ def main():
81
71
train_dataset = RerankerTrainDataset (data_args )
82
72
train_collator = RerankerTrainCollator (data_args , tokenizer )
83
73
84
- trainer = RerankerTrainer (
74
+ # Choose the appropriate trainer based on the grad_cache flag
75
+ trainer_cls = GradCacheTrainer if training_args .grad_cache else RerankerTrainer
76
+ trainer = trainer_cls (
85
77
model = model ,
86
78
args = training_args ,
87
79
train_dataset = train_dataset ,
0 commit comments