7
7
HfArgumentParser ,
8
8
set_seed ,
9
9
)
10
- from transformers import TrainingArguments
11
10
12
- from tevatron .reranker .arguments import ModelArguments , DataArguments , \
13
- TevatronTrainingArguments
11
+ from tevatron .reranker .arguments import ModelArguments , DataArguments , TevatronTrainingArguments
14
12
from tevatron .reranker .modeling import RerankerModel
15
13
from tevatron .reranker .dataset import RerankerTrainDataset
16
14
from tevatron .reranker .collator import RerankerTrainCollator
@@ -22,17 +20,9 @@ def main():
22
20
parser = HfArgumentParser ((ModelArguments , DataArguments , TevatronTrainingArguments ))
23
21
24
22
if len (sys .argv ) == 2 and sys .argv [1 ].endswith (".json" ):
25
- model_args , data_args , training_args , tevatron_args = parser .parse_json_file (json_file = os .path .abspath (sys .argv [1 ]))
23
+ model_args , data_args , training_args = parser .parse_json_file (json_file = os .path .abspath (sys .argv [1 ]))
26
24
else :
27
- model_args , data_args , training_args , tevatron_args = parser .parse_args_into_dataclasses ()
28
- model_args : ModelArguments
29
- data_args : DataArguments
30
- training_args : TrainingArguments
31
- tevatron_args : TevatronTrainingArguments
32
-
33
- # Combine TrainingArguments and TevatronTrainingArguments
34
- for key , value in vars (tevatron_args ).items ():
35
- setattr (training_args , key , value )
25
+ model_args , data_args , training_args = parser .parse_args_into_dataclasses ()
36
26
37
27
if (
38
28
os .path .exists (training_args .output_dir )
@@ -60,7 +50,6 @@ def main():
60
50
)
61
51
logger .info ("Training/evaluation parameters %s" , training_args )
62
52
logger .info ("MODEL parameters %s" , model_args )
63
- logger .info ("Tevatron parameters %s" , tevatron_args )
64
53
65
54
set_seed (training_args .seed )
66
55
0 commit comments