Skip to content

Commit a4d6637

Browse files
committed
fix: ddp
1 parent 43a642d commit a4d6637

File tree

2 files changed

+28
-27
lines changed

2 files changed

+28
-27
lines changed

src/tevatron/reranker/driver/train.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
11
import logging
22
import os
33
import sys
4+
import torch
45
from transformers import AutoTokenizer
56
from transformers import (
67
HfArgumentParser,
78
set_seed,
89
)
10+
from torch.nn.parallel import DistributedDataParallel as DDP
11+
import torch.distributed as dist
912
from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments
1013
from tevatron.reranker.modeling import RerankerModel
1114
from tevatron.reranker.dataset import RerankerTrainDataset
1215
from tevatron.reranker.collator import RerankerTrainCollator
13-
from tevatron.reranker.trainer import RerankerTrainer # Make sure this is your updated RerankerTrainer
16+
from tevatron.reranker.trainer import RerankerTrainer
1417

1518
logger = logging.getLogger(__name__)
1619

1720

21+
def setup_ddp():
22+
if not dist.is_initialized():
23+
dist.init_process_group(backend="nccl")
24+
torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
25+
26+
1827
def main():
1928
parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments))
2029

@@ -23,15 +32,8 @@ def main():
2332
else:
2433
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
2534

26-
if (
27-
os.path.exists(training_args.output_dir)
28-
and os.listdir(training_args.output_dir)
29-
and training_args.do_train
30-
and not training_args.overwrite_output_dir
31-
):
32-
raise ValueError(
33-
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
34-
)
35+
if training_args.local_rank != -1:
36+
setup_ddp()
3537

3638
# Setup logging
3739
logging.basicConfig(
@@ -67,10 +69,12 @@ def main():
6769
cache_dir=model_args.cache_dir,
6870
)
6971

72+
if training_args.local_rank != -1:
73+
model = DDP(model, device_ids=[training_args.local_rank], output_device=training_args.local_rank)
74+
7075
train_dataset = RerankerTrainDataset(data_args)
7176
train_collator = RerankerTrainCollator(data_args, tokenizer)
7277

73-
# Add GradCache-specific arguments to training_args
7478
training_args.gc_chunk_size = getattr(training_args, 'gc_chunk_size', 2)
7579

7680
trainer = RerankerTrainer(
@@ -81,7 +85,7 @@ def main():
8185
)
8286
train_dataset.trainer = trainer
8387

84-
trainer.train() # TODO: resume training
88+
trainer.train()
8589
trainer.save_model()
8690
if trainer.is_world_process_zero():
8791
tokenizer.save_pretrained(training_args.output_dir)

src/tevatron/reranker/trainer.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,46 +7,43 @@
77
from transformers.trainer_utils import PredictionOutput
88

99
from grad_cache import GradCache
10-
1110
from grad_cache.functional import cached, cat_input_tensor
1211
from torch.cuda.amp import autocast
1312

1413
logger = logging.getLogger(__name__)
1514

16-
1715
@cached
1816
@autocast()
1917
def get_model_rep(model, inputs):
2018
outputs = model(**inputs)
2119
return outputs.scores
2220

23-
2421
@cat_input_tensor
2522
@autocast()
2623
def contrastive_loss(scores):
2724
batch_size = scores.size(0) // 2
2825
labels = torch.arange(batch_size, device=scores.device)
2926
return nn.CrossEntropyLoss()(scores, labels)
3027

31-
3228
def split_inputs(model_input, chunk_size):
3329
logger.debug(f"Splitting inputs with chunk size: {chunk_size}")
3430
keys = list(model_input.keys())
3531
chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys]
3632
return [dict(zip(keys, tt)) for tt in zip(*chunked_tensors)]
3733

38-
3934
class RerankerTrainer(Trainer):
4035
def __init__(self, *args, **kwargs):
4136
super().__init__(*args, **kwargs)
4237
logger.info("Initializing RerankerTrainer with GradCache")
4338
self.args: TrainingArguments
4439

45-
# Add these lines to include the necessary parameters
46-
self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4) # default to 4 if not provided
40+
self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4)
41+
42+
# If the model is wrapped in DDP, we need to use the .module attribute
43+
model_for_gc = self.model.module if hasattr(self.model, 'module') else self.model
4744

4845
self.gc = GradCache(
49-
models=[self.model],
46+
models=[model_for_gc],
5047
chunk_sizes=self.gc_chunk_size,
5148
loss_fn=contrastive_loss,
5249
split_input_fn=split_inputs,
@@ -68,17 +65,17 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
6865
logger.debug("Entering training step")
6966
model.train()
7067
inputs = self._prepare_inputs(inputs)
71-
_distributed = self.args.local_rank > -1
68+
_distributed = self.args.local_rank != -1
7269
loss = self.gc(inputs, no_sync_except_last=_distributed)
7370
logger.debug(f"Training step loss: {loss.item()}")
7471
return loss
7572

7673
def prediction_step(
77-
self,
78-
model: nn.Module,
79-
inputs: Dict[str, Union[torch.Tensor, Any]],
80-
prediction_loss_only: bool,
81-
ignore_keys: bool = None,
74+
self,
75+
model: nn.Module,
76+
inputs: Dict[str, Union[torch.Tensor, Any]],
77+
prediction_loss_only: bool,
78+
ignore_keys: bool = None,
8279
) -> PredictionOutput:
8380
logger.debug("Entering prediction step")
8481
inputs = self._prepare_inputs(inputs)
@@ -87,4 +84,4 @@ def prediction_step(
8784
scores = outputs.scores
8885
loss = contrastive_loss(scores)
8986
logger.debug(f"Prediction step loss: {loss.item() if loss is not None else 'N/A'}")
90-
return PredictionOutput(predictions=scores, label_ids=None, metrics=None)
87+
return PredictionOutput(predictions=scores, label_ids=None, metrics=None)

0 commit comments

Comments
 (0)