Skip to content

Commit 7626bbf

Browse files
committed
fix: ddp
1 parent 43a642d commit 7626bbf

File tree

2 files changed

+53
-34
lines changed

2 files changed

+53
-34
lines changed
Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,58 @@
11
import logging
22
import os
33
import sys
4+
import torch
45
from transformers import AutoTokenizer
56
from transformers import (
67
HfArgumentParser,
78
set_seed,
89
)
10+
from torch.nn.parallel import DistributedDataParallel as DDP
11+
import torch.distributed as dist
912
from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments
1013
from tevatron.reranker.modeling import RerankerModel
1114
from tevatron.reranker.dataset import RerankerTrainDataset
1215
from tevatron.reranker.collator import RerankerTrainCollator
13-
from tevatron.reranker.trainer import RerankerTrainer # Make sure this is your updated RerankerTrainer
16+
from tevatron.reranker.trainer import RerankerTrainer
1417

1518
logger = logging.getLogger(__name__)
1619

1720

21+
def setup_ddp():
22+
if not dist.is_initialized():
23+
dist.init_process_group(backend="nccl")
24+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
25+
torch.cuda.set_device(local_rank)
26+
return local_rank
27+
28+
1829
def main():
1930
parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments))
2031

32+
parser.add_argument('--bf16', action='store_true', help='Use bfloat16 precision')
33+
2134
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
2235
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
2336
else:
2437
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
2538

26-
if (
27-
os.path.exists(training_args.output_dir)
28-
and os.listdir(training_args.output_dir)
29-
and training_args.do_train
30-
and not training_args.overwrite_output_dir
31-
):
32-
raise ValueError(
33-
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
34-
)
39+
local_rank = -1
40+
if training_args.local_rank != -1:
41+
local_rank = setup_ddp()
3542

3643
# Setup logging
3744
logging.basicConfig(
3845
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
3946
datefmt="%m/%d/%Y %H:%M:%S",
40-
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
47+
level=logging.INFO if local_rank in [-1, 0] else logging.WARN,
4148
)
4249
logger.warning(
4350
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
44-
training_args.local_rank,
51+
local_rank,
4552
training_args.device,
4653
training_args.n_gpu,
47-
bool(training_args.local_rank != -1),
48-
training_args.fp16,
54+
bool(local_rank != -1),
55+
training_args.fp16 or training_args.bf16,
4956
)
5057
logger.info("Training/evaluation parameters %s", training_args)
5158
logger.info("MODEL parameters %s", model_args)
@@ -67,11 +74,16 @@ def main():
6774
cache_dir=model_args.cache_dir,
6875
)
6976

77+
# Move model to GPU
78+
if local_rank != -1:
79+
model = model.to(local_rank)
80+
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
81+
7082
train_dataset = RerankerTrainDataset(data_args)
7183
train_collator = RerankerTrainCollator(data_args, tokenizer)
7284

73-
# Add GradCache-specific arguments to training_args
7485
training_args.gc_chunk_size = getattr(training_args, 'gc_chunk_size', 2)
86+
training_args.grad_cache = getattr(training_args, 'grad_cache', False)
7587

7688
trainer = RerankerTrainer(
7789
model=model,
@@ -81,11 +93,11 @@ def main():
8193
)
8294
train_dataset.trainer = trainer
8395

84-
trainer.train() # TODO: resume training
96+
trainer.train()
8597
trainer.save_model()
8698
if trainer.is_world_process_zero():
8799
tokenizer.save_pretrained(training_args.output_dir)
88100

89101

90102
if __name__ == "__main__":
91-
main()
103+
main()

src/tevatron/reranker/trainer.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from transformers.trainer_utils import PredictionOutput
88

99
from grad_cache import GradCache
10-
1110
from grad_cache.functional import cached, cat_input_tensor
1211
from torch.cuda.amp import autocast
1312

@@ -39,22 +38,27 @@ def split_inputs(model_input, chunk_size):
3938
class RerankerTrainer(Trainer):
4039
def __init__(self, *args, **kwargs):
4140
super().__init__(*args, **kwargs)
42-
logger.info("Initializing RerankerTrainer with GradCache")
41+
logger.info("Initializing RerankerTrainer")
4342
self.args: TrainingArguments
4443

45-
# Add these lines to include the necessary parameters
46-
self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4) # default to 4 if not provided
47-
48-
self.gc = GradCache(
49-
models=[self.model],
50-
chunk_sizes=self.gc_chunk_size,
51-
loss_fn=contrastive_loss,
52-
split_input_fn=split_inputs,
53-
get_rep_fn=lambda x: x.scores,
54-
fp16=self.args.fp16,
55-
scaler=self.scaler if self.args.fp16 else None
56-
)
57-
logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}")
44+
self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4)
45+
self.use_grad_cache = getattr(self.args, 'grad_cache', False)
46+
47+
if self.use_grad_cache:
48+
# If the model is wrapped in DDP, we need to use the .module attribute
49+
model_for_gc = self.model.module if hasattr(self.model, 'module') else self.model
50+
51+
self.gc = GradCache(
52+
models=[model_for_gc],
53+
chunk_sizes=self.gc_chunk_size,
54+
loss_fn=contrastive_loss,
55+
split_input_fn=split_inputs,
56+
get_rep_fn=lambda x: x.scores,
57+
fp16=self.args.fp16,
58+
bf16=self.args.bf16,
59+
scaler=self.scaler if (self.args.fp16 or self.args.bf16) else None
60+
)
61+
logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}")
5862

5963
def compute_loss(self, model, inputs, return_outputs=False):
6064
logger.debug(f"Computing loss with inputs: {inputs.keys()}")
@@ -68,8 +72,11 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
6872
logger.debug("Entering training step")
6973
model.train()
7074
inputs = self._prepare_inputs(inputs)
71-
_distributed = self.args.local_rank > -1
72-
loss = self.gc(inputs, no_sync_except_last=_distributed)
75+
if self.use_grad_cache:
76+
_distributed = self.args.local_rank != -1
77+
loss = self.gc(inputs, no_sync_except_last=_distributed)
78+
else:
79+
loss = self.compute_loss(model, inputs)
7380
logger.debug(f"Training step loss: {loss.item()}")
7481
return loss
7582

0 commit comments

Comments
 (0)