diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 470b4b12039f..63547972124d 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -130,6 +130,7 @@ PADDLE_WEIGHTS_INDEX_NAME, PADDLE_WEIGHTS_NAME, PREFIX_CHECKPOINT_DIR, + PREFIX_HF_CHECKPOINT_DIR, PREFIX_WEIGHTS_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME, @@ -3053,28 +3054,30 @@ def _sorted_checkpoints( def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: if self.args.save_total_limit is None or self.args.save_total_limit <= 0: return + for checkpoint_prefix in [PREFIX_CHECKPOINT_DIR, PREFIX_HF_CHECKPOINT_DIR]: + # Check if we should delete older checkpoint(s) + checkpoints_sorted = self._sorted_checkpoints( + use_mtime=use_mtime, checkpoint_prefix=checkpoint_prefix, output_dir=output_dir + ) + if len(checkpoints_sorted) <= self.args.save_total_limit: + return - # Check if we should delete older checkpoint(s) - checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir) - if len(checkpoints_sorted) <= self.args.save_total_limit: - return - - # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which - # we don't do to allow resuming. - save_total_limit = self.args.save_total_limit - if ( - self.state.best_model_checkpoint is not None - and self.args.save_total_limit == 1 - and checkpoints_sorted[-1] != self.state.best_model_checkpoint - ): - save_total_limit = 2 - - number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) - checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] - for checkpoint in checkpoints_to_be_deleted: - logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") - # ignore_errors for shared disks between train nodes. - shutil.rmtree(checkpoint, ignore_errors=True) + # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which + # we don't do to allow resuming. + save_total_limit = self.args.save_total_limit + if ( + self.state.best_model_checkpoint is not None + and self.args.save_total_limit == 1 + and checkpoints_sorted[-1] != self.state.best_model_checkpoint + ): + save_total_limit = 2 + + number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) + checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] + for checkpoint in checkpoints_to_be_deleted: + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + # ignore_errors for shared disks between train nodes. + shutil.rmtree(checkpoint, ignore_errors=True) def _save( self, diff --git a/paddlenlp/trainer/trainer_callback.py b/paddlenlp/trainer/trainer_callback.py index 190fd2adecd1..aefeeb8725da 100644 --- a/paddlenlp/trainer/trainer_callback.py +++ b/paddlenlp/trainer/trainer_callback.py @@ -157,6 +157,7 @@ class TrainerControl: should_save: bool = False should_evaluate: bool = False should_log: bool = False + should_save_hf: bool = False def _new_training(self): """Internal method that resets the variable for a new training.""" @@ -171,6 +172,7 @@ def _new_step(self): self.should_save = False self.should_evaluate = False self.should_log = False + self.should_save_hf = False class TrainerCallback: @@ -306,6 +308,12 @@ def on_prediction_step(self, args: TrainingArguments, state: TrainerState, contr """ pass + def on_save_hf(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after a huggingface checkpoint save. + """ + pass + class CallbackHandler(TrainerCallback): """Internal class that just calls the list of callbacks in order.""" @@ -386,6 +394,7 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T control.should_log = False control.should_evaluate = False control.should_save = False + control.should_save_hf = False return self.call_event("on_step_begin", args, state, control) def on_load_data_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, inputs: Dict): @@ -418,6 +427,10 @@ def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerC def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): return self.call_event("on_prediction_step", args, state, control) + def on_save_hf(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + control.should_save_hf = False + return self.call_event("on_save_hf", args, state, control) + def call_event(self, event, args, state, control, **kwargs): for callback in self.callbacks: result = getattr(callback, event)( @@ -474,6 +487,14 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra if state.global_step >= state.max_steps: control.should_training_stop = True + # Save hf + if ( + args.save_strategy == IntervalStrategy.STEPS + and args.save_hf_steps > 0 + and state.global_step % args.save_hf_steps == 0 + ): + control.should_save_hf = True + return control def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index 0b9fa9ea5c16..94bb836a0139 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -26,16 +26,18 @@ import math import os import random +import re import threading import time from contextlib import contextmanager from enum import Enum from pathlib import Path -from typing import Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple, Union import numpy as np import paddle import paddle.distributed as dist +from paddle import Tensor from paddle.distributed import fleet from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( DygraphShardingOptimizer, @@ -44,6 +46,8 @@ from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.io import IterableDataset from paddle.optimizer.lr import LambdaDecay +from safetensors import safe_open +from safetensors.paddle import save_file from paddlenlp.ops import Topology @@ -1445,3 +1449,201 @@ def buffer_params(): continue param_list.append(param) optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list) + + +def _parse_size(size_str: str) -> int: + """Parses a size string like '100MB', '2GB' into the number of bytes.""" + size_str = size_str.upper().strip() + match = re.match(r"^(\d+\.?\d*)\s*(B|KB|MB|GB|TB)?$", size_str) + if not match: + raise ValueError(f"Could not parse size string: '{size_str}'") + + num_str, unit = match.groups() + num = float(num_str) + + if unit == "B" or unit is None: + return int(num) + elif unit == "KB": + return int(num * 1024) + elif unit == "MB": + return int(num * 1024**2) + elif unit == "GB": + return int(num * 1024**3) + elif unit == "TB": + return int(num * 1024**4) + else: + # This case should not be reached due to regex + raise ValueError(f"Unknown unit: '{unit}'") + + +def save_full_param( + itr: Iterator[tuple[str, Tensor]], + save_dir: str, + rank: int, + moe_sharding_world_size: int, + max_shard_size: str = "2GB", + num_saver_ranks: int = 8, +) -> None: + """ + Saves model weights from an iterator into shards, supporting max shard size + and a limited number of saver ranks. + + Only ranks less than `num_saver_ranks` will perform disk I/O. All other ranks + will iterate through the data to maintain synchronization but will not save. + The parameter distribution logic is based on `num_saver_ranks`, ensuring all + parameters are handled by a designated saver rank. + + Args: + itr (Iterator): An iterator that yields (param_key, param_tensor). + save_dir (str): The directory where shard files will be saved. + rank (int): The rank of the current process. + moe_sharding_world_size (int): The total number of processes. + max_shard_size (str): The maximum size for each shard file, e.g., "500MB", "2GB". + num_saver_ranks (int): The number of ranks (starting from 0) that will save files. + """ + + # 1. Non-saver ranks simply consume the iterator to stay in sync. + if rank >= num_saver_ranks: + logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Non-saver) Consuming iterator for synchronization...") + for _ in itr: + pass + logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Non-saver) Iterator consumption complete.") + return + + max_shard_size_bytes = _parse_size(max_shard_size) + logger.info( + f"[Rank {rank}/{moe_sharding_world_size}] (Saver) Initializing save. " + f"Max shard size set to: {max_shard_size_bytes / 1024**3:.2f} GB" + ) + + os.makedirs(save_dir, exist_ok=True) + + current_shard_state_dict = {} + current_shard_size_bytes = 0 + sub_shard_index = 0 + + def _save_current_shard(): + nonlocal sub_shard_index, current_shard_state_dict, current_shard_size_bytes + if not current_shard_state_dict: + return + + # Filename includes the main shard number (rank) and the sub-shard index + cur_rank = paddle.distributed.get_rank() + shard_filename = f"shard_{cur_rank}-{sub_shard_index}.safetensors" + save_path = os.path.join(save_dir, shard_filename) + + logger.info( + f"[Rank {rank}/{moe_sharding_world_size}] Saving sub-shard {sub_shard_index}... " + f"Size: {current_shard_size_bytes / 1024**2:.2f} MB, " + f"Params: {len(current_shard_state_dict)}, " + f"Path: {save_path}" + ) + + save_file(current_shard_state_dict, save_path) + + # Reset for the next shard + sub_shard_index += 1 + current_shard_state_dict = {} + current_shard_size_bytes = 0 + + logger.info(f"[Rank {rank}/{moe_sharding_world_size}] Starting to process the weight iterator...") + + total_size = 0 + + for i, (param_key, param) in enumerate(itr): + param_size_bytes = param.numel() * param.element_size() + total_size += param_size_bytes.item() + if i % num_saver_ranks == rank: + if current_shard_size_bytes > 0 and (current_shard_size_bytes + param_size_bytes > max_shard_size_bytes): + _save_current_shard() + + current_shard_state_dict[param_key] = param + current_shard_size_bytes += param_size_bytes + + if current_shard_size_bytes >= max_shard_size_bytes: + _save_current_shard() + _save_current_shard() + logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Saver) All shards saved successfully.") + return total_size + + +def replace_name_and_gen_index(path, cur_rank_total_size): + index_mapping = {} + cur_rank = paddle.distributed.get_rank() + safetensor_files = [fname for fname in os.listdir(path) if fname.endswith(".safetensors")] + files_num = len(safetensor_files) + all_files_num = [] + paddle.distributed.all_gather_object(all_files_num, files_num) + total_files_num = sum(all_files_num) + + all_sizes = [] + paddle.distributed.all_gather_object(all_sizes, cur_rank_total_size) + total_size = sum(all_sizes) + + start_idx = [] + acc = 1 + for files_num in all_files_num: + start_idx.append(acc) + acc += files_num + + env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1)) + env_local_size = int(os.environ.get("PADDLE_LOCAL_SIZE", 8)) + assert env_local_rank >= 0 + + cur_file_index = start_idx[cur_rank] // env_local_size + total_files_num = total_files_num // env_local_size + + total_size = total_size // env_local_size + + index_mapping = {} + if env_local_rank == 0: + for file in safetensor_files: + cur_file_index += 1 + file_path = os.path.join(path, file) + new_file_name = f"model-{cur_file_index:05d}-of-{total_files_num:05d}.safetensors" + with safe_open(file_path, framework="np") as f: + for key in f.keys(): + index_mapping[key] = new_file_name + new_file_path = os.path.join(path, new_file_name) + os.rename(file_path, new_file_path) + + index_mapping_list = [] + paddle.distributed.all_gather_object(index_mapping_list, index_mapping) + index_mapping = {} + for mapping in index_mapping_list: + index_mapping.update(mapping) + + if env_local_rank == 0: + index_file_name = "model.safetensors.index.json" + index_infos = {} + index_infos["metadata"] = {} + index_infos["metadata"]["total_size"] = total_size + index_infos["weight_map"] = dict(sorted(index_mapping.items())) + with open(os.path.join(path, index_file_name), "w") as f: + json.dump(index_infos, f, indent=4) + + +def save_hf_checkpoint( + model, + aoa_config, + h_group, + v_group, + num_splits, + shard_idx, + path, +): + itr = model.full( + aoa_config=aoa_config, h_group=h_group, v_group=v_group, num_splits=num_splits, shard_idx=shard_idx + ) + num_saver_ranks = h_group.nranks * v_group.nranks + rank = h_group.rank + v_group.rank * h_group.nranks + total_saved_size = save_full_param( + itr=itr, + save_dir=path, + rank=rank, + moe_sharding_world_size=num_saver_ranks, + max_shard_size="16GB", + num_saver_ranks=num_saver_ranks, + ) + paddle.distributed.barrier() + replace_name_and_gen_index(path, total_saved_size) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 7599c4d49cc8..b29f9303aac3 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -413,6 +413,8 @@ class TrainingArguments: Specifies the format for loading checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint'. (default: None). This setting is ignored if the corresponding switch is configured. aoa_config (`Optional[dict[str, list[str]]]`, *optional*): The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None. + save_hf_steps (`int`, *optional*, defaults to 500): + Number of updates steps before two huggingface checkpoint saves if `save_strategy="steps"`. """ output_dir: str = field( @@ -1142,6 +1144,8 @@ class TrainingArguments: }, ) + save_hf_steps: int = field(default=-1, metadata={"help": "Save huggingface checkpoint every X updates steps."}) + def __post_init__(self): world_size = paddle.distributed.get_world_size() if in_auto_parallel_align_mode(): diff --git a/paddlenlp/utils/env.py b/paddlenlp/utils/env.py index 0489d6b6d6cc..eea1c51317fe 100644 --- a/paddlenlp/utils/env.py +++ b/paddlenlp/utils/env.py @@ -166,3 +166,6 @@ def _get_bool_env(env_key: str, default_value: str) -> bool: USE_FAST_TOKENIZER: bool = _get_bool_env("USE_FAST_TOKENIZER", "false") PREFILL_USE_SAGE_ATTN: bool = _get_bool_env("PREFILL_USE_SAGE_ATTN", "false") + +# hf checkpoint dir name +PREFIX_HF_CHECKPOINT_DIR = "hf_checkpoint"