diff --git a/examples/example_rankllama.md b/examples/example_rankllama.md index 31275087..ae7e1b5f 100644 --- a/examples/example_rankllama.md +++ b/examples/example_rankllama.md @@ -18,4 +18,5 @@ deepspeed --include localhost:4,5,6,7 --master_port 60000 --module tevatron.rera --num_train_epochs 1 \ --logging_steps 10 \ --overwrite_output_dir + --gra ``` \ No newline at end of file diff --git a/src/tevatron/reranker/arguments.py b/src/tevatron/reranker/arguments.py index a2089322..a48b4468 100644 --- a/src/tevatron/reranker/arguments.py +++ b/src/tevatron/reranker/arguments.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from typing import Optional - +from transformers import TrainingArguments @dataclass class ModelArguments: @@ -116,3 +116,10 @@ class DataArguments: "enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta)." }, ) + +@dataclass +class TevatronTrainingArguments(TrainingArguments): + warmup_ratio: float = field(default=0.1) + + grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache"}) + gc_chunk_size: Optional[int] = field(default=2, metadata={"help": "Chunk size for gradient cache"}) diff --git a/src/tevatron/reranker/driver/train.py b/src/tevatron/reranker/driver/train.py index eaa17734..7c924001 100644 --- a/src/tevatron/reranker/driver/train.py +++ b/src/tevatron/reranker/driver/train.py @@ -1,57 +1,60 @@ import logging import os import sys - +import torch from transformers import AutoTokenizer from transformers import ( HfArgumentParser, set_seed, ) -from transformers import TrainingArguments - -from tevatron.reranker.arguments import ModelArguments, DataArguments - +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist +from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments from tevatron.reranker.modeling import RerankerModel from tevatron.reranker.dataset import RerankerTrainDataset -from tevatron.reranker.trainer import RerankerTrainer from tevatron.reranker.collator import RerankerTrainCollator +from tevatron.reranker.trainer import RerankerTrainer logger = logging.getLogger(__name__) + +def setup_ddp(): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + # We're running in a distributed environment + import torch.distributed as dist + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + dist.init_process_group(backend="nccl") + return rank + else: + # We're not running in a distributed environment + return -1 + + def main(): - parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - model_args: ModelArguments - data_args: DataArguments - training_args: TrainingArguments - - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." - ) + + local_rank = setup_ddp() + training_args.local_rank = local_rank # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + level=logging.INFO if local_rank in [-1, 0] else logging.WARN, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.local_rank, + local_rank, training_args.device, training_args.n_gpu, - bool(training_args.local_rank != -1), - training_args.fp16, + bool(local_rank != -1), + training_args.fp16 or training_args.bf16, ) logger.info("Training/evaluation parameters %s", training_args) logger.info("MODEL parameters %s", model_args) @@ -60,20 +63,30 @@ def main(): tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir + cache_dir=model_args.cache_dir, + trust_remote_code=True ) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.unk_token_id tokenizer.padding_side = 'right' + model = RerankerModel.build( model_args, training_args, cache_dir=model_args.cache_dir, ) + # Move model to GPU + if local_rank != -1: + model = model.to(local_rank) + model = DDP(model, device_ids=[local_rank], output_device=local_rank) + train_dataset = RerankerTrainDataset(data_args) train_collator = RerankerTrainCollator(data_args, tokenizer) + training_args.gc_chunk_size = getattr(training_args, 'gc_chunk_size', 2) + training_args.grad_cache = getattr(training_args, 'grad_cache', False) + trainer = RerankerTrainer( model=model, args=training_args, @@ -82,7 +95,7 @@ def main(): ) train_dataset.trainer = trainer - trainer.train() # TODO: resume training + trainer.train() trainer.save_model() if trainer.is_world_process_zero(): tokenizer.save_pretrained(training_args.output_dir) diff --git a/src/tevatron/reranker/modeling.py b/src/tevatron/reranker/modeling.py index 95929390..6007dffc 100644 --- a/src/tevatron/reranker/modeling.py +++ b/src/tevatron/reranker/modeling.py @@ -1,6 +1,6 @@ -import os +import logging from dataclasses import dataclass -from typing import Dict, Optional +from typing import Optional import torch from torch import nn, Tensor @@ -9,11 +9,8 @@ from transformers import TrainingArguments from peft import LoraConfig, PeftModel, TaskType, get_peft_model - from tevatron.reranker.arguments import ModelArguments -import logging - logger = logging.getLogger(__name__) @@ -22,44 +19,24 @@ class RerankerOutput(ModelOutput): loss: Optional[Tensor] = None scores: Optional[Tensor] = None + class RerankerModel(nn.Module): TRANSFORMER_CLS = AutoModelForSequenceClassification - def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None): + def __init__(self, hf_model: PreTrainedModel): super().__init__() + logger.info("Initializing RerankerModel") self.config = hf_model.config self.hf_model = hf_model - self.train_batch_size = train_batch_size - self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') - if train_batch_size: - self.register_buffer( - 'target_label', - torch.zeros(self.train_batch_size, dtype=torch.long, device=self.hf_model.device) - ) - for name, param in self.hf_model.named_parameters(): - # for some reason, ds zero 3 left some weights empty - if 'modules_to_save' in name and param.numel() == 0: - logger.warning(f'parameter {name}, shape {param.shape} is empty') - param.data = nn.Linear(self.hf_model.config.hidden_size, 1).weight.data - logger.warning('{} data: {}'.format(name, param.data.cpu().numpy())) + logger.info(f"RerankerModel initialized with config: {self.config}") - def forward(self, pair: Dict[str, Tensor] = None): - ranker_logits = self.hf_model(**pair, return_dict=True).logits - if self.train_batch_size: - grouped_logits = ranker_logits.view(self.train_batch_size, -1) - loss = self.cross_entropy(grouped_logits, self.target_label) - return RerankerOutput( - loss = loss, - scores = ranker_logits - ) + def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, **kwargs): + logger.debug(f"Forward pass with input shape: {input_ids.shape if input_ids is not None else 'None'}") + outputs = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) return RerankerOutput( - loss = None, - scores = ranker_logits + scores=outputs.logits ) - - def gradient_checkpointing_enable(self, **kwargs): - self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs) @classmethod def build( @@ -68,19 +45,27 @@ def build( train_args: TrainingArguments, **hf_kwargs, ): + logger.info(f"Building RerankerModel with args: {model_args}") base_model = cls.TRANSFORMER_CLS.from_pretrained( model_args.model_name_or_path, **hf_kwargs, ) if base_model.config.pad_token_id is None: base_model.config.pad_token_id = 0 + logger.info("Set pad_token_id to 0") + if model_args.lora or model_args.lora_name_or_path: + logger.info("Applying LoRA") if train_args.gradient_checkpointing: base_model.enable_input_require_grads() if model_args.lora_name_or_path: + logger.info(f"Loading LoRA from {model_args.lora_name_or_path}") lora_config = LoraConfig.from_pretrained(model_args.lora_name_or_path, **hf_kwargs) - lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") + lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2") else: + logger.info("Initializing new LoRA") lora_config = LoraConfig( base_model_name_or_path=model_args.model_name_or_path, task_type=TaskType.SEQ_CLS, @@ -91,15 +76,10 @@ def build( inference_mode=False, ) lora_model = get_peft_model(base_model, lora_config) - model = cls( - hf_model=lora_model, - train_batch_size=train_args.per_device_train_batch_size, - ) + model = cls(hf_model=lora_model) else: - model = cls( - hf_model=base_model, - train_batch_size=train_args.per_device_train_batch_size, - ) + logger.info("Building model without LoRA") + model = cls(hf_model=base_model) return model @classmethod @@ -107,21 +87,24 @@ def load(cls, model_name_or_path: str, lora_name_or_path: str = None, **hf_kwargs): - base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, num_labels=1, **hf_kwargs, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") + logger.info(f"Loading RerankerModel from {model_name_or_path}") + base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, num_labels=1, **hf_kwargs, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2") if base_model.config.pad_token_id is None: base_model.config.pad_token_id = 0 + logger.info("Set pad_token_id to 0") if lora_name_or_path: + logger.info(f"Loading LoRA from {lora_name_or_path}") lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs) lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config) lora_model = lora_model.merge_and_unload() - model = cls( - hf_model=lora_model, - ) + model = cls(hf_model=lora_model) else: - model = cls( - hf_model=base_model, - ) + logger.info("Loading model without LoRA") + model = cls(hf_model=base_model) return model def save(self, output_dir: str): - self.hf_model.save_pretrained(output_dir) + logger.info(f"Saving model to {output_dir}") + self.hf_model.save_pretrained(output_dir) \ No newline at end of file diff --git a/src/tevatron/reranker/trainer.py b/src/tevatron/reranker/trainer.py index f49e1baf..8534b68a 100644 --- a/src/tevatron/reranker/trainer.py +++ b/src/tevatron/reranker/trainer.py @@ -1,40 +1,96 @@ -import os -from typing import Optional +import logging +from typing import Dict, Union, Any import torch +from torch import nn +from transformers import Trainer, TrainingArguments +from transformers.trainer_utils import PredictionOutput -from transformers.trainer import Trainer -from transformers.deepspeed import is_deepspeed_zero3_enabled -from peft import get_peft_model_state_dict +from grad_cache import GradCache +from grad_cache.functional import cached, cat_input_tensor +from torch.cuda.amp import autocast -import logging logger = logging.getLogger(__name__) +@cached +@autocast() +def get_model_rep(model, inputs): + outputs = model(**inputs) + return outputs.scores + + +@cat_input_tensor +@autocast() +def contrastive_loss(scores): + batch_size = scores.size(0) // 2 + labels = torch.arange(batch_size, device=scores.device) + return nn.CrossEntropyLoss()(scores, labels) + + +def split_inputs(model_input, chunk_size): + logger.debug(f"Splitting inputs with chunk size: {chunk_size}") + keys = list(model_input.keys()) + chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys] + return [dict(zip(keys, tt)) for tt in zip(*chunked_tensors)] + + class RerankerTrainer(Trainer): def __init__(self, *args, **kwargs): - super(RerankerTrainer, self).__init__(*args, **kwargs) - - def _save(self, output_dir: Optional[str] = None, state_dict=None): - output_dir = output_dir if output_dir is not None else self.args.output_dir - os.makedirs(output_dir, exist_ok=True) - logger.info("Saving model checkpoint to %s", output_dir) - self.model.save(output_dir) - - if is_deepspeed_zero3_enabled(): - if state_dict is None: - state_dict = self.model.state_dict() - prefix = 'hf_model.' - assert all( - k.startswith(prefix) or k == "target_label" - for k in state_dict.keys() - ), list(state_dict.keys()) - state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} - lora_state_dict = get_peft_model_state_dict(self.model.hf_model, state_dict) - if self.args.process_index <= 0: - torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin")) - print(f"Save adapter model at {output_dir}") - - - def compute_loss(self, model, inputs): - return model(inputs).loss + super().__init__(*args, **kwargs) + logger.info("Initializing RerankerTrainer") + self.args: TrainingArguments + + self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4) + self.use_grad_cache = getattr(self.args, 'grad_cache', False) + + if self.use_grad_cache: + # If the model is wrapped in DDP, we need to use the .module attribute + model_for_gc = self.model.module if hasattr(self.model, 'module') else self.model + + self.gc = GradCache( + models=[model_for_gc], + chunk_sizes=self.gc_chunk_size, + loss_fn=contrastive_loss, + split_input_fn=split_inputs, + get_rep_fn=lambda x: x.scores, + fp16=self.args.fp16, + # scaler: GradScaler = None, + ) + logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}") + + def compute_loss(self, model, inputs, return_outputs=False): + logger.debug(f"Computing loss with inputs: {inputs.keys()}") + outputs = model(**inputs) + scores = outputs.scores + loss = contrastive_loss(scores) + logger.debug(f"Computed loss: {loss.item()}") + return (loss, outputs) if return_outputs else loss + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + logger.debug("Entering training step") + model.train() + inputs = self._prepare_inputs(inputs) + if self.use_grad_cache: + _distributed = self.args.local_rank != -1 + loss = self.gc(inputs, no_sync_except_last=_distributed) + else: + loss = self.compute_loss(model, inputs) + logger.debug(f"Training step loss: {loss.item()}") + return loss + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: bool = None, + ) -> PredictionOutput: + logger.debug("Entering prediction step") + inputs = self._prepare_inputs(inputs) + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.scores + loss = contrastive_loss(scores) + logger.debug(f"Prediction step loss: {loss.item() if loss is not None else 'N/A'}") + return PredictionOutput(predictions=scores, label_ids=None, metrics=None)