|
7 | 7 | HfArgumentParser,
|
8 | 8 | set_seed,
|
9 | 9 | )
|
10 |
| -from transformers import TrainingArguments |
11 | 10 |
|
12 |
| -from tevatron.reranker.arguments import ModelArguments, DataArguments, \ |
13 |
| - TevatronTrainingArguments |
| 11 | +from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments |
14 | 12 | from tevatron.reranker.modeling import RerankerModel
|
15 | 13 | from tevatron.reranker.dataset import RerankerTrainDataset
|
16 | 14 | from tevatron.reranker.collator import RerankerTrainCollator
|
17 | 15 | from tevatron.reranker.trainer import RerankerTrainer
|
| 16 | +from tevatron.reranker.gc_trainer import GradCacheTrainer |
18 | 17 |
|
19 | 18 | logger = logging.getLogger(__name__)
|
20 | 19 |
|
21 | 20 | def main():
|
22 | 21 | parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments))
|
23 | 22 |
|
24 | 23 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
25 |
| - model_args, data_args, training_args, tevatron_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
| 24 | + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
26 | 25 | else:
|
27 |
| - model_args, data_args, training_args, tevatron_args = parser.parse_args_into_dataclasses() |
28 |
| - model_args: ModelArguments |
29 |
| - data_args: DataArguments |
30 |
| - training_args: TrainingArguments |
31 |
| - tevatron_args: TevatronTrainingArguments |
32 |
| - |
33 |
| - # Combine TrainingArguments and TevatronTrainingArguments |
34 |
| - for key, value in vars(tevatron_args).items(): |
35 |
| - setattr(training_args, key, value) |
| 26 | + model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
36 | 27 |
|
37 | 28 | if (
|
38 | 29 | os.path.exists(training_args.output_dir)
|
@@ -60,7 +51,6 @@ def main():
|
60 | 51 | )
|
61 | 52 | logger.info("Training/evaluation parameters %s", training_args)
|
62 | 53 | logger.info("MODEL parameters %s", model_args)
|
63 |
| - logger.info("Tevatron parameters %s", tevatron_args) |
64 | 54 |
|
65 | 55 | set_seed(training_args.seed)
|
66 | 56 |
|
|
0 commit comments