Skip to content
Open
19 changes: 10 additions & 9 deletions docs/source/Instruction/ray的支持.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@

SWIFT已经支持使用ray来进行多卡或多节点训练。已有功能中对ray的支持情况如下:

| 功能 | 支持ray | 例子 | 可分配角色 |
|----------|-------|--------------------------------------------------------------------------------|-----------------|
| pt/sft | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray | default |
| dpo | ❎ | | |
| grpo | ❎ | | |
| ppo | ❎ | | |
| megatron | ❎ | | |
| sampling | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/distill | sampler/prm/orm |
| distill | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/sample | sampler/prm/orm |
| 功能 | 支持ray | 例子 | 可分配角色 |
|----------|-------|---------------------------------------------------------------------------------------|------------------|
| pt | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/pt.sh | pt:default |
| sft | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/sft.sh | sft:default |
| dpo | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/dpo.sh | rlhf:default/ref |
| grpo | ❎ | | |
| ppo | ❎ | | |
| megatron | ❎ | | |
| sampling | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/distill | sampler/prm/orm |
| distill | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/sample | sampler/prm/orm |

## 技术细节

Expand Down
23 changes: 12 additions & 11 deletions docs/source_en/Instruction/Ray.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@

SWIFT already supports using Ray for multi-GPU or multi-node training. The support status for Ray in existing features is as follows:

| Feature | Ray Support | Example | Assignable Roles |
|----------|-------------|--------------------------------------------------------------------------------|------------------|
| pt/sft | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray | default |
| dpo | ❎ | | |
| grpo | ❎ | | |
| ppo | ❎ | | |
| megatron | ❎ | | |
| sampling | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/distill | sampler/prm/orm |
| distill | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/sample | sampler/prm/orm |

## Technical Details
| Feature | Ray Support | Example | Assignable Roles |
|----------|-------------|---------------------------------------------------------------------------------------|------------------|
| pt | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/pt.sh | pt:default |
| sft | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/sft.sh | sft:default |
| dpo | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/ray/dpo.sh | rlhf:default/ref |
| grpo | ❎ | | |
| ppo | ❎ | | |
| megatron | ❎ | | |
| sampling | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/distill | sampler/prm/orm |
| distill | ✅ | https://github.com/modelscope/ms-swift/tree/main/examples/sampler/sample | sampler/prm/orm |

## Technical Detailsp

Before describing parameter settings, it's necessary to first explain the technical details. Since SWIFT currently uses many existing implementations from transformers and trl internally, decomposing into different Ray roles like veRL or ROLL is impractical, and decomposition would center around Ray, resulting in poor support for non-Ray scenarios.

Expand Down
1 change: 1 addition & 0 deletions examples/train/multi-node/ray/dpo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
swift rlhf --config dpo.yaml
39 changes: 39 additions & 0 deletions examples/train/multi-node/ray/dpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
rlhf_type: dpo
model: Qwen/Qwen2.5-VL-3B-Instruct
ref_model: Qwen/Qwen2.5-VL-3B-Instruct
train_type: full
dataset: swift/RLAIF-V-Dataset#1000
load_from_cache_file: true
split_dataset_ratio: 0.01
torch_dtype: bfloat16
num_train_epochs: 1
per_device_train_batch_size: 4
per_device_eval_batch_size: 4
learning_rate: 1e-4
gradient_accumulation_steps: 2
eval_steps: 1
save_steps: 1
save_total_limit: 2
logging_steps: 5
max_length: 2048
output_dir: output
warmup_ratio: 0.05
dataloader_num_workers: 4
rpo_alpha: 0.1
dataset_num_proc: 4

use_ray: true

# Ranks of rlhf:default and ref must equal
device_groups:
nproc_per_node: 4
sample_group:
device: GPU
ranks: list(range(0, 2))
workers:
- rlhf:default
rm_group:
device: GPU
ranks: list(range(2, 4))
workers:
- ref
1 change: 1 addition & 0 deletions examples/train/multi-node/ray/pt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
swift sft --config pt.yaml
33 changes: 33 additions & 0 deletions examples/train/multi-node/ray/pt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
model: Qwen/Qwen2.5-7B
train_type: full
dataset: swift/chinese-c4#2000
torch_dtype: bfloat16
streaming: true
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
learning_rate: 1e-5
gradient_accumulation_steps: 2
packing: true
eval_steps: 500
save_steps: 500
save_total_limit: 2
logging_steps: 5
deepspeed: zero3
max_length: 8192
max_steps: 10000
warmup_ratio: 0.05
dataloader_num_workers: 4
dataset_num_proc: 8
save_only_model: true
output_dir: output/Qwen2.5-7B
attn_impl: flash_attn

use_ray: true

device_groups:
nproc_per_node: 4
default:
device: GPU
ranks: list(range(0, 4))
workers:
- pt:default
2 changes: 1 addition & 1 deletion examples/train/multi-node/ray/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ device_groups:
device: GPU
ranks: list(range(0, 4))
workers:
- default
- sft:default
4 changes: 3 additions & 1 deletion swift/cli/pt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.llm import pt_main

if __name__ == '__main__':
from swift.ray import try_init_ray
try_init_ray()
from swift.llm import pt_main
pt_main()
4 changes: 3 additions & 1 deletion swift/cli/rlhf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.llm import rlhf_main

if __name__ == '__main__':
from swift.ray import try_init_ray
try_init_ray()
from swift.llm import rlhf_main
rlhf_main()
2 changes: 2 additions & 0 deletions swift/llm/train/pt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import List, Optional, Union

from swift.ray import RayHelper
from swift.utils import get_logger
from ..argument import TrainArguments
from .sft import SwiftSft

logger = get_logger()


@RayHelper.worker(group=['pt:default'])
class SwiftPt(SwiftSft):
args_class = TrainArguments
args: args_class
Expand Down
115 changes: 91 additions & 24 deletions swift/llm/train/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from swift.llm import safe_snapshot_download
from swift.plugin import Tuner, extra_tuners
from swift.ray import RayHelper
from swift.trainers import TrainerFactory
from swift.tuners import Swift
from swift.utils import get_logger, get_model_parameter_info
from swift.utils.utils import disable_deepspeed_zero3
Expand All @@ -16,10 +18,20 @@
logger = get_logger()


@RayHelper.worker(group=['rlhf:default', 'ref', 'reward', 'value', 'teacher'])
class SwiftRLHF(SwiftSft):
args_class = RLHFArguments
args: args_class

def __init__(self, args: RLHFArguments):
self.model = None
self.callbacks = []
super().__init__(args)
self.reward_model = []
if self.args.rlhf_type == 'grpo':
self.reward_template = []
Comment on lines +29 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The initialization of self.reward_model and self.reward_template happens after super().__init__(args) is called. The super().__init__ call chain eventually triggers remote calls to methods like _prepare_reward_model, which attempt to use these attributes on the worker instances. Since they are not yet initialized on the workers when these remote methods are called, this will lead to an AttributeError.

To fix this, these attributes should be initialized before the super().__init__(args) call. Also, note that self.args is not available before super().__init__, so you should use the args from the __init__ method signature for the conditional initialization of self.reward_template.

Suggested change
super().__init__(args)
self.reward_model = []
if self.args.rlhf_type == 'grpo':
self.reward_template = []
self.reward_model = []
if args.rlhf_type == 'grpo':
self.reward_template = []
super().__init__(args)

self._prepare_trainer()

@staticmethod
def _get_model_task_type(model_dir):
task_type = None
Expand Down Expand Up @@ -48,6 +60,47 @@ def _get_model_task_type(model_dir):
task_type = 'seq_cls'
return task_type, num_labels

@RayHelper.function(group='ref')
def _prepare_ref_model(self, key, origin_key, model_type, model_revision):
result = self._prepare_single_model(key, origin_key, model_type, model_revision)

if result is not None:
self.ref_model = result[0]

@RayHelper.function(group='value')
def _prepare_value_model(self, key, origin_key, model_type, model_revision):
result = self._prepare_single_model(key, origin_key, model_type, model_revision)

if result is not None:
self.value_model = result[0]

@RayHelper.function(group='teacher')
def _prepare_teacher_model(self, key, origin_key, model_type, model_revision):
result = self._prepare_single_model(key, origin_key, model_type, model_revision)

if result is not None:
self.teacher_model = result[0]

def _prepare_reward_model(self, reward_model_path, key, origin_key, model_type, model_revision):
rms = self.args.reward_model if isinstance(self.args.reward_model, list) else [self.args.reward_model]
self.args.reward_model = reward_model_path # Temporarily set for prepare_single_model
result = self._prepare_single_model(key, origin_key, model_type, model_revision)

if result is not None:
model, processor = result
self.reward_model.append(model)

if self.args.rlhf_type == 'grpo':
reward_template = self.args.get_template(processor, processor.model_meta.template)
if reward_template.use_model:
reward_template.model = model
self.reward_template.append(reward_template)
self.args.reward_model = rms # Restore original value

if self.args.rlhf_type != 'grpo' and self.reward_model:
assert len(self.reward_model) <= 1
self.reward_model = self.reward_model[0]

def _prepare_single_model(self, key, origin_key, model_type, model_revision):
from swift.llm.infer.utils import prepare_adapter
args = self.args
Expand Down Expand Up @@ -116,10 +169,7 @@ def _prepare_model_tokenizer(self):
model_type = model_type[0] if model_type else None
model_revision = model_revision[0] if model_revision else None

result = self._prepare_single_model(model_key, key, model_type, model_revision)
if result is not None:
model, _ = result
setattr(self, f'{key}_model', model)
getattr(self, f'_prepare_{key}_model')(model_key, key, model_type, model_revision)

# Handle reward model(s)
self.reward_model = None
Expand All @@ -130,26 +180,9 @@ def _prepare_model_tokenizer(self):
rm_revisions = args.reward_model_revision if args.reward_model_revision else [None] * num_rms
assert len(rms) == len(rm_types) == len(rm_revisions)

self.reward_model = []
if args.rlhf_type == 'grpo':
self.reward_template = []

for reward_model_path, rm_type, rm_revision in zip(rms, rm_types, rm_revisions):
args.reward_model = reward_model_path # Temporarily set for prepare_single_model
result = self._prepare_single_model('reward', None, rm_type, rm_revision)
if result is not None:
model, processor = result
self.reward_model.append(model)

if args.rlhf_type == 'grpo':
reward_template = self.args.get_template(processor, processor.model_meta.template)
if reward_template.use_model:
reward_template.model = model
self.reward_template.append(reward_template)
args.reward_model = rms # Restore original value
if args.rlhf_type != 'grpo' and self.reward_model:
assert len(self.reward_model) <= 1
self.reward_model = self.reward_model[0]
for rm_idx, (reward_model_path, rm_type, rm_revision) in enumerate(zip(rms, rm_types, rm_revisions)):
_prepare_reward_model = RayHelper.function(group=f'reward_{rm_idx}')(self._prepare_reward_model)
_prepare_reward_model(reward_model_path, 'reward', None, rm_type, rm_revision)

super()._prepare_model_tokenizer()

Expand Down Expand Up @@ -222,6 +255,40 @@ def _get_trainer_kwargs(self):
trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed
return trainer_kwargs

@RayHelper.function(group='default')
def _add_adapter_to_model(self, train_dataset):
# Some tuners require train_dataset and data_collator for preparation: LoRA-GA
self.model = self.prepare_model(self.args, self.model, template=self.template, train_dataset=train_dataset)
logger.info(f'model: {self.model}')
model_parameter_info = get_model_parameter_info(self.model)
self.train_msg['model_parameter_info'] = model_parameter_info
logger.info(f'model_parameter_info: {model_parameter_info}')

def _prepare_trainer(self):
args = self.args
train_dataset, val_dataset = self._prepare_dataset()
args.save_args()

data_collator = self._get_data_collator()
self._add_adapter_to_model(train_dataset)

trainer_cls = TrainerFactory.get_trainer_cls(args)
self.args.training_args.ref_model = self.args.ref_model
self.trainer = trainer_cls(
model=self.model,
args=self.args.training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=val_dataset,
callbacks=self.callbacks,
template=self.template,
**self._get_trainer_kwargs(),
)

@RayHelper.function(group='default')
def run(self):
return self.train(self.trainer)


def rlhf_main(args: Optional[Union[List[str], RLHFArguments]] = None):
return SwiftRLHF(args).main()
Loading
Loading