From ccc0edd17af0f00d6657b030872967d490602624 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 29 Oct 2025 20:19:22 +0800 Subject: [PATCH 01/16] wip --- swift/llm/train/rlhf.py | 77 +++++++++++++++------- swift/trainers/mixin.py | 23 ++++--- swift/trainers/rlhf_trainer/dpo_trainer.py | 63 ++++++++++++++++-- swift/trainers/rlhf_trainer/rlhf_mixin.py | 51 +++++++++----- 4 files changed, 160 insertions(+), 54 deletions(-) diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index 832b458ee5..e822514966 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -5,6 +5,7 @@ from swift.llm import safe_snapshot_download from swift.plugin import Tuner, extra_tuners +from swift.ray import RayHelper from swift.tuners import Swift from swift.utils import get_logger, get_model_parameter_info from swift.utils.utils import disable_deepspeed_zero3 @@ -16,10 +17,17 @@ logger = get_logger() +@RayHelper.worker(group=['default', 'ref', 'reward', 'value', 'teacher']) class SwiftRLHF(SwiftSft): args_class = RLHFArguments args: args_class + def __init__(self, args: RLHFArguments): + super().__init__(args) + self.reward_model = [] + if self.args.rlhf_type == 'grpo': + self.reward_template = [] + @staticmethod def _get_model_task_type(model_dir): task_type = None @@ -48,6 +56,47 @@ def _get_model_task_type(model_dir): task_type = 'seq_cls' return task_type, num_labels + @RayHelper.function(group='ref') + def _prepare_ref_model(self, key, origin_key, model_type, model_revision): + result = self._prepare_single_model(key, origin_key, model_type, model_revision) + + if result is not None: + self.ref_model = result[0] + + @RayHelper.function(group='value') + def _prepare_value_model(self, key, origin_key, model_type, model_revision): + result = self._prepare_single_model(key, origin_key, model_type, model_revision) + + if result is not None: + self.value_model = result[0] + + @RayHelper.function(group='teacher') + def _prepare_teacher_model(self, key, origin_key, model_type, model_revision): + result = self._prepare_single_model(key, origin_key, model_type, model_revision) + + if result is not None: + self.teacher_model = result[0] + + def _prepare_reward_model(self, reward_model_path, key, origin_key, model_type, model_revision): + rms = self.args.reward_model if isinstance(self.args.reward_model, list) else [self.args.reward_model] + self.args.reward_model = reward_model_path # Temporarily set for prepare_single_model + result = self._prepare_single_model(key, origin_key, model_type, model_revision) + + if result is not None: + model, processor = result + self.reward_model.append(model) + + if self.args.rlhf_type == 'grpo': + reward_template = self.args.get_template(processor, processor.model_meta.template) + if reward_template.use_model: + reward_template.model = model + self.reward_template.append(reward_template) + self.args.reward_model = rms # Restore original value + + if self.args.rlhf_type != 'grpo' and self.reward_model: + assert len(self.reward_model) <= 1 + self.reward_model = self.reward_model[0] + def _prepare_single_model(self, key, origin_key, model_type, model_revision): from swift.llm.infer.utils import prepare_adapter args = self.args @@ -116,10 +165,7 @@ def _prepare_model_tokenizer(self): model_type = model_type[0] if model_type else None model_revision = model_revision[0] if model_revision else None - result = self._prepare_single_model(model_key, key, model_type, model_revision) - if result is not None: - model, _ = result - setattr(self, f'{key}_model', model) + getattr(self, f'_prepare_{key}_model')(model_key, key, model_type, model_revision) # Handle reward model(s) self.reward_model = None @@ -130,26 +176,9 @@ def _prepare_model_tokenizer(self): rm_revisions = args.reward_model_revision if args.reward_model_revision else [None] * num_rms assert len(rms) == len(rm_types) == len(rm_revisions) - self.reward_model = [] - if args.rlhf_type == 'grpo': - self.reward_template = [] - - for reward_model_path, rm_type, rm_revision in zip(rms, rm_types, rm_revisions): - args.reward_model = reward_model_path # Temporarily set for prepare_single_model - result = self._prepare_single_model('reward', None, rm_type, rm_revision) - if result is not None: - model, processor = result - self.reward_model.append(model) - - if args.rlhf_type == 'grpo': - reward_template = self.args.get_template(processor, processor.model_meta.template) - if reward_template.use_model: - reward_template.model = model - self.reward_template.append(reward_template) - args.reward_model = rms # Restore original value - if args.rlhf_type != 'grpo' and self.reward_model: - assert len(self.reward_model) <= 1 - self.reward_model = self.reward_model[0] + for rm_idx, (reward_model_path, rm_type, rm_revision) in enumerate(zip(rms, rm_types, rm_revisions)): + _prepare_reward_model = RayHelper.function(group=f'reward_{rm_idx}')(self._prepare_reward_model) + _prepare_reward_model(reward_model_path, 'reward', None, rm_type, rm_revision) super()._prepare_model_tokenizer() diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 4565521411..80e7dba28d 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -46,6 +46,7 @@ from ..llm.model.patcher import get_lm_head_model, revert_padding_free, transformers_seq_cls_forward from .arguments import TrainingArguments from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model +from swift.ray import RayHelper try: from trl import AutoModelForCausalLMWithValueHead @@ -55,6 +56,7 @@ logger = get_logger() +@RayHelper.worker(group=['default']) class SwiftMixin: FLASH_CKPT_WAIT_TIMEOUT = 1800 @@ -74,7 +76,7 @@ def __init__(self, logger.warning('Using IterableDataset, setting args.dataloader_num_workers to 1.') self.compute_loss_func = None # Compatible with the older version of transformers - if args.check_model and hasattr(model, 'model_dir'): + if model is not None and args.check_model and hasattr(model, 'model_dir'): with ms_logger_context(logging.CRITICAL), self._patch_timeout(): check_local_model_is_latest( model.model_dir, user_agent={ @@ -99,8 +101,6 @@ def _get_mean_metric(): self.template = template self.hub = get_hub() - self.model_meta = model.model_meta - kwargs.update(self.create_loss_and_metric(args)) trainer_parameters = inspect.signature(Trainer.__init__).parameters tokenizer_key = 'processing_class' if 'processing_class' in trainer_parameters else 'tokenizer' @@ -117,21 +117,26 @@ def _get_mean_metric(): optimizers=optimizers, **kwargs) - if get_function(model.__class__.forward) is not get_function(model.forward): - self.label_names = find_labels(model) - self.can_return_loss = can_return_loss(model) + self._prepare_model_info(model) self.label_names = self.label_names or ['labels'] self.start_time = time.time() self._fix_gradient_checkpointing() self._patch_tasks() - update_generation_config_eos_token(self.model.generation_config, self.template) - if getattr(self.model, 'origin_generation_config', None): - self.model.origin_generation_config.eos_token_id = self.model.generation_config.eos_token_id if self.args.resume_only_model and self.args.ignore_data_skip: # The weights have already been loaded outside the trainer, # so reading train_state is skipped here. self.args.resume_from_checkpoint = None + @RayHelper.function(group='default') + def _prepare_model_info(self, model): + self.model_meta = model.model_meta + if get_function(model.__class__.forward) is not get_function(model.forward): + self.label_names = find_labels(model) + self.can_return_loss = can_return_loss(model) + update_generation_config_eos_token(self.model.generation_config, self.template) + if getattr(self.model, 'origin_generation_config', None): + self.model.origin_generation_config.eos_token_id = self.model.generation_config.eos_token_id + @contextmanager def _patch_timeout(self): from modelscope.hub.api import HubApi diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index b84a7051eb..f7d61b20d4 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -1,18 +1,19 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import warnings +from contextlib import nullcontext from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from accelerate.utils import gather_object -from peft import PeftModel from transformers import PreTrainedModel from transformers.utils.versions import require_version from trl import DPOTrainer as HFDPOTrainer from trl.trainer.dpo_config import DPOConfig -from trl.trainer.utils import RunningMoments - +from trl.trainer.utils import RunningMoments, pad_to_length +from torch import autocast from swift.llm import to_device +from swift.ray import RayHelper from swift.utils import get_logger from ..mixin import DataLoaderMixin, SwiftMixin from .rlhf_mixin import RLHFTrainerMixin @@ -27,6 +28,7 @@ def new_gather_function(tensor): return torch.concat(to_device(tensor_list, tensor.device), dim=0) +@RayHelper.worker(group=['default', 'ref']) class DPOTrainer(RLHFTrainerMixin, SwiftMixin, DataLoaderMixin, HFDPOTrainer): def __init__(self, @@ -60,7 +62,6 @@ def __init__(self, self.precompute_ref_log_probs = args.precompute_ref_log_probs self.f_divergence_type = args.f_divergence_type self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} - self.is_peft_model = isinstance(model, PeftModel) self.ref_adapter_name = args.ref_adapter_name self.model_adapter_name = None @@ -176,10 +177,64 @@ def concatenated_forward( output['aux_loss'] = outputs.aux_loss return output + @RayHelper.function(group='default') def training_step(self, model, inputs, *args, **kwargs): with self.template.forward_context(self.model, inputs): return super().training_step(model, inputs, *args, **kwargs) + @RayHelper.function(group='default') def prediction_step(self, model, inputs, *args, **kwargs): with self.template.forward_context(self.model, inputs): return super().prediction_step(model, inputs, *args, **kwargs) + + def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with + # the torch amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with generate_context_manager: + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.padding_value, + ) + + # if ref_output in batch use that otherwise use the reference model + if "ref_output" in batch: + ref_output = batch["ref_output"] + else: + if self.args.ref_model is None: + with self.null_ref_context(): + ref_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.padding_value, + ) + else: + ref_output = self.generate_from_ref(batch) + + policy_output = pad_to_length(policy_output, self.max_length, self.padding_value) + policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) + + ref_output = pad_to_length(ref_output, self.max_length, self.padding_value) + ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True) + + return policy_output_decoded, ref_output_decoded + + @RayHelper.function(group='ref') + def generate_from_ref(self, batch): + return self.ref_model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.padding_value, + ) \ No newline at end of file diff --git a/swift/trainers/rlhf_trainer/rlhf_mixin.py b/swift/trainers/rlhf_trainer/rlhf_mixin.py index 428799da62..8efc5a2122 100644 --- a/swift/trainers/rlhf_trainer/rlhf_mixin.py +++ b/swift/trainers/rlhf_trainer/rlhf_mixin.py @@ -7,12 +7,16 @@ import torch import torch.nn as nn +from peft import PeftModel from torch.utils.data import DataLoader from transformers import PreTrainedModel from trl.models.utils import prepare_deepspeed from trl.trainer.utils import selective_log_softmax +from swift.ray import RayHelper + +@RayHelper.worker(group=['default', 'ref', 'reward', 'value', 'teacher']) class RLHFTrainerMixin: def __init__(self, @@ -20,38 +24,50 @@ def __init__(self, ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, *_args, **kwargs): - from trl.trainer import disable_dropout_in_model - from swift.llm import HfConfigFactory self.ref_model = ref_model self._stored_metrics = defaultdict(lambda: defaultdict(list)) args = kwargs['args'] self.beta = getattr(args, 'beta', 0.0) - if getattr(args, 'disable_dropout', False): - disable_dropout_in_model(model) - if self.ref_model is not None: - disable_dropout_in_model(self.ref_model) - self.is_encoder_decoder = kwargs['template'].is_encoder_decoder self._peft_has_been_casted_to_bf16 = False self.generate_during_eval = getattr(args, 'generate_during_eval', False) - if self.is_encoder_decoder: - self.decoder_start_token_id = HfConfigFactory.get_config_attr(model.config, 'decoder_start_token_id') - self.pad_token_id = HfConfigFactory.get_config_attr(model.config, 'pad_token_id') + + self._prepare_model(args, model) + self._prepare_ref_model(args, ref_model) + # not use self.is_vision_model = False self.label_pad_token_id = -100 self.use_dpo_data_collator = True super().__init__(model, *_args, **kwargs) - self.aux_loss_enabled = model.model_info.is_moe_model and args.router_aux_loss_coef > 0 self.aux_loss_coef = args.router_aux_loss_coef - if ref_model is not None: - if self.is_deepspeed_enabled: - self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) - else: - self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - self.padding_value = self.tokenizer.pad_token_id + @RayHelper.function(group='default') + def _prepare_model(self, args, model): + from trl.trainer import disable_dropout_in_model + from swift.llm import HfConfigFactory + if getattr(args, 'disable_dropout', False): + if model is not None: + disable_dropout_in_model(model) + + self.aux_loss_enabled = model.model_info.is_moe_model and args.router_aux_loss_coef > 0 + self.is_peft_model = isinstance(model, PeftModel) + if self.is_encoder_decoder: + self.decoder_start_token_id = HfConfigFactory.get_config_attr(model.config, 'decoder_start_token_id') + self.pad_token_id = HfConfigFactory.get_config_attr(model.config, 'pad_token_id') + + @RayHelper.function(group='ref') + def _prepare_ref_model(self, args, ref_model): + from trl.trainer import disable_dropout_in_model + if getattr(args, 'disable_dropout', False): + if ref_model is not None: + disable_dropout_in_model(ref_model) + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + def create_loss_and_metric(self, args): return {} @@ -113,6 +129,7 @@ def _patch_concatenated_forward(): with _patch_concatenated_forward(): return super().concatenated_forward(model, model_kwargs) + @RayHelper.function(group='default') def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): res = super().compute_loss(model, inputs, return_outputs=return_outputs) # compat transformers>=4.46.* From 3d585fd6be4bfeea544360acc9ac15413c74bb65 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 30 Oct 2025 17:05:21 +0800 Subject: [PATCH 02/16] wip --- swift/trainers/rlhf_trainer/dpo_trainer.py | 40 +++++++++++++++++++++- swift/trainers/rlhf_trainer/rlhf_mixin.py | 2 +- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index f7d61b20d4..96281e9065 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import warnings -from contextlib import nullcontext +from contextlib import nullcontext, contextmanager from typing import Dict, List, Optional, Tuple, Union import torch @@ -187,6 +187,30 @@ def prediction_step(self, model, inputs, *args, **kwargs): with self.template.forward_context(self.model, inputs): return super().prediction_step(model, inputs, *args, **kwargs) + @RayHelper.function(group='default') + def _compute_log_probs(self, batch): + compte_ref_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + with torch.no_grad(), compte_ref_context_manager, self.null_ref_context(): + ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) + return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] + + @RayHelper.function(group='ref') + def _compute_ref_log_probs(self, batch): + compte_ref_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + with torch.no_grad(), compte_ref_context_manager: + ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True) + return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] + + def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> dict: + if self.args.ref_model is None: + return self._compute_log_probs(batch) + else: + return self._compute_ref_log_probs(batch) + def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" @@ -229,6 +253,20 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) return policy_output_decoded, ref_output_decoded + @contextmanager + def null_ref_context(self): + """Context manager for handling null reference model (that is, peft adapter manipulation).""" + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): + if self.ref_adapter_name: + self.model.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.set_adapter(self.model_adapter_name or "default") + @RayHelper.function(group='ref') def generate_from_ref(self, batch): return self.ref_model.generate( diff --git a/swift/trainers/rlhf_trainer/rlhf_mixin.py b/swift/trainers/rlhf_trainer/rlhf_mixin.py index 8efc5a2122..1d8077bf77 100644 --- a/swift/trainers/rlhf_trainer/rlhf_mixin.py +++ b/swift/trainers/rlhf_trainer/rlhf_mixin.py @@ -39,9 +39,9 @@ def __init__(self, self.is_vision_model = False self.label_pad_token_id = -100 self.use_dpo_data_collator = True - super().__init__(model, *_args, **kwargs) self.aux_loss_coef = args.router_aux_loss_coef self.padding_value = self.tokenizer.pad_token_id + super().__init__(model, *_args, **kwargs) @RayHelper.function(group='default') def _prepare_model(self, args, model): From 2cf6194236f84ce6589772ececc8615e6a8edc91 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 30 Oct 2025 17:06:18 +0800 Subject: [PATCH 03/16] fix --- swift/cli/rlhf.py | 4 +++- swift/llm/train/rlhf.py | 4 ++++ swift/llm/train/sft.py | 15 ++++++++------- swift/ray/base.py | 3 ++- swift/ray/resource_manager.py | 2 +- swift/trainers/rlhf_trainer/dpo_trainer.py | 2 +- 6 files changed, 19 insertions(+), 11 deletions(-) diff --git a/swift/cli/rlhf.py b/swift/cli/rlhf.py index 4f0fd6a0ab..0419df5c86 100644 --- a/swift/cli/rlhf.py +++ b/swift/cli/rlhf.py @@ -1,5 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import rlhf_main if __name__ == '__main__': + from swift.ray import try_init_ray + try_init_ray() + from swift.llm import rlhf_main rlhf_main() diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index e822514966..937a275736 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -197,6 +197,7 @@ def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_t args.training_args.ref_adapter_name = 'ref_adapter' return model + @RayHelper.function(group='default') def _prepare_template(self) -> None: args = self.args super()._prepare_template() @@ -250,6 +251,9 @@ def _get_trainer_kwargs(self): if self.args.rlhf_type == 'gkd' and self.args.teacher_deepspeed: trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed return trainer_kwargs + + def run(self): + return super().run() def rlhf_main(args: Optional[Union[List[str], RLHFArguments]] = None): diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 502b895646..dafd7c0056 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -173,18 +173,13 @@ def _post_process_datasets(self, datasets: List) -> List: datasets[i] = dataset self._show_dataset(*datasets) return datasets - + @RayHelper.function(group='default') - def run(self): - args = self.args - train_dataset, val_dataset = self._prepare_dataset() - + def tune_model(self, train_dataset): if args.task_type == 'seq_cls': args.problem_type = args.problem_type or getattr(self.model.config, 'problem_type', None) logger.info(f'args.problem_type: {args.problem_type}') args.save_args() - - data_collator = self._get_data_collator() # Some tuners require train_dataset and data_collator for preparation: LoRA-GA self.model = self.prepare_model(self.args, self.model, template=self.template, train_dataset=train_dataset) logger.info(f'model: {self.model}') @@ -192,6 +187,12 @@ def run(self): self.train_msg['model_parameter_info'] = model_parameter_info logger.info(f'model_parameter_info: {model_parameter_info}') + @RayHelper.function(group='default') + def run(self): + args = self.args + train_dataset, val_dataset = self._prepare_dataset() + + data_collator = self._get_data_collator() trainer_cls = TrainerFactory.get_trainer_cls(args) trainer = trainer_cls( model=self.model, diff --git a/swift/ray/base.py b/swift/ray/base.py index 44bfad191a..6a7805a346 100644 --- a/swift/ray/base.py +++ b/swift/ray/base.py @@ -265,7 +265,8 @@ def _create_workers(group: Union[str, List[str]], *args, **kwargs): _config = config break - assert _config is not None + if _config is None: + continue local_groups = _config['workers'] VISIBLE_ENV_MAPPING = { diff --git a/swift/ray/resource_manager.py b/swift/ray/resource_manager.py index 639d7d0754..02a770d8dd 100644 --- a/swift/ray/resource_manager.py +++ b/swift/ray/resource_manager.py @@ -46,7 +46,7 @@ def __init__(self, groups: Dict[str, Any]): ranks = list(range(last_rank + 1, last_rank + 1 + ranks)) except Exception: # noqa if isinstance(ranks, str): - ranks = ast.literal_eval(ranks) + ranks = eval(ranks) finally: all_ranks.extend(ranks) group['ranks'] = ranks diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index f7d61b20d4..67d3c869eb 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -67,7 +67,7 @@ def __init__(self, self.model_adapter_name = None self.reference_free = args.reference_free self.use_weighting = False - + import ray; ray.util.pdb.set_trace() super().__init__(model, ref_model, *_args, **kwargs) if 'bco_pair' in loss_types: From 9156320fcae8108a1cceea8196349a41ddfc0747 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 30 Oct 2025 17:11:01 +0800 Subject: [PATCH 04/16] revert files --- swift/llm/train/sft.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index dafd7c0056..502b895646 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -173,13 +173,18 @@ def _post_process_datasets(self, datasets: List) -> List: datasets[i] = dataset self._show_dataset(*datasets) return datasets - + @RayHelper.function(group='default') - def tune_model(self, train_dataset): + def run(self): + args = self.args + train_dataset, val_dataset = self._prepare_dataset() + if args.task_type == 'seq_cls': args.problem_type = args.problem_type or getattr(self.model.config, 'problem_type', None) logger.info(f'args.problem_type: {args.problem_type}') args.save_args() + + data_collator = self._get_data_collator() # Some tuners require train_dataset and data_collator for preparation: LoRA-GA self.model = self.prepare_model(self.args, self.model, template=self.template, train_dataset=train_dataset) logger.info(f'model: {self.model}') @@ -187,12 +192,6 @@ def tune_model(self, train_dataset): self.train_msg['model_parameter_info'] = model_parameter_info logger.info(f'model_parameter_info: {model_parameter_info}') - @RayHelper.function(group='default') - def run(self): - args = self.args - train_dataset, val_dataset = self._prepare_dataset() - - data_collator = self._get_data_collator() trainer_cls = TrainerFactory.get_trainer_cls(args) trainer = trainer_cls( model=self.model, From 3a9e5754213a07a53e7965b5977688c5dff91551 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 30 Oct 2025 17:41:05 +0800 Subject: [PATCH 05/16] wip --- swift/llm/train/rlhf.py | 3 --- swift/trainers/mixin.py | 2 -- swift/trainers/rlhf_trainer/dpo_trainer.py | 1 - swift/trainers/rlhf_trainer/rlhf_mixin.py | 14 ++++++-------- 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index 937a275736..ab68cd45ef 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -251,9 +251,6 @@ def _get_trainer_kwargs(self): if self.args.rlhf_type == 'gkd' and self.args.teacher_deepspeed: trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed return trainer_kwargs - - def run(self): - return super().run() def rlhf_main(args: Optional[Union[List[str], RLHFArguments]] = None): diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 80e7dba28d..2dd168d039 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -56,7 +56,6 @@ logger = get_logger() -@RayHelper.worker(group=['default']) class SwiftMixin: FLASH_CKPT_WAIT_TIMEOUT = 1800 @@ -127,7 +126,6 @@ def _get_mean_metric(): # so reading train_state is skipped here. self.args.resume_from_checkpoint = None - @RayHelper.function(group='default') def _prepare_model_info(self, model): self.model_meta = model.model_meta if get_function(model.__class__.forward) is not get_function(model.forward): diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index 14a2ff74c0..e73297f32a 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -28,7 +28,6 @@ def new_gather_function(tensor): return torch.concat(to_device(tensor_list, tensor.device), dim=0) -@RayHelper.worker(group=['default', 'ref']) class DPOTrainer(RLHFTrainerMixin, SwiftMixin, DataLoaderMixin, HFDPOTrainer): def __init__(self, diff --git a/swift/trainers/rlhf_trainer/rlhf_mixin.py b/swift/trainers/rlhf_trainer/rlhf_mixin.py index 1d8077bf77..8468f94309 100644 --- a/swift/trainers/rlhf_trainer/rlhf_mixin.py +++ b/swift/trainers/rlhf_trainer/rlhf_mixin.py @@ -16,7 +16,6 @@ from swift.ray import RayHelper -@RayHelper.worker(group=['default', 'ref', 'reward', 'value', 'teacher']) class RLHFTrainerMixin: def __init__(self, @@ -60,13 +59,13 @@ def _prepare_model(self, args, model): @RayHelper.function(group='ref') def _prepare_ref_model(self, args, ref_model): from trl.trainer import disable_dropout_in_model - if getattr(args, 'disable_dropout', False): - if ref_model is not None: + if ref_model is not None: + if getattr(args, 'disable_dropout', False): disable_dropout_in_model(ref_model) - if self.is_deepspeed_enabled: - self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) - else: - self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(ref_model, evaluation_mode=True) def create_loss_and_metric(self, args): return {} @@ -129,7 +128,6 @@ def _patch_concatenated_forward(): with _patch_concatenated_forward(): return super().concatenated_forward(model, model_kwargs) - @RayHelper.function(group='default') def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): res = super().compute_loss(model, inputs, return_outputs=return_outputs) # compat transformers>=4.46.* From c4439ed85ddf11cfd5c4f1f0bd6dd74d27ae08f8 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 30 Oct 2025 20:18:03 +0800 Subject: [PATCH 06/16] wip --- swift/ray/base.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/swift/ray/base.py b/swift/ray/base.py index 6a7805a346..140841279a 100644 --- a/swift/ray/base.py +++ b/swift/ray/base.py @@ -35,6 +35,43 @@ class RayHelper: device_groups: Dict[str, Any] = None + _registry = None + + @staticmethod + def _init_registry(): + if RayHelper._registry is not None: + return + + import ray + + @ray.remote + class WorkerRegistry: + def __init__(self): + self.workers = {} + + def register_workers(self, group: str, worker_handles: List): + self.workers[group] = worker_handles + + def get_workers(self, group: Optional[str] = None): + if group: + return self.workers.get(group, []) + return self.workers + + def get_all_workers(self): + return self.workers + + def clear(self): + self.workers.clear() + + try: + RayHelper._registry = ray.get_actor("swift_worker_registry") + except ValueError: + RayHelper._registry = WorkerRegistry.options( + name="swift_worker_registry", + lifetime="detached", + namespace="default" + ).remote() + @staticmethod def initialize(device_groups: Dict[str, Any]): """Initialize RayHelper. @@ -53,6 +90,7 @@ def initialize(device_groups: Dict[str, Any]): if RayHelper.resource_manager is None: # Resource manager initialize only once in the pipeline process. RayHelper.resource_manager = ResourceManager(device_groups) + RayHelper._init_registry() @staticmethod def teardown(): @@ -60,6 +98,15 @@ def teardown(): RayHelper.resource_manager.destroy_placement_group() RayHelper.resource_manager = None + if RayHelper._registry is not None: + import ray + try: + ray.get(RayHelper._registry.clear.remote()) + ray.kill(RayHelper._registry) + except: # noqa + pass + RayHelper._registry = None + @staticmethod def is_called_from_init(): """If some function called from __init__. @@ -199,7 +246,12 @@ def execute_all_sync(group, dispatch, execute, method_name: str, *args, **kwargs @staticmethod def execute_all_async(group, dispatch, execute, method_name: str, *args, **kwargs): - workers = RayHelper.worker_instance[group] + import ray + if group in RayHelper.worker_instance: + workers = RayHelper.worker_instance[group] + else: + workers = ray.get(RayHelper._registry.get_workers.remote(group)) + length = len(workers) if execute == 'first': return getattr(workers[0], method_name).remote(*args, **kwargs) @@ -369,3 +421,4 @@ def get_node_address(): for g in local_groups: RayHelper.worker_instance[g] = workers + ray.get(RayHelper._registry.register_workers.remote(g, workers)) \ No newline at end of file From 53f03272338d2f54119db5ab0ecd403b2fc201aa Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 30 Oct 2025 23:05:30 +0800 Subject: [PATCH 07/16] wip --- swift/llm/train/rlhf.py | 36 +++++++++++++++++++++- swift/llm/train/sft.py | 14 +++++++-- swift/trainers/rlhf_trainer/dpo_trainer.py | 1 - 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index ab68cd45ef..4d639881ab 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -13,6 +13,7 @@ from ..model import HfConfigFactory from .kto import prepare_kto_dataset from .sft import SwiftSft +from ...trainers import TrainerFactory logger = get_logger() @@ -27,6 +28,7 @@ def __init__(self, args: RLHFArguments): self.reward_model = [] if self.args.rlhf_type == 'grpo': self.reward_template = [] + self._prepare_trainer() @staticmethod def _get_model_task_type(model_dir): @@ -197,7 +199,6 @@ def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_t args.training_args.ref_adapter_name = 'ref_adapter' return model - @RayHelper.function(group='default') def _prepare_template(self) -> None: args = self.args super()._prepare_template() @@ -252,6 +253,39 @@ def _get_trainer_kwargs(self): trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed return trainer_kwargs + @RayHelper.function(group='default') + def _prepare_model(self, train_dataset): + # Some tuners require train_dataset and data_collator for preparation: LoRA-GA + self.model = self.prepare_model(self.args, self.model, template=self.template, train_dataset=train_dataset) + logger.info(f'model: {self.model}') + model_parameter_info = get_model_parameter_info(self.model) + self.train_msg['model_parameter_info'] = model_parameter_info + logger.info(f'model_parameter_info: {model_parameter_info}') + + def _prepare_trainer(self): + args = self.args + train_dataset, val_dataset = self._prepare_dataset() + args.save_args() + + data_collator = self._get_data_collator() + self._prepare_model(train_dataset) + + trainer_cls = TrainerFactory.get_trainer_cls(args) + self.trainer = trainer_cls( + model=self.model, + args=self.args.training_args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=val_dataset, + callbacks=self.callbacks, + template=self.template, + **self._get_trainer_kwargs(), + ) + + @RayHelper.function(group='default') + def run(self): + return self.train(self.trainer) + def rlhf_main(args: Optional[Union[List[str], RLHFArguments]] = None): return SwiftRLHF(args).main() diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 502b895646..5afac4e44e 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -28,8 +28,10 @@ class SwiftSft(SwiftPipeline, TunerMixin): def __init__(self, args: Optional[Union[List[str], TrainArguments]] = None) -> None: super().__init__(args) self.train_msg = {} + self._prepare_processor() self._prepare_model_tokenizer() self._prepare_template() + self._patch_model_to_template() self._prepare_callbacks() self._prepare_flash_ckpt() @@ -48,6 +50,10 @@ def _prepare_generation_config(self): args.get_request_config(), self.tokenizer) logger.info(f'model.generation_config: {self.model.generation_config}') + def _prepare_processor(self, **kwargs): + args = self.args + _, self.processor = args.get_model_processor(load_model=False, **kwargs) + @RayHelper.function(group='default') def _prepare_model_tokenizer(self, **kwargs): args = self.args @@ -65,17 +71,19 @@ def _prepare_model_tokenizer(self, **kwargs): self._prepare_generation_config() - @RayHelper.function(group='default') def _prepare_template(self) -> None: args = self.args template = args.get_template(self.processor) template.set_mode('train') - if template.use_model: - template.model = self.model if args.model_meta.is_multimodal and (args.padding_free or args.packing) and not template.support_padding_free: raise ValueError(f'Template `{args.template}` does not support padding free or packing.') self.template = template + @RayHelper.function(group='default') + def _patch_model_to_template(self): + if self.template.use_model: + self.template.model = self.model + def _get_dataset(self): # The random shuffling of the training set occurs in the dataloader of the trainer. args = self.args diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index e73297f32a..949cc978a4 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -66,7 +66,6 @@ def __init__(self, self.model_adapter_name = None self.reference_free = args.reference_free self.use_weighting = False - import ray; ray.util.pdb.set_trace() super().__init__(model, ref_model, *_args, **kwargs) if 'bco_pair' in loss_types: From c303e1f03939194f9c403df50e12dd2388231688 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 31 Oct 2025 21:39:40 +0800 Subject: [PATCH 08/16] fix --- swift/llm/train/sft.py | 4 +-- swift/ray/base.py | 30 ++++++++++++++++++++-- swift/trainers/mixin.py | 2 ++ swift/trainers/rlhf_trainer/dpo_trainer.py | 14 ---------- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 5afac4e44e..363d21481c 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -35,7 +35,6 @@ def __init__(self, args: Optional[Union[List[str], TrainArguments]] = None) -> N self._prepare_callbacks() self._prepare_flash_ckpt() - @RayHelper.function(group='default') def _prepare_flash_ckpt(self): if self.args.use_flash_ckpt: try: @@ -57,7 +56,7 @@ def _prepare_processor(self, **kwargs): @RayHelper.function(group='default') def _prepare_model_tokenizer(self, **kwargs): args = self.args - self.model, self.processor = args.get_model_processor(**kwargs) + self.model, _ = args.get_model_processor(**kwargs) if args.sequence_parallel_size > 1: from swift.trainers.sequence_parallel import sequence_parallel sequence_parallel.prepare( @@ -127,7 +126,6 @@ def _get_cached_dataset(self): val_datasets.append(load_from_disk(val_path)) return train_datasets, val_datasets - @RayHelper.function(group='default') def _prepare_dataset(self): args = self.args # Defer encoding to the training phase diff --git a/swift/ray/base.py b/swift/ray/base.py index 140841279a..1d4e23e83f 100644 --- a/swift/ray/base.py +++ b/swift/ray/base.py @@ -190,7 +190,7 @@ def collect_func(method: Union[Literal['none', 'flatten'], Callable], result): @staticmethod def function(group: str, dispatch: Union[Literal['slice', 'all'], Callable] = 'all', - execute: Literal['first', 'all'] = 'all', + execute: Literal['first', 'peer', 'all'] = 'all', collect: Union[Literal['none', 'flatten'], Callable] = 'none'): """Remote execution function. @@ -245,13 +245,39 @@ def execute_all_sync(group, dispatch, execute, method_name: str, *args, **kwargs return ray.get(RayHelper.execute_all_async(group, dispatch, execute, method_name, *args, **kwargs)) @staticmethod - def execute_all_async(group, dispatch, execute, method_name: str, *args, **kwargs): + def get_workers(group, dispatch): import ray if group in RayHelper.worker_instance: workers = RayHelper.worker_instance[group] else: workers = ray.get(RayHelper._registry.get_workers.remote(group)) + if dispatch == 'first': + return [workers[0]] + elif dispatch == 'all': + return workers + elif dispatch == 'peer': + return workers[RayHelper.get_peer_index(len(workers))] + else: + raise ValueError(f'Unsupported dispatch method: {dispatch}') + + @staticmethod + def get_peer_index(target_size): + rank = int(os.environ.get('RANK', '0')) + world_size = int(os.environ.get('WORLD_SIZE', '1')) + k, m = divmod(target_size, world_size) + start_idx = rank * k + min(rank, m) + end_idx = (rank + 1) * k + min(rank + 1, m) + if target_size < world_size: + start_idx = rank % target_size + end_idx = start_idx + 1 + + return slice(start_idx, end_idx) + + @staticmethod + def execute_all_async(group, dispatch, execute, method_name: str, *args, **kwargs): + import ray + workers = RayHelper.get_workers(group, dispatch) length = len(workers) if execute == 'first': return getattr(workers[0], method_name).remote(*args, **kwargs) diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 2dd168d039..2e4f303045 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -126,6 +126,7 @@ def _get_mean_metric(): # so reading train_state is skipped here. self.args.resume_from_checkpoint = None + @RayHelper.function(group='default') def _prepare_model_info(self, model): self.model_meta = model.model_meta if get_function(model.__class__.forward) is not get_function(model.forward): @@ -614,6 +615,7 @@ def clip_grad_norm_(self, parameters, *args, **kwargs): finally: Accelerator.clip_grad_norm_ = origin_clip_grad_norm_ + @RayHelper.function(group='default') def _patch_tasks(self): if isinstance(self.model, PeftModel): model = self.model.model diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index 949cc978a4..3fd84056bd 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -251,20 +251,6 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) return policy_output_decoded, ref_output_decoded - @contextmanager - def null_ref_context(self): - """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with ( - self.accelerator.unwrap_model(self.model).disable_adapter() - if self.is_peft_model and not self.ref_adapter_name - else nullcontext() - ): - if self.ref_adapter_name: - self.model.set_adapter(self.ref_adapter_name) - yield - if self.ref_adapter_name: - self.model.set_adapter(self.model_adapter_name or "default") - @RayHelper.function(group='ref') def generate_from_ref(self, batch): return self.ref_model.generate( From c9e1cb60670617975bec48f33f41b3d4681aa5fc Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 31 Oct 2025 23:54:22 +0800 Subject: [PATCH 09/16] fix --- swift/llm/train/rlhf.py | 6 ++ swift/ray/base.py | 68 ++++++++++++++++------ swift/trainers/mixin.py | 13 +++-- swift/trainers/rlhf_trainer/dpo_trainer.py | 20 ++++--- swift/trainers/rlhf_trainer/rlhf_mixin.py | 5 +- 5 files changed, 78 insertions(+), 34 deletions(-) diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index 4d639881ab..5fa36ca1f7 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -24,6 +24,8 @@ class SwiftRLHF(SwiftSft): args: args_class def __init__(self, args: RLHFArguments): + self.model = None + self.callbacks = [] super().__init__(args) self.reward_model = [] if self.args.rlhf_type == 'grpo': @@ -271,6 +273,7 @@ def _prepare_trainer(self): self._prepare_model(train_dataset) trainer_cls = TrainerFactory.get_trainer_cls(args) + self.args.training_args.ref_model = self.args.ref_model self.trainer = trainer_cls( model=self.model, args=self.args.training_args, @@ -281,6 +284,9 @@ def _prepare_trainer(self): template=self.template, **self._get_trainer_kwargs(), ) + + def call_trainer(self, func, *args, **kwargs): + return getattr(self.trainer, func)(*args, **kwargs) @RayHelper.function(group='default') def run(self): diff --git a/swift/ray/base.py b/swift/ray/base.py index 1d4e23e83f..e47f922ec9 100644 --- a/swift/ray/base.py +++ b/swift/ray/base.py @@ -2,6 +2,7 @@ import argparse import functools import inspect +from contextlib import contextmanager import os from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union @@ -38,7 +39,7 @@ class RayHelper: _registry = None @staticmethod - def _init_registry(): + def init_registry(): if RayHelper._registry is not None: return @@ -66,11 +67,14 @@ def clear(self): try: RayHelper._registry = ray.get_actor("swift_worker_registry") except ValueError: - RayHelper._registry = WorkerRegistry.options( - name="swift_worker_registry", - lifetime="detached", - namespace="default" - ).remote() + try: + RayHelper._registry = WorkerRegistry.options( + name="swift_worker_registry", + lifetime="detached", + ).remote() + except ValueError: + RayHelper._registry = ray.get_actor("swift_worker_registry") + assert RayHelper._registry is not None @staticmethod def initialize(device_groups: Dict[str, Any]): @@ -90,7 +94,7 @@ def initialize(device_groups: Dict[str, Any]): if RayHelper.resource_manager is None: # Resource manager initialize only once in the pipeline process. RayHelper.resource_manager = ResourceManager(device_groups) - RayHelper._init_registry() + RayHelper.init_registry() @staticmethod def teardown(): @@ -106,7 +110,26 @@ def teardown(): except: # noqa pass RayHelper._registry = None + + @contextmanager + def patch_init(): + if not RayHelper.is_default(): + from transformers import Trainer + init_method = Trainer.__init__ + @functools.wraps(init_method) + def new_init(self, *args, **kwargs): + from accelerate import Accelerator + self.processing_class = kwargs['processing_class'] + self.args = kwargs['args'] + self.accelerator = Accelerator(kwargs['args']) + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + Trainer.__init__ = new_init + yield + if not RayHelper.is_default(): + Trainer.__init__ = init_method + @staticmethod def is_called_from_init(): """If some function called from __init__. @@ -130,6 +153,10 @@ def ray_inited(): # not installed, not inited return False return ray.is_initialized() + + @staticmethod + def is_default(): + return 'default' in os.environ.get('RAY_SWIFT_GROUP', '').split(',') @staticmethod def is_worker(): @@ -215,6 +242,7 @@ def decorator(func: Callable[..., T]) -> Callable[..., T]: def wrapper(self, *args, **kwargs) -> T: if not RayHelper.ray_inited(): return func(self, *args, **kwargs) + RayHelper.init_registry() if RayHelper.is_worker(): if not hasattr(self, 'group'): # pass through env @@ -224,8 +252,8 @@ def wrapper(self, *args, **kwargs) -> T: # Functions in init of different group, do nothing return None else: - # Should not happen - raise ValueError() + result = RayHelper.execute_all_sync(group, dispatch, execute, func.__name__, *args, **kwargs) + return RayHelper.collect_func(collect, result) else: return func(self, *args, **kwargs) else: @@ -245,20 +273,20 @@ def execute_all_sync(group, dispatch, execute, method_name: str, *args, **kwargs return ray.get(RayHelper.execute_all_async(group, dispatch, execute, method_name, *args, **kwargs)) @staticmethod - def get_workers(group, dispatch): + def get_workers(group, execute): import ray if group in RayHelper.worker_instance: workers = RayHelper.worker_instance[group] else: workers = ray.get(RayHelper._registry.get_workers.remote(group)) - if dispatch == 'first': + if execute == 'first': return [workers[0]] - elif dispatch == 'all': + elif execute == 'all': return workers - elif dispatch == 'peer': + elif execute == 'peer': return workers[RayHelper.get_peer_index(len(workers))] else: - raise ValueError(f'Unsupported dispatch method: {dispatch}') + raise ValueError(f'Unsupported execute method: {execute}') @staticmethod def get_peer_index(target_size): @@ -277,7 +305,7 @@ def get_peer_index(target_size): @staticmethod def execute_all_async(group, dispatch, execute, method_name: str, *args, **kwargs): import ray - workers = RayHelper.get_workers(group, dispatch) + workers = RayHelper.get_workers(group, execute) length = len(workers) if execute == 'first': return getattr(workers[0], method_name).remote(*args, **kwargs) @@ -300,8 +328,14 @@ def dispatch_func(arg, n): sliced_kwargs = {k: v[i] for k, v in kwargs.items()} if (sliced_args and sliced_args[0]) or (kwargs and list(kwargs.values())): # skip empty input - remote_call = getattr(workers[i], method_name) - result.append(remote_call.remote(*sliced_args, **sliced_kwargs)) + # import ray; ray.util.pdb.set_trace() + if hasattr(workers[i], method_name): + remote_call = getattr(workers[i], method_name) + result.append(remote_call.remote(*sliced_args, **sliced_kwargs)) + else: + remote_call = getattr(workers[i], 'call_trainer') + result.append(remote_call.remote(method_name, *sliced_args, **sliced_kwargs)) + return result elif isinstance(dispatch, Callable): # dispatch is Callable diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 2e4f303045..c8a45732bf 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -104,7 +104,7 @@ def _get_mean_metric(): trainer_parameters = inspect.signature(Trainer.__init__).parameters tokenizer_key = 'processing_class' if 'processing_class' in trainer_parameters else 'tokenizer' kwargs[tokenizer_key] = template.tokenizer - with self.hub.patch_hub(): + with self.hub.patch_hub(), RayHelper.patch_init(): super().__init__( model=model, args=args, @@ -117,7 +117,8 @@ def _get_mean_metric(): **kwargs) self._prepare_model_info(model) - self.label_names = self.label_names or ['labels'] + if not getattr(self, 'label_names', []): + self.label_names = ['labels'] self.start_time = time.time() self._fix_gradient_checkpointing() self._patch_tasks() @@ -182,11 +183,13 @@ def deepspeed_load_checkpoint(*args, **kwargs): finally: trainer.deepspeed_load_checkpoint = origin_deepspeed_load_checkpoint - def get_use_logits_to_keep(self, default_value: bool = True): + def get_use_logits_to_keep(self, default_value: bool = True, model = None): use_logits_to_keep = self.args.use_logits_to_keep + if model is None: + model = self.model if use_logits_to_keep is None: - base_model = self.template.get_base_model(self.model) - use_logits_to_keep = (not self.model.model_meta.is_multimodal + base_model = self.template.get_base_model(model) + use_logits_to_keep = (not model.model_meta.is_multimodal and 'logits_to_keep' in inspect.signature(base_model.forward).parameters and default_value) logger.info_once(f'use_logits_to_keep: {use_logits_to_keep}') diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index 3fd84056bd..5a7fc6478c 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -83,10 +83,11 @@ def concatenated_forward( ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: batch = batch.copy() - use_logits_to_keep = self.get_use_logits_to_keep(self.template.sequence_parallel_size == 1) + use_logits_to_keep = self.get_use_logits_to_keep(self.template.sequence_parallel_size == 1, model) if use_logits_to_keep: self.prepare_logits_to_keep(batch) - if self.aux_loss_enabled: + aux_loss_enabled = model.model_info.is_moe_model and self.args.router_aux_loss_coef > 0 + if aux_loss_enabled: batch['output_router_logits'] = True labels = batch.pop('labels', None) if self.is_encoder_decoder: @@ -171,7 +172,8 @@ def concatenated_forward( output['rejected_logps'] = all_logps[num_examples:] output['mean_chosen_logits'] = mean_all_logits[:num_examples][loss_mask[:num_examples]].mean() output['mean_rejected_logits'] = mean_all_logits[num_examples:][loss_mask[num_examples:]].mean() - if self.aux_loss_enabled: + aux_loss_enabled = model.model_info.is_moe_model and self.args.router_aux_loss_coef > 0 + if aux_loss_enabled: output['aux_loss'] = outputs.aux_loss return output @@ -186,7 +188,7 @@ def prediction_step(self, model, inputs, *args, **kwargs): return super().prediction_step(model, inputs, *args, **kwargs) @RayHelper.function(group='default') - def _compute_log_probs(self, batch): + def compute_log_probs_model(self, batch): compte_ref_context_manager = ( autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() ) @@ -194,8 +196,8 @@ def _compute_log_probs(self, batch): ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] - @RayHelper.function(group='ref') - def _compute_ref_log_probs(self, batch): + @RayHelper.function(group='ref', execute='peer', dispatch='slice', collect='flatten') + def compute_log_probs_ref_model(self, batch): compte_ref_context_manager = ( autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() ) @@ -205,9 +207,9 @@ def _compute_ref_log_probs(self, batch): def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> dict: if self.args.ref_model is None: - return self._compute_log_probs(batch) + return self.compute_log_probs_model(batch) else: - return self._compute_ref_log_probs(batch) + return self.compute_log_probs_ref_model(batch) def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" @@ -251,7 +253,7 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) return policy_output_decoded, ref_output_decoded - @RayHelper.function(group='ref') + @RayHelper.function(group='ref', execute='peer', dispatch='slice', collect='flatten') def generate_from_ref(self, batch): return self.ref_model.generate( input_ids=batch["prompt_input_ids"], diff --git a/swift/trainers/rlhf_trainer/rlhf_mixin.py b/swift/trainers/rlhf_trainer/rlhf_mixin.py index 8468f94309..7f22db412d 100644 --- a/swift/trainers/rlhf_trainer/rlhf_mixin.py +++ b/swift/trainers/rlhf_trainer/rlhf_mixin.py @@ -32,15 +32,14 @@ def __init__(self, self.generate_during_eval = getattr(args, 'generate_during_eval', False) self._prepare_model(args, model) - self._prepare_ref_model(args, ref_model) - # not use self.is_vision_model = False self.label_pad_token_id = -100 self.use_dpo_data_collator = True self.aux_loss_coef = args.router_aux_loss_coef - self.padding_value = self.tokenizer.pad_token_id super().__init__(model, *_args, **kwargs) + self._prepare_ref_model(args, ref_model) + self.padding_value = self.tokenizer.pad_token_id @RayHelper.function(group='default') def _prepare_model(self, args, model): From 01b4f66a434361d0053ff169b70ce88297f5b23c Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 1 Nov 2025 10:52:22 +0800 Subject: [PATCH 10/16] lint --- swift/llm/train/rlhf.py | 9 ++-- swift/llm/train/sft.py | 10 +++-- swift/ray/__init__.py | 2 +- swift/ray/base.py | 44 +++++++++++-------- swift/trainers/mixin.py | 4 +- swift/trainers/rlhf_trainer/dpo_trainer.py | 51 +++++++++------------- 6 files changed, 58 insertions(+), 62 deletions(-) diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index 5fa36ca1f7..9afaeb3727 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -6,6 +6,7 @@ from swift.llm import safe_snapshot_download from swift.plugin import Tuner, extra_tuners from swift.ray import RayHelper +from swift.trainers import TrainerFactory from swift.tuners import Swift from swift.utils import get_logger, get_model_parameter_info from swift.utils.utils import disable_deepspeed_zero3 @@ -13,7 +14,6 @@ from ..model import HfConfigFactory from .kto import prepare_kto_dataset from .sft import SwiftSft -from ...trainers import TrainerFactory logger = get_logger() @@ -256,7 +256,7 @@ def _get_trainer_kwargs(self): return trainer_kwargs @RayHelper.function(group='default') - def _prepare_model(self, train_dataset): + def _add_adapter_to_model(self, train_dataset): # Some tuners require train_dataset and data_collator for preparation: LoRA-GA self.model = self.prepare_model(self.args, self.model, template=self.template, train_dataset=train_dataset) logger.info(f'model: {self.model}') @@ -270,7 +270,7 @@ def _prepare_trainer(self): args.save_args() data_collator = self._get_data_collator() - self._prepare_model(train_dataset) + self._add_adapter_to_model(train_dataset) trainer_cls = TrainerFactory.get_trainer_cls(args) self.args.training_args.ref_model = self.args.ref_model @@ -284,9 +284,6 @@ def _prepare_trainer(self): template=self.template, **self._get_trainer_kwargs(), ) - - def call_trainer(self, func, *args, **kwargs): - return getattr(self.trainer, func)(*args, **kwargs) @RayHelper.function(group='default') def run(self): diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 363d21481c..29baddb466 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -8,7 +8,7 @@ from swift.llm.dataset.loader import DatasetLoader from swift.plugin import extra_callbacks -from swift.ray import RayHelper +from swift.ray import RayHelper, RayMixin from swift.trainers import TrainerFactory from swift.utils import append_to_jsonl, get_logger, get_model_parameter_info, is_master, plot_images, stat_array from ..argument import TrainArguments @@ -21,7 +21,7 @@ @RayHelper.worker(group=['default']) -class SwiftSft(SwiftPipeline, TunerMixin): +class SwiftSft(SwiftPipeline, TunerMixin, RayMixin): args_class = TrainArguments args: args_class @@ -50,11 +50,12 @@ def _prepare_generation_config(self): logger.info(f'model.generation_config: {self.model.generation_config}') def _prepare_processor(self, **kwargs): - args = self.args - _, self.processor = args.get_model_processor(load_model=False, **kwargs) + """Prepare processor only.""" + _, self.processor = self.args.get_model_processor(load_model=False, **kwargs) @RayHelper.function(group='default') def _prepare_model_tokenizer(self, **kwargs): + """Prepare model and tokenizer.""" args = self.args self.model, _ = args.get_model_processor(**kwargs) if args.sequence_parallel_size > 1: @@ -80,6 +81,7 @@ def _prepare_template(self) -> None: @RayHelper.function(group='default') def _patch_model_to_template(self): + """Some template need model to do preprocess, especially for mllm models.""" if self.template.use_model: self.template.model = self.model diff --git a/swift/ray/__init__.py b/swift/ray/__init__.py index 61aad5097f..e978f3e3e2 100644 --- a/swift/ray/__init__.py +++ b/swift/ray/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .base import RayHelper +from .base import RayHelper, RayMixin def try_init_ray(): diff --git a/swift/ray/base.py b/swift/ray/base.py index e47f922ec9..ec4093b9b4 100644 --- a/swift/ray/base.py +++ b/swift/ray/base.py @@ -2,8 +2,8 @@ import argparse import functools import inspect -from contextlib import contextmanager import os +from contextlib import contextmanager from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union import json @@ -23,6 +23,12 @@ def get_args(): return json.dumps(unknown) +class RayMixin: + + def call_inner_function(self, func, *args, **kwargs): + return getattr(self.trainer, func)(*args, **kwargs) + + class RayHelper: resource_manager: Optional[ResourceManager] = None @@ -47,6 +53,7 @@ def init_registry(): @ray.remote class WorkerRegistry: + def __init__(self): self.workers = {} @@ -58,22 +65,19 @@ def get_workers(self, group: Optional[str] = None): return self.workers.get(group, []) return self.workers - def get_all_workers(self): - return self.workers - def clear(self): self.workers.clear() try: - RayHelper._registry = ray.get_actor("swift_worker_registry") + RayHelper._registry = ray.get_actor('swift_worker_registry') except ValueError: try: RayHelper._registry = WorkerRegistry.options( - name="swift_worker_registry", - lifetime="detached", + name='swift_worker_registry', + lifetime='detached', ).remote() except ValueError: - RayHelper._registry = ray.get_actor("swift_worker_registry") + RayHelper._registry = ray.get_actor('swift_worker_registry') assert RayHelper._registry is not None @staticmethod @@ -107,10 +111,11 @@ def teardown(): try: ray.get(RayHelper._registry.clear.remote()) ray.kill(RayHelper._registry) - except: # noqa + except: # noqa pass RayHelper._registry = None - + + @staticmethod @contextmanager def patch_init(): if not RayHelper.is_default(): @@ -123,13 +128,14 @@ def new_init(self, *args, **kwargs): self.processing_class = kwargs['processing_class'] self.args = kwargs['args'] self.accelerator = Accelerator(kwargs['args']) - self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None - self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + self.is_deepspeed_enabled = getattr(self.accelerator.state, 'deepspeed_plugin', None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, 'fsdp_plugin', None) is not None + Trainer.__init__ = new_init yield if not RayHelper.is_default(): Trainer.__init__ = init_method - + @staticmethod def is_called_from_init(): """If some function called from __init__. @@ -153,7 +159,7 @@ def ray_inited(): # not installed, not inited return False return ray.is_initialized() - + @staticmethod def is_default(): return 'default' in os.environ.get('RAY_SWIFT_GROUP', '').split(',') @@ -252,7 +258,8 @@ def wrapper(self, *args, **kwargs) -> T: # Functions in init of different group, do nothing return None else: - result = RayHelper.execute_all_sync(group, dispatch, execute, func.__name__, *args, **kwargs) + result = RayHelper.execute_all_sync(group, dispatch, execute, func.__name__, *args, + **kwargs) return RayHelper.collect_func(collect, result) else: return func(self, *args, **kwargs) @@ -328,14 +335,13 @@ def dispatch_func(arg, n): sliced_kwargs = {k: v[i] for k, v in kwargs.items()} if (sliced_args and sliced_args[0]) or (kwargs and list(kwargs.values())): # skip empty input - # import ray; ray.util.pdb.set_trace() if hasattr(workers[i], method_name): remote_call = getattr(workers[i], method_name) result.append(remote_call.remote(*sliced_args, **sliced_kwargs)) else: - remote_call = getattr(workers[i], 'call_trainer') + remote_call = getattr(workers[i], 'call_inner_function') result.append(remote_call.remote(method_name, *sliced_args, **sliced_kwargs)) - + return result elif isinstance(dispatch, Callable): # dispatch is Callable @@ -481,4 +487,4 @@ def get_node_address(): for g in local_groups: RayHelper.worker_instance[g] = workers - ray.get(RayHelper._registry.register_workers.remote(g, workers)) \ No newline at end of file + ray.get(RayHelper._registry.register_workers.remote(g, workers)) diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index c8a45732bf..50d50188a7 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -41,12 +41,12 @@ from swift.llm import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, Template, get_llm_model from swift.llm.utils import update_generation_config_eos_token from swift.plugin import MeanMetric, compute_acc, extra_tuners, get_loss_func, get_metric +from swift.ray import RayHelper from swift.tuners import SwiftModel from swift.utils import get_current_device, get_logger, is_dist, is_mp, is_mp_ddp, ms_logger_context, seed_worker from ..llm.model.patcher import get_lm_head_model, revert_padding_free, transformers_seq_cls_forward from .arguments import TrainingArguments from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model -from swift.ray import RayHelper try: from trl import AutoModelForCausalLMWithValueHead @@ -183,7 +183,7 @@ def deepspeed_load_checkpoint(*args, **kwargs): finally: trainer.deepspeed_load_checkpoint = origin_deepspeed_load_checkpoint - def get_use_logits_to_keep(self, default_value: bool = True, model = None): + def get_use_logits_to_keep(self, default_value: bool = True, model=None): use_logits_to_keep = self.args.use_logits_to_keep if model is None: model = self.model diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index 5a7fc6478c..cb0282da63 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -1,17 +1,18 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import warnings -from contextlib import nullcontext, contextmanager +from contextlib import contextmanager, nullcontext from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from accelerate.utils import gather_object +from torch import autocast from transformers import PreTrainedModel from transformers.utils.versions import require_version from trl import DPOTrainer as HFDPOTrainer from trl.trainer.dpo_config import DPOConfig from trl.trainer.utils import RunningMoments, pad_to_length -from torch import autocast + from swift.llm import to_device from swift.ray import RayHelper from swift.utils import get_logger @@ -172,72 +173,62 @@ def concatenated_forward( output['rejected_logps'] = all_logps[num_examples:] output['mean_chosen_logits'] = mean_all_logits[:num_examples][loss_mask[:num_examples]].mean() output['mean_rejected_logits'] = mean_all_logits[num_examples:][loss_mask[num_examples:]].mean() - aux_loss_enabled = model.model_info.is_moe_model and self.args.router_aux_loss_coef > 0 if aux_loss_enabled: output['aux_loss'] = outputs.aux_loss return output - @RayHelper.function(group='default') def training_step(self, model, inputs, *args, **kwargs): with self.template.forward_context(self.model, inputs): return super().training_step(model, inputs, *args, **kwargs) - @RayHelper.function(group='default') def prediction_step(self, model, inputs, *args, **kwargs): with self.template.forward_context(self.model, inputs): return super().prediction_step(model, inputs, *args, **kwargs) @RayHelper.function(group='default') - def compute_log_probs_model(self, batch): + def _compute_log_probs(self, batch): compte_ref_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()) with torch.no_grad(), compte_ref_context_manager, self.null_ref_context(): ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) - return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] + return ref_model_output['chosen_logps'], ref_model_output['rejected_logps'] @RayHelper.function(group='ref', execute='peer', dispatch='slice', collect='flatten') - def compute_log_probs_ref_model(self, batch): + def _compute_ref_log_probs_ref(self, batch): compte_ref_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()) with torch.no_grad(), compte_ref_context_manager: ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True) - return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] + return ref_model_output['chosen_logps'], ref_model_output['rejected_logps'] def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> dict: if self.args.ref_model is None: - return self.compute_log_probs_model(batch) + return self._compute_log_probs(batch) else: - return self.compute_log_probs_ref_model(batch) + return self._compute_ref_log_probs_ref(batch) def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: - """Generate samples from the model and reference model for the given batch of inputs.""" - - # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with - # the torch amp context manager as some hidden states are silently casted to full precision. generate_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()) with generate_context_manager: policy_output = model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], + input_ids=batch['prompt_input_ids'], + attention_mask=batch['prompt_attention_mask'], max_length=self.max_length, do_sample=True, pad_token_id=self.padding_value, ) # if ref_output in batch use that otherwise use the reference model - if "ref_output" in batch: - ref_output = batch["ref_output"] + if 'ref_output' in batch: + ref_output = batch['ref_output'] else: if self.args.ref_model is None: with self.null_ref_context(): ref_output = model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], + input_ids=batch['prompt_input_ids'], + attention_mask=batch['prompt_attention_mask'], max_length=self.max_length, do_sample=True, pad_token_id=self.padding_value, @@ -256,9 +247,9 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) @RayHelper.function(group='ref', execute='peer', dispatch='slice', collect='flatten') def generate_from_ref(self, batch): return self.ref_model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], + input_ids=batch['prompt_input_ids'], + attention_mask=batch['prompt_attention_mask'], max_length=self.max_length, do_sample=True, pad_token_id=self.padding_value, - ) \ No newline at end of file + ) From 789b3c3f8ed7ba55b0fc2c0bd0bbef4e06c7bb73 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 1 Nov 2025 13:52:25 +0800 Subject: [PATCH 11/16] fix --- swift/ray/base.py | 27 ++++++++++++---------- swift/trainers/mixin.py | 3 +++ swift/trainers/rlhf_trainer/dpo_trainer.py | 23 +++++++++++++----- 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/swift/ray/base.py b/swift/ray/base.py index ec4093b9b4..fbca1f84ea 100644 --- a/swift/ray/base.py +++ b/swift/ray/base.py @@ -130,6 +130,8 @@ def new_init(self, *args, **kwargs): self.accelerator = Accelerator(kwargs['args']) self.is_deepspeed_enabled = getattr(self.accelerator.state, 'deepspeed_plugin', None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, 'fsdp_plugin', None) is not None + if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None: + self.propagate_args_to_deepspeed() Trainer.__init__ = new_init yield @@ -311,13 +313,21 @@ def get_peer_index(target_size): @staticmethod def execute_all_async(group, dispatch, execute, method_name: str, *args, **kwargs): - import ray workers = RayHelper.get_workers(group, execute) length = len(workers) + + def remote_func(worker, *args, **kwargs): + if hasattr(worker, method_name): + remote_call = getattr(worker, method_name) + return remote_call.remote(*args, **kwargs) + else: + remote_call = getattr(worker, 'call_inner_function') + return remote_call.remote(method_name, *args, **kwargs) + if execute == 'first': - return getattr(workers[0], method_name).remote(*args, **kwargs) + return remote_func(workers[0], *args, **kwargs) elif dispatch == 'all': - return [getattr(worker, method_name).remote(*args, **kwargs) for worker in workers] + return [remote_func(worker, *args, **kwargs) for worker in workers] elif dispatch == 'slice': result = [] @@ -334,13 +344,7 @@ def dispatch_func(arg, n): sliced_args = tuple(arg[i] for arg in args) sliced_kwargs = {k: v[i] for k, v in kwargs.items()} if (sliced_args and sliced_args[0]) or (kwargs and list(kwargs.values())): - # skip empty input - if hasattr(workers[i], method_name): - remote_call = getattr(workers[i], method_name) - result.append(remote_call.remote(*sliced_args, **sliced_kwargs)) - else: - remote_call = getattr(workers[i], 'call_inner_function') - result.append(remote_call.remote(method_name, *sliced_args, **sliced_kwargs)) + result.append(remote_func(workers[i], *sliced_args, **sliced_kwargs)) return result elif isinstance(dispatch, Callable): @@ -348,8 +352,7 @@ def dispatch_func(arg, n): result = [] for i in range(length): sliced_args, sliced_kwargs = dispatch(length, i, *args, **kwargs) - remote_call = getattr(workers[i], method_name) - result.append(remote_call.remote(*sliced_args, **sliced_kwargs)) + result.append(remote_func(workers[i], *sliced_args, **sliced_kwargs)) return result else: raise ValueError(f'Invalid dispatch method: {dispatch}') diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 50d50188a7..99afb0fd30 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -104,6 +104,8 @@ def _get_mean_metric(): trainer_parameters = inspect.signature(Trainer.__init__).parameters tokenizer_key = 'processing_class' if 'processing_class' in trainer_parameters else 'tokenizer' kwargs[tokenizer_key] = template.tokenizer + # if 'ref' in os.environ.get('RAY_SWIFT_GROUP', ''): + import ray; ray.util.pdb.set_trace() with self.hub.patch_hub(), RayHelper.patch_init(): super().__init__( model=model, @@ -187,6 +189,7 @@ def get_use_logits_to_keep(self, default_value: bool = True, model=None): use_logits_to_keep = self.args.use_logits_to_keep if model is None: model = self.model + model = unwrap_model(model) if use_logits_to_keep is None: base_model = self.template.get_base_model(model) use_logits_to_keep = (not model.model_meta.is_multimodal diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index cb0282da63..31ac1fe969 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -12,7 +12,7 @@ from trl import DPOTrainer as HFDPOTrainer from trl.trainer.dpo_config import DPOConfig from trl.trainer.utils import RunningMoments, pad_to_length - +from transformers.modeling_utils import unwrap_model from swift.llm import to_device from swift.ray import RayHelper from swift.utils import get_logger @@ -29,6 +29,17 @@ def new_gather_function(tensor): return torch.concat(to_device(tensor_list, tensor.device), dim=0) +def dispatch_ref_batch(n, i, *args, **kwargs): + mini_batch = {} + batch = args[0] + for key in batch: + tensor = batch[key] + batch_size = tensor.shape[0] + k, m = divmod(batch_size, n) + mini_batch[key] = tensor[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] + return (mini_batch, ), {} + + class DPOTrainer(RLHFTrainerMixin, SwiftMixin, DataLoaderMixin, HFDPOTrainer): def __init__(self, @@ -87,7 +98,7 @@ def concatenated_forward( use_logits_to_keep = self.get_use_logits_to_keep(self.template.sequence_parallel_size == 1, model) if use_logits_to_keep: self.prepare_logits_to_keep(batch) - aux_loss_enabled = model.model_info.is_moe_model and self.args.router_aux_loss_coef > 0 + aux_loss_enabled = unwrap_model(model).model_info.is_moe_model and self.args.router_aux_loss_coef > 0 if aux_loss_enabled: batch['output_router_logits'] = True labels = batch.pop('labels', None) @@ -193,8 +204,8 @@ def _compute_log_probs(self, batch): ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) return ref_model_output['chosen_logps'], ref_model_output['rejected_logps'] - @RayHelper.function(group='ref', execute='peer', dispatch='slice', collect='flatten') - def _compute_ref_log_probs_ref(self, batch): + @RayHelper.function(group='ref', execute='peer', dispatch=dispatch_ref_batch, collect='flatten') + def _compute_ref_log_probs(self, batch): compte_ref_context_manager = ( autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()) with torch.no_grad(), compte_ref_context_manager: @@ -205,7 +216,7 @@ def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> dict: if self.args.ref_model is None: return self._compute_log_probs(batch) else: - return self._compute_ref_log_probs_ref(batch) + return self._compute_ref_log_probs(batch) def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: generate_context_manager = ( @@ -244,7 +255,7 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) return policy_output_decoded, ref_output_decoded - @RayHelper.function(group='ref', execute='peer', dispatch='slice', collect='flatten') + @RayHelper.function(group='ref', execute='peer', dispatch=dispatch_ref_batch, collect='flatten') def generate_from_ref(self, batch): return self.ref_model.generate( input_ids=batch['prompt_input_ids'], From 0b46447b44bf001b320bebf6514e5e5fb41cfb5d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 1 Nov 2025 16:52:56 +0800 Subject: [PATCH 12/16] fix --- examples/train/multi-node/ray/sft.yaml | 2 +- examples/train/rlhf/dpo/dpo.yaml | 42 ++++++++++++++++++++++++++ swift/llm/train/rlhf.py | 2 +- swift/llm/train/sft.py | 2 +- swift/ray/base.py | 41 +++++++++++++++++-------- swift/trainers/mixin.py | 2 -- 6 files changed, 74 insertions(+), 17 deletions(-) create mode 100644 examples/train/rlhf/dpo/dpo.yaml diff --git a/examples/train/multi-node/ray/sft.yaml b/examples/train/multi-node/ray/sft.yaml index f338ff379c..8dd53c5190 100644 --- a/examples/train/multi-node/ray/sft.yaml +++ b/examples/train/multi-node/ray/sft.yaml @@ -33,4 +33,4 @@ device_groups: device: GPU ranks: list(range(0, 4)) workers: - - default + - sft:default diff --git a/examples/train/rlhf/dpo/dpo.yaml b/examples/train/rlhf/dpo/dpo.yaml new file mode 100644 index 0000000000..9d5ec6b178 --- /dev/null +++ b/examples/train/rlhf/dpo/dpo.yaml @@ -0,0 +1,42 @@ +rlhf_type: dpo +model: Qwen/Qwen2.5-VL-3B-Instruct +ref_model: Qwen/Qwen2.5-VL-3B-Instruct +train_type: full +dataset: swift/RLAIF-V-Dataset#1000 +load_from_cache_file: true +split_dataset_ratio: 0.01 +torch_dtype: bfloat16 +num_train_epochs: 1 +per_device_train_batch_size: 4 +per_device_eval_batch_size: 4 +learning_rate: 1e-4 +lora_rank: 8 +lora_alpha: 32 +target_modules: all-linear +gradient_accumulation_steps: 16 +# deepspeed: zero3 +eval_steps: 50 +save_steps: 50 +save_total_limit: 2 +logging_steps: 5 +max_length: 2048 +output_dir: output +warmup_ratio: 0.05 +dataloader_num_workers: 4 +rpo_alpha: 0.1 +dataset_num_proc: 4 + +use_ray: true + +device_groups: + nproc_per_node: 4 + sample_group: + device: GPU + ranks: list(range(0, 2)) + workers: + - rlhf:default + rm_group: + device: GPU + ranks: list(range(2, 4)) + workers: + - ref \ No newline at end of file diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index 9afaeb3727..9ffb9d1051 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -18,7 +18,7 @@ logger = get_logger() -@RayHelper.worker(group=['default', 'ref', 'reward', 'value', 'teacher']) +@RayHelper.worker(group=['rlhf:default', 'ref', 'reward', 'value', 'teacher']) class SwiftRLHF(SwiftSft): args_class = RLHFArguments args: args_class diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 29baddb466..717f197af7 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -20,7 +20,7 @@ logger = get_logger() -@RayHelper.worker(group=['default']) +@RayHelper.worker(group=['sft:default']) class SwiftSft(SwiftPipeline, TunerMixin, RayMixin): args_class = TrainArguments args: args_class diff --git a/swift/ray/base.py b/swift/ray/base.py index fbca1f84ea..1682c8667b 100644 --- a/swift/ray/base.py +++ b/swift/ray/base.py @@ -3,6 +3,7 @@ import functools import inspect import os +from types import SimpleNamespace from contextlib import contextmanager from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union @@ -58,7 +59,14 @@ def __init__(self): self.workers = {} def register_workers(self, group: str, worker_handles: List): - self.workers[group] = worker_handles + if group == 'sft:default': + group = ['default', 'sft:default'] + elif group == 'rlhf:default': + group = ['default', 'rlhf:default'] + else: + group = [group] + for _group in group: + self.workers[_group] = worker_handles def get_workers(self, group: Optional[str] = None): if group: @@ -118,24 +126,25 @@ def teardown(): @staticmethod @contextmanager def patch_init(): - if not RayHelper.is_default(): + if RayHelper.ray_inited() and not RayHelper.is_default(): from transformers import Trainer init_method = Trainer.__init__ @functools.wraps(init_method) def new_init(self, *args, **kwargs): - from accelerate import Accelerator + from transformers import Trainer, enable_full_determinism, set_seed + self: Trainer self.processing_class = kwargs['processing_class'] - self.args = kwargs['args'] - self.accelerator = Accelerator(kwargs['args']) - self.is_deepspeed_enabled = getattr(self.accelerator.state, 'deepspeed_plugin', None) is not None - self.is_fsdp_enabled = getattr(self.accelerator.state, 'fsdp_plugin', None) is not None - if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None: - self.propagate_args_to_deepspeed() + args = kwargs['args'] + self.args = args + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + self.model = SimpleNamespace(tp_size=1) + self.create_accelerator_and_postprocess() + self.args._setup_devices Trainer.__init__ = new_init yield - if not RayHelper.is_default(): + if RayHelper.ray_inited() and not RayHelper.is_default(): Trainer.__init__ = init_method @staticmethod @@ -164,7 +173,9 @@ def ray_inited(): @staticmethod def is_default(): - return 'default' in os.environ.get('RAY_SWIFT_GROUP', '').split(',') + ray_groups = os.environ.get('RAY_SWIFT_GROUP', '').split(',') + default_names = ['default', 'sft:default', 'rlhf:default'] + return any(name in ray_groups for name in default_names) @staticmethod def is_worker(): @@ -203,6 +214,8 @@ def new_init(self, *args, **kwargs): @staticmethod def collect_func(method: Union[Literal['none', 'flatten'], Callable], result): + if not result: + return result if isinstance(result[0], tuple): output = [] for i in range(len(result[0])): @@ -254,7 +267,11 @@ def wrapper(self, *args, **kwargs) -> T: if RayHelper.is_worker(): if not hasattr(self, 'group'): # pass through env - self.group = os.environ['RAY_SWIFT_GROUP'].split(',') + default_names = ['default', 'sft:default', 'rlhf:default'] + groups = os.environ['RAY_SWIFT_GROUP'].split(',') + if 'sft:default' in groups or 'rlhf:default' in groups: + groups.append('default') + self.group = groups if group not in self.group: if RayHelper.is_called_from_init(): # Functions in init of different group, do nothing diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 99afb0fd30..e7926d7e15 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -104,8 +104,6 @@ def _get_mean_metric(): trainer_parameters = inspect.signature(Trainer.__init__).parameters tokenizer_key = 'processing_class' if 'processing_class' in trainer_parameters else 'tokenizer' kwargs[tokenizer_key] = template.tokenizer - # if 'ref' in os.environ.get('RAY_SWIFT_GROUP', ''): - import ray; ray.util.pdb.set_trace() with self.hub.patch_hub(), RayHelper.patch_init(): super().__init__( model=model, From 1e98804f8a4cf2c042b84e534aca59dc56e68ee3 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 1 Nov 2025 17:07:25 +0800 Subject: [PATCH 13/16] wip --- ...ay\347\232\204\346\224\257\346\214\201.md" | 19 ++++++++++--------- docs/source_en/Instruction/Ray.md | 19 ++++++++++--------- examples/train/multi-node/ray/dpo.sh | 1 + .../{rlhf/dpo => multi-node/ray}/dpo.yaml | 4 ++-- swift/llm/train/pt.py | 2 ++ swift/ray/base.py | 9 +++++---- swift/trainers/rlhf_trainer/dpo_trainer.py | 3 ++- 7 files changed, 32 insertions(+), 25 deletions(-) create mode 100644 examples/train/multi-node/ray/dpo.sh rename examples/train/{rlhf/dpo => multi-node/ray}/dpo.yaml (93%) diff --git "a/docs/source/Instruction/ray\347\232\204\346\224\257\346\214\201.md" "b/docs/source/Instruction/ray\347\232\204\346\224\257\346\214\201.md" index 9e825528be..a2f6fa17dd 100644 --- "a/docs/source/Instruction/ray\347\232\204\346\224\257\346\214\201.md" +++ "b/docs/source/Instruction/ray\347\232\204\346\224\257\346\214\201.md" @@ -2,15 +2,16 @@ SWIFT已经支持使用ray来进行多卡或多节点训练。已有功能中对ray的支持情况如下: -| 功能 | 支持ray | 例子 | 可分配角色 | -|----------|-------|--------------------------------------------------------------------------------|-----------------| -| pt/sft | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray | default | -| dpo | ❎ | | | -| grpo | ❎ | | | -| ppo | ❎ | | | -| megatron | ❎ | | | -| sampling | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/distill | sampler/prm/orm | -| distill | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/sample | sampler/prm/orm | +| 功能 | 支持ray | 例子 | 可分配角色 | +|----------|-------|---------------------------------------------------------------------------------------|------------------| +| pt | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/sft.sh | pt:default | +| sft | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/sft.sh | sft:default | +| dpo | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/dpo.sh | rlhf:default/ref | +| grpo | ❎ | | | +| ppo | ❎ | | | +| megatron | ❎ | | | +| sampling | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/distill | sampler/prm/orm | +| distill | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/sample | sampler/prm/orm | ## 技术细节 diff --git a/docs/source_en/Instruction/Ray.md b/docs/source_en/Instruction/Ray.md index cdda7957c3..41a8123b30 100644 --- a/docs/source_en/Instruction/Ray.md +++ b/docs/source_en/Instruction/Ray.md @@ -2,15 +2,16 @@ SWIFT already supports using Ray for multi-GPU or multi-node training. The support status for Ray in existing features is as follows: -| Feature | Ray Support | Example | Assignable Roles | -|----------|-------------|--------------------------------------------------------------------------------|------------------| -| pt/sft | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray | default | -| dpo | ❎ | | | -| grpo | ❎ | | | -| ppo | ❎ | | | -| megatron | ❎ | | | -| sampling | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/distill | sampler/prm/orm | -| distill | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/sample | sampler/prm/orm | +| Feature | Ray Support | Example | Assignable Roles | +|----------|-------------|---------------------------------------------------------------------------------------|------------------| +| pt | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/sft.sh | pt:default | +| sft | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/sft.sh | sft:default | +| dpo | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/dpo.sh | rlhf:default/ref | +| grpo | ❎ | | | +| ppo | ❎ | | | +| megatron | ❎ | | | +| sampling | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/distill | sampler/prm/orm | +| distill | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/sample | sampler/prm/orm | ## Technical Details diff --git a/examples/train/multi-node/ray/dpo.sh b/examples/train/multi-node/ray/dpo.sh new file mode 100644 index 0000000000..fbabae97aa --- /dev/null +++ b/examples/train/multi-node/ray/dpo.sh @@ -0,0 +1 @@ +swift rlhf --config dpo.yaml diff --git a/examples/train/rlhf/dpo/dpo.yaml b/examples/train/multi-node/ray/dpo.yaml similarity index 93% rename from examples/train/rlhf/dpo/dpo.yaml rename to examples/train/multi-node/ray/dpo.yaml index 9d5ec6b178..3d70db4faa 100644 --- a/examples/train/rlhf/dpo/dpo.yaml +++ b/examples/train/multi-node/ray/dpo.yaml @@ -14,7 +14,6 @@ lora_rank: 8 lora_alpha: 32 target_modules: all-linear gradient_accumulation_steps: 16 -# deepspeed: zero3 eval_steps: 50 save_steps: 50 save_total_limit: 2 @@ -28,6 +27,7 @@ dataset_num_proc: 4 use_ray: true +# Ranks of rlhf:default and ref must equal device_groups: nproc_per_node: 4 sample_group: @@ -39,4 +39,4 @@ device_groups: device: GPU ranks: list(range(2, 4)) workers: - - ref \ No newline at end of file + - ref diff --git a/swift/llm/train/pt.py b/swift/llm/train/pt.py index c7b1858756..af8ee3b62c 100644 --- a/swift/llm/train/pt.py +++ b/swift/llm/train/pt.py @@ -4,10 +4,12 @@ from swift.utils import get_logger from ..argument import TrainArguments from .sft import SwiftSft +from swift.ray import RayHelper logger = get_logger() +@RayHelper.worker(group=['pt:default']) class SwiftPt(SwiftSft): args_class = TrainArguments args: args_class diff --git a/swift/ray/base.py b/swift/ray/base.py index 1682c8667b..8b87a12242 100644 --- a/swift/ray/base.py +++ b/swift/ray/base.py @@ -3,8 +3,8 @@ import functools import inspect import os -from types import SimpleNamespace from contextlib import contextmanager +from types import SimpleNamespace from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union import json @@ -61,6 +61,8 @@ def __init__(self): def register_workers(self, group: str, worker_handles: List): if group == 'sft:default': group = ['default', 'sft:default'] + elif group == 'pt:default': + group = ['default', 'pt:default'] elif group == 'rlhf:default': group = ['default', 'rlhf:default'] else: @@ -174,7 +176,7 @@ def ray_inited(): @staticmethod def is_default(): ray_groups = os.environ.get('RAY_SWIFT_GROUP', '').split(',') - default_names = ['default', 'sft:default', 'rlhf:default'] + default_names = ['default', 'sft:default', 'rlhf:default', 'pt:default'] return any(name in ray_groups for name in default_names) @staticmethod @@ -267,9 +269,8 @@ def wrapper(self, *args, **kwargs) -> T: if RayHelper.is_worker(): if not hasattr(self, 'group'): # pass through env - default_names = ['default', 'sft:default', 'rlhf:default'] groups = os.environ['RAY_SWIFT_GROUP'].split(',') - if 'sft:default' in groups or 'rlhf:default' in groups: + if 'sft:default' in groups or 'rlhf:default' in groups or 'pt:default' in groups: groups.append('default') self.group = groups if group not in self.group: diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index 31ac1fe969..c178d6cb2e 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -8,11 +8,12 @@ from accelerate.utils import gather_object from torch import autocast from transformers import PreTrainedModel +from transformers.modeling_utils import unwrap_model from transformers.utils.versions import require_version from trl import DPOTrainer as HFDPOTrainer from trl.trainer.dpo_config import DPOConfig from trl.trainer.utils import RunningMoments, pad_to_length -from transformers.modeling_utils import unwrap_model + from swift.llm import to_device from swift.ray import RayHelper from swift.utils import get_logger From 746e5338edd29e3b56b1cf892ffd96d5550b6253 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 1 Nov 2025 17:43:14 +0800 Subject: [PATCH 14/16] fix --- examples/train/multi-node/ray/dpo.yaml | 9 +++---- examples/train/multi-node/ray/pt.sh | 1 + examples/train/multi-node/ray/pt.yaml | 33 ++++++++++++++++++++++++++ swift/cli/pt.py | 4 +++- 4 files changed, 40 insertions(+), 7 deletions(-) create mode 100644 examples/train/multi-node/ray/pt.sh create mode 100644 examples/train/multi-node/ray/pt.yaml diff --git a/examples/train/multi-node/ray/dpo.yaml b/examples/train/multi-node/ray/dpo.yaml index 3d70db4faa..d2baed47b2 100644 --- a/examples/train/multi-node/ray/dpo.yaml +++ b/examples/train/multi-node/ray/dpo.yaml @@ -10,12 +10,9 @@ num_train_epochs: 1 per_device_train_batch_size: 4 per_device_eval_batch_size: 4 learning_rate: 1e-4 -lora_rank: 8 -lora_alpha: 32 -target_modules: all-linear -gradient_accumulation_steps: 16 -eval_steps: 50 -save_steps: 50 +gradient_accumulation_steps: 2 +eval_steps: 1 +save_steps: 1 save_total_limit: 2 logging_steps: 5 max_length: 2048 diff --git a/examples/train/multi-node/ray/pt.sh b/examples/train/multi-node/ray/pt.sh new file mode 100644 index 0000000000..9aefab2447 --- /dev/null +++ b/examples/train/multi-node/ray/pt.sh @@ -0,0 +1 @@ +swift sft --config pt.yaml diff --git a/examples/train/multi-node/ray/pt.yaml b/examples/train/multi-node/ray/pt.yaml new file mode 100644 index 0000000000..0179118b9a --- /dev/null +++ b/examples/train/multi-node/ray/pt.yaml @@ -0,0 +1,33 @@ +model: Qwen/Qwen2.5-7B +train_type: full +dataset: swift/chinese-c4#2000 +torch_dtype: bfloat16 +streaming: true +per_device_train_batch_size: 1 +per_device_eval_batch_size: 1 +learning_rate: 1e-5 +gradient_accumulation_steps: 2 +packing: true +eval_steps: 500 +save_steps: 500 +save_total_limit: 2 +logging_steps: 5 +deepspeed: zero3 +max_length: 8192 +max_steps: 10000 +warmup_ratio: 0.05 +dataloader_num_workers: 4 +dataset_num_proc: 8 +save_only_model: true +output_dir: output/Qwen2.5-7B +attn_impl: flash_attn + +use_ray: true + +device_groups: + nproc_per_node: 4 + default: + device: GPU + ranks: list(range(0, 4)) + workers: + - pt:default \ No newline at end of file diff --git a/swift/cli/pt.py b/swift/cli/pt.py index 1ca2aabd8a..bec06023c7 100644 --- a/swift/cli/pt.py +++ b/swift/cli/pt.py @@ -1,5 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import pt_main if __name__ == '__main__': + from swift.ray import try_init_ray + try_init_ray() + from swift.llm import pt_main pt_main() From afd608b06e7eb1d9609ff49c84c3cfce73ef1fd6 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 1 Nov 2025 17:45:52 +0800 Subject: [PATCH 15/16] fix --- examples/train/multi-node/ray/pt.yaml | 2 +- swift/llm/train/pt.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/train/multi-node/ray/pt.yaml b/examples/train/multi-node/ray/pt.yaml index 0179118b9a..79bcf0d7ed 100644 --- a/examples/train/multi-node/ray/pt.yaml +++ b/examples/train/multi-node/ray/pt.yaml @@ -30,4 +30,4 @@ device_groups: device: GPU ranks: list(range(0, 4)) workers: - - pt:default \ No newline at end of file + - pt:default diff --git a/swift/llm/train/pt.py b/swift/llm/train/pt.py index af8ee3b62c..83984c08be 100644 --- a/swift/llm/train/pt.py +++ b/swift/llm/train/pt.py @@ -1,10 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import List, Optional, Union +from swift.ray import RayHelper from swift.utils import get_logger from ..argument import TrainArguments from .sft import SwiftSft -from swift.ray import RayHelper logger = get_logger() From 081fa47ee436fc97b44b0a2e0cb7853b96f8570d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 1 Nov 2025 17:49:39 +0800 Subject: [PATCH 16/16] fix doc --- .../Instruction/ray\347\232\204\346\224\257\346\214\201.md" | 2 +- docs/source_en/Instruction/Ray.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git "a/docs/source/Instruction/ray\347\232\204\346\224\257\346\214\201.md" "b/docs/source/Instruction/ray\347\232\204\346\224\257\346\214\201.md" index a2f6fa17dd..ec69f92de1 100644 --- "a/docs/source/Instruction/ray\347\232\204\346\224\257\346\214\201.md" +++ "b/docs/source/Instruction/ray\347\232\204\346\224\257\346\214\201.md" @@ -4,7 +4,7 @@ SWIFT已经支持使用ray来进行多卡或多节点训练。已有功能中对 | 功能 | 支持ray | 例子 | 可分配角色 | |----------|-------|---------------------------------------------------------------------------------------|------------------| -| pt | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/sft.sh | pt:default | +| pt | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/pt.sh | pt:default | | sft | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/sft.sh | sft:default | | dpo | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/dpo.sh | rlhf:default/ref | | grpo | ❎ | | | diff --git a/docs/source_en/Instruction/Ray.md b/docs/source_en/Instruction/Ray.md index 41a8123b30..755dd4d33d 100644 --- a/docs/source_en/Instruction/Ray.md +++ b/docs/source_en/Instruction/Ray.md @@ -4,7 +4,7 @@ SWIFT already supports using Ray for multi-GPU or multi-node training. The suppo | Feature | Ray Support | Example | Assignable Roles | |----------|-------------|---------------------------------------------------------------------------------------|------------------| -| pt | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/sft.sh | pt:default | +| pt | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/pt.sh | pt:default | | sft | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/sft.sh | sft:default | | dpo | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/dpo.sh | rlhf:default/ref | | grpo | ❎ | | | @@ -13,7 +13,7 @@ SWIFT already supports using Ray for multi-GPU or multi-node training. The suppo | sampling | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/distill | sampler/prm/orm | | distill | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/sample | sampler/prm/orm | -## Technical Details +## Technical Detailsp Before describing parameter settings, it's necessary to first explain the technical details. Since SWIFT currently uses many existing implementations from transformers and trl internally, decomposing into different Ray roles like veRL or ROLL is impractical, and decomposition would center around Ray, resulting in poor support for non-Ray scenarios.