diff --git "a/docs/source/Instruction/ray\347\232\204\346\224\257\346\214\201.md" "b/docs/source/Instruction/ray\347\232\204\346\224\257\346\214\201.md" index 9e825528be..ec69f92de1 100644 --- "a/docs/source/Instruction/ray\347\232\204\346\224\257\346\214\201.md" +++ "b/docs/source/Instruction/ray\347\232\204\346\224\257\346\214\201.md" @@ -2,15 +2,16 @@ SWIFT已经支持使用ray来进行多卡或多节点训练。已有功能中对ray的支持情况如下: -| 功能 | 支持ray | 例子 | 可分配角色 | -|----------|-------|--------------------------------------------------------------------------------|-----------------| -| pt/sft | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray | default | -| dpo | ❎ | | | -| grpo | ❎ | | | -| ppo | ❎ | | | -| megatron | ❎ | | | -| sampling | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/distill | sampler/prm/orm | -| distill | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/sample | sampler/prm/orm | +| 功能 | 支持ray | 例子 | 可分配角色 | +|----------|-------|---------------------------------------------------------------------------------------|------------------| +| pt | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/pt.sh | pt:default | +| sft | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/sft.sh | sft:default | +| dpo | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/dpo.sh | rlhf:default/ref | +| grpo | ❎ | | | +| ppo | ❎ | | | +| megatron | ❎ | | | +| sampling | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/distill | sampler/prm/orm | +| distill | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/sample | sampler/prm/orm | ## 技术细节 diff --git a/docs/source_en/Instruction/Ray.md b/docs/source_en/Instruction/Ray.md index cdda7957c3..755dd4d33d 100644 --- a/docs/source_en/Instruction/Ray.md +++ b/docs/source_en/Instruction/Ray.md @@ -2,17 +2,18 @@ SWIFT already supports using Ray for multi-GPU or multi-node training. The support status for Ray in existing features is as follows: -| Feature | Ray Support | Example | Assignable Roles | -|----------|-------------|--------------------------------------------------------------------------------|------------------| -| pt/sft | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray | default | -| dpo | ❎ | | | -| grpo | ❎ | | | -| ppo | ❎ | | | -| megatron | ❎ | | | -| sampling | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/distill | sampler/prm/orm | -| distill | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/sample | sampler/prm/orm | - -## Technical Details +| Feature | Ray Support | Example | Assignable Roles | +|----------|-------------|---------------------------------------------------------------------------------------|------------------| +| pt | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/pt.sh | pt:default | +| sft | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/sft.sh | sft:default | +| dpo | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/dpo.sh | rlhf:default/ref | +| grpo | ❎ | | | +| ppo | ❎ | | | +| megatron | ❎ | | | +| sampling | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/distill | sampler/prm/orm | +| distill | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/sample | sampler/prm/orm | + +## Technical Detailsp Before describing parameter settings, it's necessary to first explain the technical details. Since SWIFT currently uses many existing implementations from transformers and trl internally, decomposing into different Ray roles like veRL or ROLL is impractical, and decomposition would center around Ray, resulting in poor support for non-Ray scenarios. diff --git a/examples/train/multi-node/ray/dpo.sh b/examples/train/multi-node/ray/dpo.sh new file mode 100644 index 0000000000..fbabae97aa --- /dev/null +++ b/examples/train/multi-node/ray/dpo.sh @@ -0,0 +1 @@ +swift rlhf --config dpo.yaml diff --git a/examples/train/multi-node/ray/dpo.yaml b/examples/train/multi-node/ray/dpo.yaml new file mode 100644 index 0000000000..d2baed47b2 --- /dev/null +++ b/examples/train/multi-node/ray/dpo.yaml @@ -0,0 +1,39 @@ +rlhf_type: dpo +model: Qwen/Qwen2.5-VL-3B-Instruct +ref_model: Qwen/Qwen2.5-VL-3B-Instruct +train_type: full +dataset: swift/RLAIF-V-Dataset#1000 +load_from_cache_file: true +split_dataset_ratio: 0.01 +torch_dtype: bfloat16 +num_train_epochs: 1 +per_device_train_batch_size: 4 +per_device_eval_batch_size: 4 +learning_rate: 1e-4 +gradient_accumulation_steps: 2 +eval_steps: 1 +save_steps: 1 +save_total_limit: 2 +logging_steps: 5 +max_length: 2048 +output_dir: output +warmup_ratio: 0.05 +dataloader_num_workers: 4 +rpo_alpha: 0.1 +dataset_num_proc: 4 + +use_ray: true + +# Ranks of rlhf:default and ref must equal +device_groups: + nproc_per_node: 4 + sample_group: + device: GPU + ranks: list(range(0, 2)) + workers: + - rlhf:default + rm_group: + device: GPU + ranks: list(range(2, 4)) + workers: + - ref diff --git a/examples/train/multi-node/ray/pt.sh b/examples/train/multi-node/ray/pt.sh new file mode 100644 index 0000000000..9aefab2447 --- /dev/null +++ b/examples/train/multi-node/ray/pt.sh @@ -0,0 +1 @@ +swift sft --config pt.yaml diff --git a/examples/train/multi-node/ray/pt.yaml b/examples/train/multi-node/ray/pt.yaml new file mode 100644 index 0000000000..79bcf0d7ed --- /dev/null +++ b/examples/train/multi-node/ray/pt.yaml @@ -0,0 +1,33 @@ +model: Qwen/Qwen2.5-7B +train_type: full +dataset: swift/chinese-c4#2000 +torch_dtype: bfloat16 +streaming: true +per_device_train_batch_size: 1 +per_device_eval_batch_size: 1 +learning_rate: 1e-5 +gradient_accumulation_steps: 2 +packing: true +eval_steps: 500 +save_steps: 500 +save_total_limit: 2 +logging_steps: 5 +deepspeed: zero3 +max_length: 8192 +max_steps: 10000 +warmup_ratio: 0.05 +dataloader_num_workers: 4 +dataset_num_proc: 8 +save_only_model: true +output_dir: output/Qwen2.5-7B +attn_impl: flash_attn + +use_ray: true + +device_groups: + nproc_per_node: 4 + default: + device: GPU + ranks: list(range(0, 4)) + workers: + - pt:default diff --git a/examples/train/multi-node/ray/sft.yaml b/examples/train/multi-node/ray/sft.yaml index f338ff379c..8dd53c5190 100644 --- a/examples/train/multi-node/ray/sft.yaml +++ b/examples/train/multi-node/ray/sft.yaml @@ -33,4 +33,4 @@ device_groups: device: GPU ranks: list(range(0, 4)) workers: - - default + - sft:default diff --git a/swift/cli/pt.py b/swift/cli/pt.py index 1ca2aabd8a..bec06023c7 100644 --- a/swift/cli/pt.py +++ b/swift/cli/pt.py @@ -1,5 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import pt_main if __name__ == '__main__': + from swift.ray import try_init_ray + try_init_ray() + from swift.llm import pt_main pt_main() diff --git a/swift/cli/rlhf.py b/swift/cli/rlhf.py index 4f0fd6a0ab..0419df5c86 100644 --- a/swift/cli/rlhf.py +++ b/swift/cli/rlhf.py @@ -1,5 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import rlhf_main if __name__ == '__main__': + from swift.ray import try_init_ray + try_init_ray() + from swift.llm import rlhf_main rlhf_main() diff --git a/swift/llm/train/pt.py b/swift/llm/train/pt.py index c7b1858756..83984c08be 100644 --- a/swift/llm/train/pt.py +++ b/swift/llm/train/pt.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import List, Optional, Union +from swift.ray import RayHelper from swift.utils import get_logger from ..argument import TrainArguments from .sft import SwiftSft @@ -8,6 +9,7 @@ logger = get_logger() +@RayHelper.worker(group=['pt:default']) class SwiftPt(SwiftSft): args_class = TrainArguments args: args_class diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index 832b458ee5..9ffb9d1051 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -5,6 +5,8 @@ from swift.llm import safe_snapshot_download from swift.plugin import Tuner, extra_tuners +from swift.ray import RayHelper +from swift.trainers import TrainerFactory from swift.tuners import Swift from swift.utils import get_logger, get_model_parameter_info from swift.utils.utils import disable_deepspeed_zero3 @@ -16,10 +18,20 @@ logger = get_logger() +@RayHelper.worker(group=['rlhf:default', 'ref', 'reward', 'value', 'teacher']) class SwiftRLHF(SwiftSft): args_class = RLHFArguments args: args_class + def __init__(self, args: RLHFArguments): + self.model = None + self.callbacks = [] + super().__init__(args) + self.reward_model = [] + if self.args.rlhf_type == 'grpo': + self.reward_template = [] + self._prepare_trainer() + @staticmethod def _get_model_task_type(model_dir): task_type = None @@ -48,6 +60,47 @@ def _get_model_task_type(model_dir): task_type = 'seq_cls' return task_type, num_labels + @RayHelper.function(group='ref') + def _prepare_ref_model(self, key, origin_key, model_type, model_revision): + result = self._prepare_single_model(key, origin_key, model_type, model_revision) + + if result is not None: + self.ref_model = result[0] + + @RayHelper.function(group='value') + def _prepare_value_model(self, key, origin_key, model_type, model_revision): + result = self._prepare_single_model(key, origin_key, model_type, model_revision) + + if result is not None: + self.value_model = result[0] + + @RayHelper.function(group='teacher') + def _prepare_teacher_model(self, key, origin_key, model_type, model_revision): + result = self._prepare_single_model(key, origin_key, model_type, model_revision) + + if result is not None: + self.teacher_model = result[0] + + def _prepare_reward_model(self, reward_model_path, key, origin_key, model_type, model_revision): + rms = self.args.reward_model if isinstance(self.args.reward_model, list) else [self.args.reward_model] + self.args.reward_model = reward_model_path # Temporarily set for prepare_single_model + result = self._prepare_single_model(key, origin_key, model_type, model_revision) + + if result is not None: + model, processor = result + self.reward_model.append(model) + + if self.args.rlhf_type == 'grpo': + reward_template = self.args.get_template(processor, processor.model_meta.template) + if reward_template.use_model: + reward_template.model = model + self.reward_template.append(reward_template) + self.args.reward_model = rms # Restore original value + + if self.args.rlhf_type != 'grpo' and self.reward_model: + assert len(self.reward_model) <= 1 + self.reward_model = self.reward_model[0] + def _prepare_single_model(self, key, origin_key, model_type, model_revision): from swift.llm.infer.utils import prepare_adapter args = self.args @@ -116,10 +169,7 @@ def _prepare_model_tokenizer(self): model_type = model_type[0] if model_type else None model_revision = model_revision[0] if model_revision else None - result = self._prepare_single_model(model_key, key, model_type, model_revision) - if result is not None: - model, _ = result - setattr(self, f'{key}_model', model) + getattr(self, f'_prepare_{key}_model')(model_key, key, model_type, model_revision) # Handle reward model(s) self.reward_model = None @@ -130,26 +180,9 @@ def _prepare_model_tokenizer(self): rm_revisions = args.reward_model_revision if args.reward_model_revision else [None] * num_rms assert len(rms) == len(rm_types) == len(rm_revisions) - self.reward_model = [] - if args.rlhf_type == 'grpo': - self.reward_template = [] - - for reward_model_path, rm_type, rm_revision in zip(rms, rm_types, rm_revisions): - args.reward_model = reward_model_path # Temporarily set for prepare_single_model - result = self._prepare_single_model('reward', None, rm_type, rm_revision) - if result is not None: - model, processor = result - self.reward_model.append(model) - - if args.rlhf_type == 'grpo': - reward_template = self.args.get_template(processor, processor.model_meta.template) - if reward_template.use_model: - reward_template.model = model - self.reward_template.append(reward_template) - args.reward_model = rms # Restore original value - if args.rlhf_type != 'grpo' and self.reward_model: - assert len(self.reward_model) <= 1 - self.reward_model = self.reward_model[0] + for rm_idx, (reward_model_path, rm_type, rm_revision) in enumerate(zip(rms, rm_types, rm_revisions)): + _prepare_reward_model = RayHelper.function(group=f'reward_{rm_idx}')(self._prepare_reward_model) + _prepare_reward_model(reward_model_path, 'reward', None, rm_type, rm_revision) super()._prepare_model_tokenizer() @@ -222,6 +255,40 @@ def _get_trainer_kwargs(self): trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed return trainer_kwargs + @RayHelper.function(group='default') + def _add_adapter_to_model(self, train_dataset): + # Some tuners require train_dataset and data_collator for preparation: LoRA-GA + self.model = self.prepare_model(self.args, self.model, template=self.template, train_dataset=train_dataset) + logger.info(f'model: {self.model}') + model_parameter_info = get_model_parameter_info(self.model) + self.train_msg['model_parameter_info'] = model_parameter_info + logger.info(f'model_parameter_info: {model_parameter_info}') + + def _prepare_trainer(self): + args = self.args + train_dataset, val_dataset = self._prepare_dataset() + args.save_args() + + data_collator = self._get_data_collator() + self._add_adapter_to_model(train_dataset) + + trainer_cls = TrainerFactory.get_trainer_cls(args) + self.args.training_args.ref_model = self.args.ref_model + self.trainer = trainer_cls( + model=self.model, + args=self.args.training_args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=val_dataset, + callbacks=self.callbacks, + template=self.template, + **self._get_trainer_kwargs(), + ) + + @RayHelper.function(group='default') + def run(self): + return self.train(self.trainer) + def rlhf_main(args: Optional[Union[List[str], RLHFArguments]] = None): return SwiftRLHF(args).main() diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 502b895646..717f197af7 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -8,7 +8,7 @@ from swift.llm.dataset.loader import DatasetLoader from swift.plugin import extra_callbacks -from swift.ray import RayHelper +from swift.ray import RayHelper, RayMixin from swift.trainers import TrainerFactory from swift.utils import append_to_jsonl, get_logger, get_model_parameter_info, is_master, plot_images, stat_array from ..argument import TrainArguments @@ -20,20 +20,21 @@ logger = get_logger() -@RayHelper.worker(group=['default']) -class SwiftSft(SwiftPipeline, TunerMixin): +@RayHelper.worker(group=['sft:default']) +class SwiftSft(SwiftPipeline, TunerMixin, RayMixin): args_class = TrainArguments args: args_class def __init__(self, args: Optional[Union[List[str], TrainArguments]] = None) -> None: super().__init__(args) self.train_msg = {} + self._prepare_processor() self._prepare_model_tokenizer() self._prepare_template() + self._patch_model_to_template() self._prepare_callbacks() self._prepare_flash_ckpt() - @RayHelper.function(group='default') def _prepare_flash_ckpt(self): if self.args.use_flash_ckpt: try: @@ -48,10 +49,15 @@ def _prepare_generation_config(self): args.get_request_config(), self.tokenizer) logger.info(f'model.generation_config: {self.model.generation_config}') + def _prepare_processor(self, **kwargs): + """Prepare processor only.""" + _, self.processor = self.args.get_model_processor(load_model=False, **kwargs) + @RayHelper.function(group='default') def _prepare_model_tokenizer(self, **kwargs): + """Prepare model and tokenizer.""" args = self.args - self.model, self.processor = args.get_model_processor(**kwargs) + self.model, _ = args.get_model_processor(**kwargs) if args.sequence_parallel_size > 1: from swift.trainers.sequence_parallel import sequence_parallel sequence_parallel.prepare( @@ -65,17 +71,20 @@ def _prepare_model_tokenizer(self, **kwargs): self._prepare_generation_config() - @RayHelper.function(group='default') def _prepare_template(self) -> None: args = self.args template = args.get_template(self.processor) template.set_mode('train') - if template.use_model: - template.model = self.model if args.model_meta.is_multimodal and (args.padding_free or args.packing) and not template.support_padding_free: raise ValueError(f'Template `{args.template}` does not support padding free or packing.') self.template = template + @RayHelper.function(group='default') + def _patch_model_to_template(self): + """Some template need model to do preprocess, especially for mllm models.""" + if self.template.use_model: + self.template.model = self.model + def _get_dataset(self): # The random shuffling of the training set occurs in the dataloader of the trainer. args = self.args @@ -119,7 +128,6 @@ def _get_cached_dataset(self): val_datasets.append(load_from_disk(val_path)) return train_datasets, val_datasets - @RayHelper.function(group='default') def _prepare_dataset(self): args = self.args # Defer encoding to the training phase diff --git a/swift/ray/__init__.py b/swift/ray/__init__.py index 61aad5097f..e978f3e3e2 100644 --- a/swift/ray/__init__.py +++ b/swift/ray/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .base import RayHelper +from .base import RayHelper, RayMixin def try_init_ray(): diff --git a/swift/ray/base.py b/swift/ray/base.py index 44bfad191a..8b87a12242 100644 --- a/swift/ray/base.py +++ b/swift/ray/base.py @@ -3,6 +3,8 @@ import functools import inspect import os +from contextlib import contextmanager +from types import SimpleNamespace from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union import json @@ -22,6 +24,12 @@ def get_args(): return json.dumps(unknown) +class RayMixin: + + def call_inner_function(self, func, *args, **kwargs): + return getattr(self.trainer, func)(*args, **kwargs) + + class RayHelper: resource_manager: Optional[ResourceManager] = None @@ -35,6 +43,53 @@ class RayHelper: device_groups: Dict[str, Any] = None + _registry = None + + @staticmethod + def init_registry(): + if RayHelper._registry is not None: + return + + import ray + + @ray.remote + class WorkerRegistry: + + def __init__(self): + self.workers = {} + + def register_workers(self, group: str, worker_handles: List): + if group == 'sft:default': + group = ['default', 'sft:default'] + elif group == 'pt:default': + group = ['default', 'pt:default'] + elif group == 'rlhf:default': + group = ['default', 'rlhf:default'] + else: + group = [group] + for _group in group: + self.workers[_group] = worker_handles + + def get_workers(self, group: Optional[str] = None): + if group: + return self.workers.get(group, []) + return self.workers + + def clear(self): + self.workers.clear() + + try: + RayHelper._registry = ray.get_actor('swift_worker_registry') + except ValueError: + try: + RayHelper._registry = WorkerRegistry.options( + name='swift_worker_registry', + lifetime='detached', + ).remote() + except ValueError: + RayHelper._registry = ray.get_actor('swift_worker_registry') + assert RayHelper._registry is not None + @staticmethod def initialize(device_groups: Dict[str, Any]): """Initialize RayHelper. @@ -53,6 +108,7 @@ def initialize(device_groups: Dict[str, Any]): if RayHelper.resource_manager is None: # Resource manager initialize only once in the pipeline process. RayHelper.resource_manager = ResourceManager(device_groups) + RayHelper.init_registry() @staticmethod def teardown(): @@ -60,6 +116,39 @@ def teardown(): RayHelper.resource_manager.destroy_placement_group() RayHelper.resource_manager = None + if RayHelper._registry is not None: + import ray + try: + ray.get(RayHelper._registry.clear.remote()) + ray.kill(RayHelper._registry) + except: # noqa + pass + RayHelper._registry = None + + @staticmethod + @contextmanager + def patch_init(): + if RayHelper.ray_inited() and not RayHelper.is_default(): + from transformers import Trainer + init_method = Trainer.__init__ + + @functools.wraps(init_method) + def new_init(self, *args, **kwargs): + from transformers import Trainer, enable_full_determinism, set_seed + self: Trainer + self.processing_class = kwargs['processing_class'] + args = kwargs['args'] + self.args = args + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + self.model = SimpleNamespace(tp_size=1) + self.create_accelerator_and_postprocess() + self.args._setup_devices + + Trainer.__init__ = new_init + yield + if RayHelper.ray_inited() and not RayHelper.is_default(): + Trainer.__init__ = init_method + @staticmethod def is_called_from_init(): """If some function called from __init__. @@ -84,6 +173,12 @@ def ray_inited(): return False return ray.is_initialized() + @staticmethod + def is_default(): + ray_groups = os.environ.get('RAY_SWIFT_GROUP', '').split(',') + default_names = ['default', 'sft:default', 'rlhf:default', 'pt:default'] + return any(name in ray_groups for name in default_names) + @staticmethod def is_worker(): import ray @@ -121,6 +216,8 @@ def new_init(self, *args, **kwargs): @staticmethod def collect_func(method: Union[Literal['none', 'flatten'], Callable], result): + if not result: + return result if isinstance(result[0], tuple): output = [] for i in range(len(result[0])): @@ -143,7 +240,7 @@ def collect_func(method: Union[Literal['none', 'flatten'], Callable], result): @staticmethod def function(group: str, dispatch: Union[Literal['slice', 'all'], Callable] = 'all', - execute: Literal['first', 'all'] = 'all', + execute: Literal['first', 'peer', 'all'] = 'all', collect: Union[Literal['none', 'flatten'], Callable] = 'none'): """Remote execution function. @@ -168,17 +265,22 @@ def decorator(func: Callable[..., T]) -> Callable[..., T]: def wrapper(self, *args, **kwargs) -> T: if not RayHelper.ray_inited(): return func(self, *args, **kwargs) + RayHelper.init_registry() if RayHelper.is_worker(): if not hasattr(self, 'group'): # pass through env - self.group = os.environ['RAY_SWIFT_GROUP'].split(',') + groups = os.environ['RAY_SWIFT_GROUP'].split(',') + if 'sft:default' in groups or 'rlhf:default' in groups or 'pt:default' in groups: + groups.append('default') + self.group = groups if group not in self.group: if RayHelper.is_called_from_init(): # Functions in init of different group, do nothing return None else: - # Should not happen - raise ValueError() + result = RayHelper.execute_all_sync(group, dispatch, execute, func.__name__, *args, + **kwargs) + return RayHelper.collect_func(collect, result) else: return func(self, *args, **kwargs) else: @@ -197,14 +299,53 @@ def execute_all_sync(group, dispatch, execute, method_name: str, *args, **kwargs import ray return ray.get(RayHelper.execute_all_async(group, dispatch, execute, method_name, *args, **kwargs)) + @staticmethod + def get_workers(group, execute): + import ray + if group in RayHelper.worker_instance: + workers = RayHelper.worker_instance[group] + else: + workers = ray.get(RayHelper._registry.get_workers.remote(group)) + if execute == 'first': + return [workers[0]] + elif execute == 'all': + return workers + elif execute == 'peer': + return workers[RayHelper.get_peer_index(len(workers))] + else: + raise ValueError(f'Unsupported execute method: {execute}') + + @staticmethod + def get_peer_index(target_size): + rank = int(os.environ.get('RANK', '0')) + world_size = int(os.environ.get('WORLD_SIZE', '1')) + + k, m = divmod(target_size, world_size) + start_idx = rank * k + min(rank, m) + end_idx = (rank + 1) * k + min(rank + 1, m) + if target_size < world_size: + start_idx = rank % target_size + end_idx = start_idx + 1 + + return slice(start_idx, end_idx) + @staticmethod def execute_all_async(group, dispatch, execute, method_name: str, *args, **kwargs): - workers = RayHelper.worker_instance[group] + workers = RayHelper.get_workers(group, execute) length = len(workers) + + def remote_func(worker, *args, **kwargs): + if hasattr(worker, method_name): + remote_call = getattr(worker, method_name) + return remote_call.remote(*args, **kwargs) + else: + remote_call = getattr(worker, 'call_inner_function') + return remote_call.remote(method_name, *args, **kwargs) + if execute == 'first': - return getattr(workers[0], method_name).remote(*args, **kwargs) + return remote_func(workers[0], *args, **kwargs) elif dispatch == 'all': - return [getattr(worker, method_name).remote(*args, **kwargs) for worker in workers] + return [remote_func(worker, *args, **kwargs) for worker in workers] elif dispatch == 'slice': result = [] @@ -221,17 +362,15 @@ def dispatch_func(arg, n): sliced_args = tuple(arg[i] for arg in args) sliced_kwargs = {k: v[i] for k, v in kwargs.items()} if (sliced_args and sliced_args[0]) or (kwargs and list(kwargs.values())): - # skip empty input - remote_call = getattr(workers[i], method_name) - result.append(remote_call.remote(*sliced_args, **sliced_kwargs)) + result.append(remote_func(workers[i], *sliced_args, **sliced_kwargs)) + return result elif isinstance(dispatch, Callable): # dispatch is Callable result = [] for i in range(length): sliced_args, sliced_kwargs = dispatch(length, i, *args, **kwargs) - remote_call = getattr(workers[i], method_name) - result.append(remote_call.remote(*sliced_args, **sliced_kwargs)) + result.append(remote_func(workers[i], *sliced_args, **sliced_kwargs)) return result else: raise ValueError(f'Invalid dispatch method: {dispatch}') @@ -265,7 +404,8 @@ def _create_workers(group: Union[str, List[str]], *args, **kwargs): _config = config break - assert _config is not None + if _config is None: + continue local_groups = _config['workers'] VISIBLE_ENV_MAPPING = { @@ -368,3 +508,4 @@ def get_node_address(): for g in local_groups: RayHelper.worker_instance[g] = workers + ray.get(RayHelper._registry.register_workers.remote(g, workers)) diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 4565521411..e7926d7e15 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -41,6 +41,7 @@ from swift.llm import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, Template, get_llm_model from swift.llm.utils import update_generation_config_eos_token from swift.plugin import MeanMetric, compute_acc, extra_tuners, get_loss_func, get_metric +from swift.ray import RayHelper from swift.tuners import SwiftModel from swift.utils import get_current_device, get_logger, is_dist, is_mp, is_mp_ddp, ms_logger_context, seed_worker from ..llm.model.patcher import get_lm_head_model, revert_padding_free, transformers_seq_cls_forward @@ -74,7 +75,7 @@ def __init__(self, logger.warning('Using IterableDataset, setting args.dataloader_num_workers to 1.') self.compute_loss_func = None # Compatible with the older version of transformers - if args.check_model and hasattr(model, 'model_dir'): + if model is not None and args.check_model and hasattr(model, 'model_dir'): with ms_logger_context(logging.CRITICAL), self._patch_timeout(): check_local_model_is_latest( model.model_dir, user_agent={ @@ -99,13 +100,11 @@ def _get_mean_metric(): self.template = template self.hub = get_hub() - self.model_meta = model.model_meta - kwargs.update(self.create_loss_and_metric(args)) trainer_parameters = inspect.signature(Trainer.__init__).parameters tokenizer_key = 'processing_class' if 'processing_class' in trainer_parameters else 'tokenizer' kwargs[tokenizer_key] = template.tokenizer - with self.hub.patch_hub(): + with self.hub.patch_hub(), RayHelper.patch_init(): super().__init__( model=model, args=args, @@ -117,21 +116,27 @@ def _get_mean_metric(): optimizers=optimizers, **kwargs) - if get_function(model.__class__.forward) is not get_function(model.forward): - self.label_names = find_labels(model) - self.can_return_loss = can_return_loss(model) - self.label_names = self.label_names or ['labels'] + self._prepare_model_info(model) + if not getattr(self, 'label_names', []): + self.label_names = ['labels'] self.start_time = time.time() self._fix_gradient_checkpointing() self._patch_tasks() - update_generation_config_eos_token(self.model.generation_config, self.template) - if getattr(self.model, 'origin_generation_config', None): - self.model.origin_generation_config.eos_token_id = self.model.generation_config.eos_token_id if self.args.resume_only_model and self.args.ignore_data_skip: # The weights have already been loaded outside the trainer, # so reading train_state is skipped here. self.args.resume_from_checkpoint = None + @RayHelper.function(group='default') + def _prepare_model_info(self, model): + self.model_meta = model.model_meta + if get_function(model.__class__.forward) is not get_function(model.forward): + self.label_names = find_labels(model) + self.can_return_loss = can_return_loss(model) + update_generation_config_eos_token(self.model.generation_config, self.template) + if getattr(self.model, 'origin_generation_config', None): + self.model.origin_generation_config.eos_token_id = self.model.generation_config.eos_token_id + @contextmanager def _patch_timeout(self): from modelscope.hub.api import HubApi @@ -178,11 +183,14 @@ def deepspeed_load_checkpoint(*args, **kwargs): finally: trainer.deepspeed_load_checkpoint = origin_deepspeed_load_checkpoint - def get_use_logits_to_keep(self, default_value: bool = True): + def get_use_logits_to_keep(self, default_value: bool = True, model=None): use_logits_to_keep = self.args.use_logits_to_keep + if model is None: + model = self.model + model = unwrap_model(model) if use_logits_to_keep is None: - base_model = self.template.get_base_model(self.model) - use_logits_to_keep = (not self.model.model_meta.is_multimodal + base_model = self.template.get_base_model(model) + use_logits_to_keep = (not model.model_meta.is_multimodal and 'logits_to_keep' in inspect.signature(base_model.forward).parameters and default_value) logger.info_once(f'use_logits_to_keep: {use_logits_to_keep}') @@ -611,6 +619,7 @@ def clip_grad_norm_(self, parameters, *args, **kwargs): finally: Accelerator.clip_grad_norm_ = origin_clip_grad_norm_ + @RayHelper.function(group='default') def _patch_tasks(self): if isinstance(self.model, PeftModel): model = self.model.model diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index b84a7051eb..c178d6cb2e 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -1,18 +1,21 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import warnings +from contextlib import contextmanager, nullcontext from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from accelerate.utils import gather_object -from peft import PeftModel +from torch import autocast from transformers import PreTrainedModel +from transformers.modeling_utils import unwrap_model from transformers.utils.versions import require_version from trl import DPOTrainer as HFDPOTrainer from trl.trainer.dpo_config import DPOConfig -from trl.trainer.utils import RunningMoments +from trl.trainer.utils import RunningMoments, pad_to_length from swift.llm import to_device +from swift.ray import RayHelper from swift.utils import get_logger from ..mixin import DataLoaderMixin, SwiftMixin from .rlhf_mixin import RLHFTrainerMixin @@ -27,6 +30,17 @@ def new_gather_function(tensor): return torch.concat(to_device(tensor_list, tensor.device), dim=0) +def dispatch_ref_batch(n, i, *args, **kwargs): + mini_batch = {} + batch = args[0] + for key in batch: + tensor = batch[key] + batch_size = tensor.shape[0] + k, m = divmod(batch_size, n) + mini_batch[key] = tensor[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] + return (mini_batch, ), {} + + class DPOTrainer(RLHFTrainerMixin, SwiftMixin, DataLoaderMixin, HFDPOTrainer): def __init__(self, @@ -60,13 +74,11 @@ def __init__(self, self.precompute_ref_log_probs = args.precompute_ref_log_probs self.f_divergence_type = args.f_divergence_type self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} - self.is_peft_model = isinstance(model, PeftModel) self.ref_adapter_name = args.ref_adapter_name self.model_adapter_name = None self.reference_free = args.reference_free self.use_weighting = False - super().__init__(model, ref_model, *_args, **kwargs) if 'bco_pair' in loss_types: @@ -84,10 +96,11 @@ def concatenated_forward( ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: batch = batch.copy() - use_logits_to_keep = self.get_use_logits_to_keep(self.template.sequence_parallel_size == 1) + use_logits_to_keep = self.get_use_logits_to_keep(self.template.sequence_parallel_size == 1, model) if use_logits_to_keep: self.prepare_logits_to_keep(batch) - if self.aux_loss_enabled: + aux_loss_enabled = unwrap_model(model).model_info.is_moe_model and self.args.router_aux_loss_coef > 0 + if aux_loss_enabled: batch['output_router_logits'] = True labels = batch.pop('labels', None) if self.is_encoder_decoder: @@ -172,7 +185,7 @@ def concatenated_forward( output['rejected_logps'] = all_logps[num_examples:] output['mean_chosen_logits'] = mean_all_logits[:num_examples][loss_mask[:num_examples]].mean() output['mean_rejected_logits'] = mean_all_logits[num_examples:][loss_mask[num_examples:]].mean() - if self.aux_loss_enabled: + if aux_loss_enabled: output['aux_loss'] = outputs.aux_loss return output @@ -183,3 +196,72 @@ def training_step(self, model, inputs, *args, **kwargs): def prediction_step(self, model, inputs, *args, **kwargs): with self.template.forward_context(self.model, inputs): return super().prediction_step(model, inputs, *args, **kwargs) + + @RayHelper.function(group='default') + def _compute_log_probs(self, batch): + compte_ref_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()) + with torch.no_grad(), compte_ref_context_manager, self.null_ref_context(): + ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) + return ref_model_output['chosen_logps'], ref_model_output['rejected_logps'] + + @RayHelper.function(group='ref', execute='peer', dispatch=dispatch_ref_batch, collect='flatten') + def _compute_ref_log_probs(self, batch): + compte_ref_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()) + with torch.no_grad(), compte_ref_context_manager: + ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True) + return ref_model_output['chosen_logps'], ref_model_output['rejected_logps'] + + def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> dict: + if self.args.ref_model is None: + return self._compute_log_probs(batch) + else: + return self._compute_ref_log_probs(batch) + + def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: + generate_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()) + + with generate_context_manager: + policy_output = model.generate( + input_ids=batch['prompt_input_ids'], + attention_mask=batch['prompt_attention_mask'], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.padding_value, + ) + + # if ref_output in batch use that otherwise use the reference model + if 'ref_output' in batch: + ref_output = batch['ref_output'] + else: + if self.args.ref_model is None: + with self.null_ref_context(): + ref_output = model.generate( + input_ids=batch['prompt_input_ids'], + attention_mask=batch['prompt_attention_mask'], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.padding_value, + ) + else: + ref_output = self.generate_from_ref(batch) + + policy_output = pad_to_length(policy_output, self.max_length, self.padding_value) + policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) + + ref_output = pad_to_length(ref_output, self.max_length, self.padding_value) + ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True) + + return policy_output_decoded, ref_output_decoded + + @RayHelper.function(group='ref', execute='peer', dispatch=dispatch_ref_batch, collect='flatten') + def generate_from_ref(self, batch): + return self.ref_model.generate( + input_ids=batch['prompt_input_ids'], + attention_mask=batch['prompt_attention_mask'], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.padding_value, + ) diff --git a/swift/trainers/rlhf_trainer/rlhf_mixin.py b/swift/trainers/rlhf_trainer/rlhf_mixin.py index 428799da62..7f22db412d 100644 --- a/swift/trainers/rlhf_trainer/rlhf_mixin.py +++ b/swift/trainers/rlhf_trainer/rlhf_mixin.py @@ -7,11 +7,14 @@ import torch import torch.nn as nn +from peft import PeftModel from torch.utils.data import DataLoader from transformers import PreTrainedModel from trl.models.utils import prepare_deepspeed from trl.trainer.utils import selective_log_softmax +from swift.ray import RayHelper + class RLHFTrainerMixin: @@ -20,37 +23,48 @@ def __init__(self, ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, *_args, **kwargs): - from trl.trainer import disable_dropout_in_model - from swift.llm import HfConfigFactory self.ref_model = ref_model self._stored_metrics = defaultdict(lambda: defaultdict(list)) args = kwargs['args'] self.beta = getattr(args, 'beta', 0.0) - if getattr(args, 'disable_dropout', False): - disable_dropout_in_model(model) - if self.ref_model is not None: - disable_dropout_in_model(self.ref_model) - self.is_encoder_decoder = kwargs['template'].is_encoder_decoder self._peft_has_been_casted_to_bf16 = False self.generate_during_eval = getattr(args, 'generate_during_eval', False) - if self.is_encoder_decoder: - self.decoder_start_token_id = HfConfigFactory.get_config_attr(model.config, 'decoder_start_token_id') - self.pad_token_id = HfConfigFactory.get_config_attr(model.config, 'pad_token_id') + + self._prepare_model(args, model) # not use self.is_vision_model = False self.label_pad_token_id = -100 self.use_dpo_data_collator = True + self.aux_loss_coef = args.router_aux_loss_coef super().__init__(model, *_args, **kwargs) + self._prepare_ref_model(args, ref_model) + self.padding_value = self.tokenizer.pad_token_id + + @RayHelper.function(group='default') + def _prepare_model(self, args, model): + from trl.trainer import disable_dropout_in_model + from swift.llm import HfConfigFactory + if getattr(args, 'disable_dropout', False): + if model is not None: + disable_dropout_in_model(model) + self.aux_loss_enabled = model.model_info.is_moe_model and args.router_aux_loss_coef > 0 - self.aux_loss_coef = args.router_aux_loss_coef + self.is_peft_model = isinstance(model, PeftModel) + if self.is_encoder_decoder: + self.decoder_start_token_id = HfConfigFactory.get_config_attr(model.config, 'decoder_start_token_id') + self.pad_token_id = HfConfigFactory.get_config_attr(model.config, 'pad_token_id') + + @RayHelper.function(group='ref') + def _prepare_ref_model(self, args, ref_model): + from trl.trainer import disable_dropout_in_model if ref_model is not None: + if getattr(args, 'disable_dropout', False): + disable_dropout_in_model(ref_model) if self.is_deepspeed_enabled: - self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + self.ref_model = prepare_deepspeed(ref_model, self.accelerator) else: - self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - - self.padding_value = self.tokenizer.pad_token_id + self.ref_model = self.accelerator.prepare_model(ref_model, evaluation_mode=True) def create_loss_and_metric(self, args): return {}