Skip to content

Commit 17b889a

Browse files
committed
fix: trainer
1 parent bb5d87c commit 17b889a

File tree

1 file changed

+7
-15
lines changed

1 file changed

+7
-15
lines changed

src/tevatron/reranker/driver/train.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,23 @@
77
HfArgumentParser,
88
set_seed,
99
)
10-
from transformers import TrainingArguments
1110

12-
from tevatron.reranker.arguments import ModelArguments, DataArguments, \
13-
TevatronTrainingArguments
11+
from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments
1412
from tevatron.reranker.modeling import RerankerModel
1513
from tevatron.reranker.dataset import RerankerTrainDataset
1614
from tevatron.reranker.collator import RerankerTrainCollator
1715
from tevatron.reranker.trainer import RerankerTrainer
16+
from tevatron.reranker.gc_trainer import GradCacheTrainer
1817

1918
logger = logging.getLogger(__name__)
2019

2120
def main():
2221
parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments))
2322

2423
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
25-
model_args, data_args, training_args, tevatron_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
24+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
2625
else:
27-
model_args, data_args, training_args, tevatron_args = parser.parse_args_into_dataclasses()
28-
model_args: ModelArguments
29-
data_args: DataArguments
30-
training_args: TrainingArguments
31-
tevatron_args: TevatronTrainingArguments
32-
33-
# Combine TrainingArguments and TevatronTrainingArguments
34-
for key, value in vars(tevatron_args).items():
35-
setattr(training_args, key, value)
26+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
3627

3728
if (
3829
os.path.exists(training_args.output_dir)
@@ -60,7 +51,6 @@ def main():
6051
)
6152
logger.info("Training/evaluation parameters %s", training_args)
6253
logger.info("MODEL parameters %s", model_args)
63-
logger.info("Tevatron parameters %s", tevatron_args)
6454

6555
set_seed(training_args.seed)
6656

@@ -81,7 +71,9 @@ def main():
8171
train_dataset = RerankerTrainDataset(data_args)
8272
train_collator = RerankerTrainCollator(data_args, tokenizer)
8373

84-
trainer = RerankerTrainer(
74+
# Choose the appropriate trainer based on the grad_cache flag
75+
trainer_cls = GradCacheTrainer if training_args.grad_cache else RerankerTrainer
76+
trainer = trainer_cls(
8577
model=model,
8678
args=training_args,
8779
train_dataset=train_dataset,

0 commit comments

Comments
 (0)