diff --git a/src/data/__init__.py b/src/data/__init__.py index c24b0b03..590cafae 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -102,4 +102,4 @@ def get_collators(collator_cfgs, **kwargs): _register_data(ForgetRetainDataset) # Register collators -_register_collator(DataCollatorForSupervisedDataset) +_register_collator(DataCollatorForSupervisedDataset) \ No newline at end of file diff --git a/src/data/unlearn.py b/src/data/unlearn.py index 0cb0bada..c58a0596 100644 --- a/src/data/unlearn.py +++ b/src/data/unlearn.py @@ -15,6 +15,15 @@ def __init__(self, forget, retain, anchor="forget"): self.forget = forget self.retain = retain self.anchor = anchor + self.generator = torch.Generator() + + def set_rank_seed(self, seed: int): + """Set the rank-specific seed for this dataset. + + This should be called after trainer initialization to ensure each rank + uses a unique seed for different unanchored data. + """ + self.generator.manual_seed(seed) def __len__(self): """Ensures the sampled dataset matches the anchor dataset's length.""" @@ -36,11 +45,11 @@ def __getitem__(self, idx): if self.anchor == "forget": item["forget"] = self.forget[idx] if self.retain: - retain_idx = torch.randint(0, len(self.retain), (1,)).item() + retain_idx = torch.randint(0, len(self.retain), (1,), generator=self.generator).item() item["retain"] = self.retain[retain_idx] elif self.anchor == "retain": item["retain"] = self.retain[idx] if self.forget: - forget_idx = torch.randint(0, len(self.forget), (1,)).item() + forget_idx = torch.randint(0, len(self.forget), (1,), generator=self.generator).item() item["forget"] = self.forget[forget_idx] - return item + return item \ No newline at end of file diff --git a/src/train.py b/src/train.py index a2f81c8d..6eb02717 100644 --- a/src/train.py +++ b/src/train.py @@ -1,6 +1,8 @@ +import torch import hydra from omegaconf import DictConfig from data import get_data, get_collators +from data.unlearn import ForgetRetainDataset from model import get_model from trainer import load_trainer from evals import get_evaluators @@ -23,7 +25,11 @@ def main(cfg: DictConfig): # Load Dataset data_cfg = cfg.data data = get_data( - data_cfg, mode=mode, tokenizer=tokenizer, template_args=template_args + data_cfg, + mode=mode, + tokenizer=tokenizer, + template_args=template_args, + seed=cfg.trainer.args.seed, ) # Load collator @@ -56,6 +62,13 @@ def main(cfg: DictConfig): template_args=template_args, ) + # Set rank-specific seed for ForgetRetainDataset after trainer initialization + train_dataset = data.get("train", None) + if isinstance(train_dataset, ForgetRetainDataset): + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + rank_seed = cfg.trainer.args.seed + rank + train_dataset.set_rank_seed(rank_seed) + if trainer_args.do_train: trainer.train() trainer.save_state()