Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
PADDLE_WEIGHTS_INDEX_NAME,
PADDLE_WEIGHTS_NAME,
PREFIX_CHECKPOINT_DIR,
PREFIX_HF_CHECKPOINT_DIR,
PREFIX_WEIGHTS_NAME,
SAFE_MASTER_WEIGHTS_INDEX_NAME,
SAFE_PEFT_WEIGHTS_INDEX_NAME,
Expand Down Expand Up @@ -3053,28 +3054,30 @@ def _sorted_checkpoints(
def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
return
for checkpoint_prefix in [PREFIX_CHECKPOINT_DIR, PREFIX_HF_CHECKPOINT_DIR]:
# Check if we should delete older checkpoint(s)
checkpoints_sorted = self._sorted_checkpoints(
use_mtime=use_mtime, checkpoint_prefix=checkpoint_prefix, output_dir=output_dir
)
if len(checkpoints_sorted) <= self.args.save_total_limit:
return

# Check if we should delete older checkpoint(s)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
if len(checkpoints_sorted) <= self.args.save_total_limit:
return

# If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
# we don't do to allow resuming.
save_total_limit = self.args.save_total_limit
if (
self.state.best_model_checkpoint is not None
and self.args.save_total_limit == 1
and checkpoints_sorted[-1] != self.state.best_model_checkpoint
):
save_total_limit = 2

number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted:
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
# ignore_errors for shared disks between train nodes.
shutil.rmtree(checkpoint, ignore_errors=True)
# If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
# we don't do to allow resuming.
save_total_limit = self.args.save_total_limit
if (
self.state.best_model_checkpoint is not None
and self.args.save_total_limit == 1
and checkpoints_sorted[-1] != self.state.best_model_checkpoint
):
save_total_limit = 2

number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted:
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
# ignore_errors for shared disks between train nodes.
shutil.rmtree(checkpoint, ignore_errors=True)

def _save(
self,
Expand Down
21 changes: 21 additions & 0 deletions paddlenlp/trainer/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class TrainerControl:
should_save: bool = False
should_evaluate: bool = False
should_log: bool = False
should_save_hf: bool = False

def _new_training(self):
"""Internal method that resets the variable for a new training."""
Expand All @@ -171,6 +172,7 @@ def _new_step(self):
self.should_save = False
self.should_evaluate = False
self.should_log = False
self.should_save_hf = False


class TrainerCallback:
Expand Down Expand Up @@ -306,6 +308,12 @@ def on_prediction_step(self, args: TrainingArguments, state: TrainerState, contr
"""
pass

def on_save_hf(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called after a huggingface checkpoint save.
"""
pass


class CallbackHandler(TrainerCallback):
"""Internal class that just calls the list of callbacks in order."""
Expand Down Expand Up @@ -386,6 +394,7 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T
control.should_log = False
control.should_evaluate = False
control.should_save = False
control.should_save_hf = False
return self.call_event("on_step_begin", args, state, control)

def on_load_data_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, inputs: Dict):
Expand Down Expand Up @@ -418,6 +427,10 @@ def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerC
def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_prediction_step", args, state, control)

def on_save_hf(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
control.should_save_hf = False
return self.call_event("on_save_hf", args, state, control)

def call_event(self, event, args, state, control, **kwargs):
for callback in self.callbacks:
result = getattr(callback, event)(
Expand Down Expand Up @@ -474,6 +487,14 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
if state.global_step >= state.max_steps:
control.should_training_stop = True

# Save hf
if (
args.save_strategy == IntervalStrategy.STEPS
and args.save_hf_steps > 0
and state.global_step % args.save_hf_steps == 0
):
control.should_save_hf = True

return control

def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
Expand Down
204 changes: 203 additions & 1 deletion paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,18 @@
import math
import os
import random
import re
import threading
import time
from contextlib import contextmanager
from enum import Enum
from pathlib import Path
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple, Union

import numpy as np
import paddle
import paddle.distributed as dist
from paddle import Tensor
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
Expand All @@ -44,6 +46,8 @@
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.io import IterableDataset
from paddle.optimizer.lr import LambdaDecay
from safetensors import safe_open
from safetensors.paddle import save_file

from paddlenlp.ops import Topology

Expand Down Expand Up @@ -1445,3 +1449,201 @@ def buffer_params():
continue
param_list.append(param)
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list)


def _parse_size(size_str: str) -> int:
"""Parses a size string like '100MB', '2GB' into the number of bytes."""
size_str = size_str.upper().strip()
match = re.match(r"^(\d+\.?\d*)\s*(B|KB|MB|GB|TB)?$", size_str)
if not match:
raise ValueError(f"Could not parse size string: '{size_str}'")

num_str, unit = match.groups()
num = float(num_str)

if unit == "B" or unit is None:
return int(num)
elif unit == "KB":
return int(num * 1024)
elif unit == "MB":
return int(num * 1024**2)
elif unit == "GB":
return int(num * 1024**3)
elif unit == "TB":
return int(num * 1024**4)
else:
# This case should not be reached due to regex
raise ValueError(f"Unknown unit: '{unit}'")


def save_full_param(
itr: Iterator[tuple[str, Tensor]],
save_dir: str,
rank: int,
moe_sharding_world_size: int,
max_shard_size: str = "2GB",
num_saver_ranks: int = 8,
) -> None:
"""
Saves model weights from an iterator into shards, supporting max shard size
and a limited number of saver ranks.

Only ranks less than `num_saver_ranks` will perform disk I/O. All other ranks
will iterate through the data to maintain synchronization but will not save.
The parameter distribution logic is based on `num_saver_ranks`, ensuring all
parameters are handled by a designated saver rank.

Args:
itr (Iterator): An iterator that yields (param_key, param_tensor).
save_dir (str): The directory where shard files will be saved.
rank (int): The rank of the current process.
moe_sharding_world_size (int): The total number of processes.
max_shard_size (str): The maximum size for each shard file, e.g., "500MB", "2GB".
num_saver_ranks (int): The number of ranks (starting from 0) that will save files.
"""

# 1. Non-saver ranks simply consume the iterator to stay in sync.
if rank >= num_saver_ranks:
logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Non-saver) Consuming iterator for synchronization...")
for _ in itr:
pass
logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Non-saver) Iterator consumption complete.")
return

max_shard_size_bytes = _parse_size(max_shard_size)
logger.info(
f"[Rank {rank}/{moe_sharding_world_size}] (Saver) Initializing save. "
f"Max shard size set to: {max_shard_size_bytes / 1024**3:.2f} GB"
)

os.makedirs(save_dir, exist_ok=True)

current_shard_state_dict = {}
current_shard_size_bytes = 0
sub_shard_index = 0

def _save_current_shard():
nonlocal sub_shard_index, current_shard_state_dict, current_shard_size_bytes
if not current_shard_state_dict:
return

# Filename includes the main shard number (rank) and the sub-shard index
cur_rank = paddle.distributed.get_rank()
shard_filename = f"shard_{cur_rank}-{sub_shard_index}.safetensors"
save_path = os.path.join(save_dir, shard_filename)

logger.info(
f"[Rank {rank}/{moe_sharding_world_size}] Saving sub-shard {sub_shard_index}... "
f"Size: {current_shard_size_bytes / 1024**2:.2f} MB, "
f"Params: {len(current_shard_state_dict)}, "
f"Path: {save_path}"
)

save_file(current_shard_state_dict, save_path)

# Reset for the next shard
sub_shard_index += 1
current_shard_state_dict = {}
current_shard_size_bytes = 0

logger.info(f"[Rank {rank}/{moe_sharding_world_size}] Starting to process the weight iterator...")

total_size = 0

for i, (param_key, param) in enumerate(itr):
param_size_bytes = param.numel() * param.element_size()
total_size += param_size_bytes.item()
if i % num_saver_ranks == rank:
if current_shard_size_bytes > 0 and (current_shard_size_bytes + param_size_bytes > max_shard_size_bytes):
_save_current_shard()

current_shard_state_dict[param_key] = param
current_shard_size_bytes += param_size_bytes

if current_shard_size_bytes >= max_shard_size_bytes:
_save_current_shard()
_save_current_shard()
logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Saver) All shards saved successfully.")
return total_size


def replace_name_and_gen_index(path, cur_rank_total_size):
index_mapping = {}
cur_rank = paddle.distributed.get_rank()
safetensor_files = [fname for fname in os.listdir(path) if fname.endswith(".safetensors")]
files_num = len(safetensor_files)
all_files_num = []
paddle.distributed.all_gather_object(all_files_num, files_num)
total_files_num = sum(all_files_num)

all_sizes = []
paddle.distributed.all_gather_object(all_sizes, cur_rank_total_size)
total_size = sum(all_sizes)

start_idx = []
acc = 1
for files_num in all_files_num:
start_idx.append(acc)
acc += files_num

env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
env_local_size = int(os.environ.get("PADDLE_LOCAL_SIZE", 8))
assert env_local_rank >= 0

cur_file_index = start_idx[cur_rank] // env_local_size
total_files_num = total_files_num // env_local_size

total_size = total_size // env_local_size

index_mapping = {}
if env_local_rank == 0:
for file in safetensor_files:
cur_file_index += 1
file_path = os.path.join(path, file)
new_file_name = f"model-{cur_file_index:05d}-of-{total_files_num:05d}.safetensors"
with safe_open(file_path, framework="np") as f:
for key in f.keys():
index_mapping[key] = new_file_name
new_file_path = os.path.join(path, new_file_name)
os.rename(file_path, new_file_path)

index_mapping_list = []
paddle.distributed.all_gather_object(index_mapping_list, index_mapping)
index_mapping = {}
for mapping in index_mapping_list:
index_mapping.update(mapping)

if env_local_rank == 0:
index_file_name = "model.safetensors.index.json"
index_infos = {}
index_infos["metadata"] = {}
index_infos["metadata"]["total_size"] = total_size
index_infos["weight_map"] = dict(sorted(index_mapping.items()))
with open(os.path.join(path, index_file_name), "w") as f:
json.dump(index_infos, f, indent=4)


def save_hf_checkpoint(
model,
aoa_config,
h_group,
v_group,
num_splits,
shard_idx,
path,
):
itr = model.full(
aoa_config=aoa_config, h_group=h_group, v_group=v_group, num_splits=num_splits, shard_idx=shard_idx
)
num_saver_ranks = h_group.nranks * v_group.nranks
rank = h_group.rank + v_group.rank * h_group.nranks
total_saved_size = save_full_param(
itr=itr,
save_dir=path,
rank=rank,
moe_sharding_world_size=num_saver_ranks,
max_shard_size="16GB",
num_saver_ranks=num_saver_ranks,
)
paddle.distributed.barrier()
replace_name_and_gen_index(path, total_saved_size)
4 changes: 4 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ class TrainingArguments:
Specifies the format for loading checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint'. (default: None). This setting is ignored if the corresponding switch is configured.
aoa_config (`Optional[dict[str, list[str]]]`, *optional*):
The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None.
save_hf_steps (`int`, *optional*, defaults to 500):
Number of updates steps before two huggingface checkpoint saves if `save_strategy="steps"`.
"""

output_dir: str = field(
Expand Down Expand Up @@ -1142,6 +1144,8 @@ class TrainingArguments:
},
)

save_hf_steps: int = field(default=-1, metadata={"help": "Save huggingface checkpoint every X updates steps."})

def __post_init__(self):
world_size = paddle.distributed.get_world_size()
if in_auto_parallel_align_mode():
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,6 @@ def _get_bool_env(env_key: str, default_value: str) -> bool:

USE_FAST_TOKENIZER: bool = _get_bool_env("USE_FAST_TOKENIZER", "false")
PREFILL_USE_SAGE_ATTN: bool = _get_bool_env("PREFILL_USE_SAGE_ATTN", "false")

# hf checkpoint dir name
PREFIX_HF_CHECKPOINT_DIR = "hf_checkpoint"
Loading