Skip to content

Commit e4d28de

Browse files
authored
【FlexCheckpoint】support save hf (#11180)
* support save hf * fix
1 parent 53b85cb commit e4d28de

File tree

5 files changed

+255
-22
lines changed

5 files changed

+255
-22
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
PADDLE_WEIGHTS_INDEX_NAME,
131131
PADDLE_WEIGHTS_NAME,
132132
PREFIX_CHECKPOINT_DIR,
133+
PREFIX_HF_CHECKPOINT_DIR,
133134
PREFIX_WEIGHTS_NAME,
134135
SAFE_MASTER_WEIGHTS_INDEX_NAME,
135136
SAFE_PEFT_WEIGHTS_INDEX_NAME,
@@ -3059,28 +3060,30 @@ def _sorted_checkpoints(
30593060
def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
30603061
if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
30613062
return
3063+
for checkpoint_prefix in [PREFIX_CHECKPOINT_DIR, PREFIX_HF_CHECKPOINT_DIR]:
3064+
# Check if we should delete older checkpoint(s)
3065+
checkpoints_sorted = self._sorted_checkpoints(
3066+
use_mtime=use_mtime, checkpoint_prefix=checkpoint_prefix, output_dir=output_dir
3067+
)
3068+
if len(checkpoints_sorted) <= self.args.save_total_limit:
3069+
return
30623070

3063-
# Check if we should delete older checkpoint(s)
3064-
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
3065-
if len(checkpoints_sorted) <= self.args.save_total_limit:
3066-
return
3067-
3068-
# If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
3069-
# we don't do to allow resuming.
3070-
save_total_limit = self.args.save_total_limit
3071-
if (
3072-
self.state.best_model_checkpoint is not None
3073-
and self.args.save_total_limit == 1
3074-
and checkpoints_sorted[-1] != self.state.best_model_checkpoint
3075-
):
3076-
save_total_limit = 2
3077-
3078-
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
3079-
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
3080-
for checkpoint in checkpoints_to_be_deleted:
3081-
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
3082-
# ignore_errors for shared disks between train nodes.
3083-
shutil.rmtree(checkpoint, ignore_errors=True)
3071+
# If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
3072+
# we don't do to allow resuming.
3073+
save_total_limit = self.args.save_total_limit
3074+
if (
3075+
self.state.best_model_checkpoint is not None
3076+
and self.args.save_total_limit == 1
3077+
and checkpoints_sorted[-1] != self.state.best_model_checkpoint
3078+
):
3079+
save_total_limit = 2
3080+
3081+
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
3082+
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
3083+
for checkpoint in checkpoints_to_be_deleted:
3084+
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
3085+
# ignore_errors for shared disks between train nodes.
3086+
shutil.rmtree(checkpoint, ignore_errors=True)
30843087

30853088
def _save(
30863089
self,

paddlenlp/trainer/trainer_callback.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ class TrainerControl:
157157
should_save: bool = False
158158
should_evaluate: bool = False
159159
should_log: bool = False
160+
should_save_hf: bool = False
160161

161162
def _new_training(self):
162163
"""Internal method that resets the variable for a new training."""
@@ -171,6 +172,7 @@ def _new_step(self):
171172
self.should_save = False
172173
self.should_evaluate = False
173174
self.should_log = False
175+
self.should_save_hf = False
174176

175177

176178
class TrainerCallback:
@@ -306,6 +308,12 @@ def on_prediction_step(self, args: TrainingArguments, state: TrainerState, contr
306308
"""
307309
pass
308310

311+
def on_save_hf(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
312+
"""
313+
Event called after a huggingface checkpoint save.
314+
"""
315+
pass
316+
309317

310318
class CallbackHandler(TrainerCallback):
311319
"""Internal class that just calls the list of callbacks in order."""
@@ -386,6 +394,7 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T
386394
control.should_log = False
387395
control.should_evaluate = False
388396
control.should_save = False
397+
control.should_save_hf = False
389398
return self.call_event("on_step_begin", args, state, control)
390399

391400
def on_load_data_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, inputs: Dict):
@@ -418,6 +427,10 @@ def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerC
418427
def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
419428
return self.call_event("on_prediction_step", args, state, control)
420429

430+
def on_save_hf(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
431+
control.should_save_hf = False
432+
return self.call_event("on_save_hf", args, state, control)
433+
421434
def call_event(self, event, args, state, control, **kwargs):
422435
for callback in self.callbacks:
423436
result = getattr(callback, event)(
@@ -474,6 +487,14 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
474487
if state.global_step >= state.max_steps:
475488
control.should_training_stop = True
476489

490+
# Save hf
491+
if (
492+
args.save_strategy == IntervalStrategy.STEPS
493+
and args.save_hf_steps > 0
494+
and state.global_step % args.save_hf_steps == 0
495+
):
496+
control.should_save_hf = True
497+
477498
return control
478499

479500
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):

paddlenlp/trainer/trainer_utils.py

Lines changed: 203 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,18 @@
2626
import math
2727
import os
2828
import random
29+
import re
2930
import threading
3031
import time
3132
from contextlib import contextmanager
3233
from enum import Enum
3334
from pathlib import Path
34-
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
35+
from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple, Union
3536

3637
import numpy as np
3738
import paddle
3839
import paddle.distributed as dist
40+
from paddle import Tensor
3941
from paddle.distributed import fleet
4042
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
4143
DygraphShardingOptimizer,
@@ -44,6 +46,8 @@
4446
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
4547
from paddle.io import IterableDataset
4648
from paddle.optimizer.lr import LambdaDecay
49+
from safetensors import safe_open
50+
from safetensors.paddle import save_file
4751

4852
from paddlenlp.ops import Topology
4953

@@ -1445,3 +1449,201 @@ def buffer_params():
14451449
continue
14461450
param_list.append(param)
14471451
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list)
1452+
1453+
1454+
def _parse_size(size_str: str) -> int:
1455+
"""Parses a size string like '100MB', '2GB' into the number of bytes."""
1456+
size_str = size_str.upper().strip()
1457+
match = re.match(r"^(\d+\.?\d*)\s*(B|KB|MB|GB|TB)?$", size_str)
1458+
if not match:
1459+
raise ValueError(f"Could not parse size string: '{size_str}'")
1460+
1461+
num_str, unit = match.groups()
1462+
num = float(num_str)
1463+
1464+
if unit == "B" or unit is None:
1465+
return int(num)
1466+
elif unit == "KB":
1467+
return int(num * 1024)
1468+
elif unit == "MB":
1469+
return int(num * 1024**2)
1470+
elif unit == "GB":
1471+
return int(num * 1024**3)
1472+
elif unit == "TB":
1473+
return int(num * 1024**4)
1474+
else:
1475+
# This case should not be reached due to regex
1476+
raise ValueError(f"Unknown unit: '{unit}'")
1477+
1478+
1479+
def save_full_param(
1480+
itr: Iterator[tuple[str, Tensor]],
1481+
save_dir: str,
1482+
rank: int,
1483+
moe_sharding_world_size: int,
1484+
max_shard_size: str = "2GB",
1485+
num_saver_ranks: int = 8,
1486+
) -> None:
1487+
"""
1488+
Saves model weights from an iterator into shards, supporting max shard size
1489+
and a limited number of saver ranks.
1490+
1491+
Only ranks less than `num_saver_ranks` will perform disk I/O. All other ranks
1492+
will iterate through the data to maintain synchronization but will not save.
1493+
The parameter distribution logic is based on `num_saver_ranks`, ensuring all
1494+
parameters are handled by a designated saver rank.
1495+
1496+
Args:
1497+
itr (Iterator): An iterator that yields (param_key, param_tensor).
1498+
save_dir (str): The directory where shard files will be saved.
1499+
rank (int): The rank of the current process.
1500+
moe_sharding_world_size (int): The total number of processes.
1501+
max_shard_size (str): The maximum size for each shard file, e.g., "500MB", "2GB".
1502+
num_saver_ranks (int): The number of ranks (starting from 0) that will save files.
1503+
"""
1504+
1505+
# 1. Non-saver ranks simply consume the iterator to stay in sync.
1506+
if rank >= num_saver_ranks:
1507+
logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Non-saver) Consuming iterator for synchronization...")
1508+
for _ in itr:
1509+
pass
1510+
logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Non-saver) Iterator consumption complete.")
1511+
return
1512+
1513+
max_shard_size_bytes = _parse_size(max_shard_size)
1514+
logger.info(
1515+
f"[Rank {rank}/{moe_sharding_world_size}] (Saver) Initializing save. "
1516+
f"Max shard size set to: {max_shard_size_bytes / 1024**3:.2f} GB"
1517+
)
1518+
1519+
os.makedirs(save_dir, exist_ok=True)
1520+
1521+
current_shard_state_dict = {}
1522+
current_shard_size_bytes = 0
1523+
sub_shard_index = 0
1524+
1525+
def _save_current_shard():
1526+
nonlocal sub_shard_index, current_shard_state_dict, current_shard_size_bytes
1527+
if not current_shard_state_dict:
1528+
return
1529+
1530+
# Filename includes the main shard number (rank) and the sub-shard index
1531+
cur_rank = paddle.distributed.get_rank()
1532+
shard_filename = f"shard_{cur_rank}-{sub_shard_index}.safetensors"
1533+
save_path = os.path.join(save_dir, shard_filename)
1534+
1535+
logger.info(
1536+
f"[Rank {rank}/{moe_sharding_world_size}] Saving sub-shard {sub_shard_index}... "
1537+
f"Size: {current_shard_size_bytes / 1024**2:.2f} MB, "
1538+
f"Params: {len(current_shard_state_dict)}, "
1539+
f"Path: {save_path}"
1540+
)
1541+
1542+
save_file(current_shard_state_dict, save_path)
1543+
1544+
# Reset for the next shard
1545+
sub_shard_index += 1
1546+
current_shard_state_dict = {}
1547+
current_shard_size_bytes = 0
1548+
1549+
logger.info(f"[Rank {rank}/{moe_sharding_world_size}] Starting to process the weight iterator...")
1550+
1551+
total_size = 0
1552+
1553+
for i, (param_key, param) in enumerate(itr):
1554+
param_size_bytes = param.numel() * param.element_size()
1555+
total_size += param_size_bytes.item()
1556+
if i % num_saver_ranks == rank:
1557+
if current_shard_size_bytes > 0 and (current_shard_size_bytes + param_size_bytes > max_shard_size_bytes):
1558+
_save_current_shard()
1559+
1560+
current_shard_state_dict[param_key] = param
1561+
current_shard_size_bytes += param_size_bytes
1562+
1563+
if current_shard_size_bytes >= max_shard_size_bytes:
1564+
_save_current_shard()
1565+
_save_current_shard()
1566+
logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Saver) All shards saved successfully.")
1567+
return total_size
1568+
1569+
1570+
def replace_name_and_gen_index(path, cur_rank_total_size):
1571+
index_mapping = {}
1572+
cur_rank = paddle.distributed.get_rank()
1573+
safetensor_files = [fname for fname in os.listdir(path) if fname.endswith(".safetensors")]
1574+
files_num = len(safetensor_files)
1575+
all_files_num = []
1576+
paddle.distributed.all_gather_object(all_files_num, files_num)
1577+
total_files_num = sum(all_files_num)
1578+
1579+
all_sizes = []
1580+
paddle.distributed.all_gather_object(all_sizes, cur_rank_total_size)
1581+
total_size = sum(all_sizes)
1582+
1583+
start_idx = []
1584+
acc = 1
1585+
for files_num in all_files_num:
1586+
start_idx.append(acc)
1587+
acc += files_num
1588+
1589+
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
1590+
env_local_size = int(os.environ.get("PADDLE_LOCAL_SIZE", 8))
1591+
assert env_local_rank >= 0
1592+
1593+
cur_file_index = start_idx[cur_rank] // env_local_size
1594+
total_files_num = total_files_num // env_local_size
1595+
1596+
total_size = total_size // env_local_size
1597+
1598+
index_mapping = {}
1599+
if env_local_rank == 0:
1600+
for file in safetensor_files:
1601+
cur_file_index += 1
1602+
file_path = os.path.join(path, file)
1603+
new_file_name = f"model-{cur_file_index:05d}-of-{total_files_num:05d}.safetensors"
1604+
with safe_open(file_path, framework="np") as f:
1605+
for key in f.keys():
1606+
index_mapping[key] = new_file_name
1607+
new_file_path = os.path.join(path, new_file_name)
1608+
os.rename(file_path, new_file_path)
1609+
1610+
index_mapping_list = []
1611+
paddle.distributed.all_gather_object(index_mapping_list, index_mapping)
1612+
index_mapping = {}
1613+
for mapping in index_mapping_list:
1614+
index_mapping.update(mapping)
1615+
1616+
if env_local_rank == 0:
1617+
index_file_name = "model.safetensors.index.json"
1618+
index_infos = {}
1619+
index_infos["metadata"] = {}
1620+
index_infos["metadata"]["total_size"] = total_size
1621+
index_infos["weight_map"] = dict(sorted(index_mapping.items()))
1622+
with open(os.path.join(path, index_file_name), "w") as f:
1623+
json.dump(index_infos, f, indent=4)
1624+
1625+
1626+
def save_hf_checkpoint(
1627+
model,
1628+
aoa_config,
1629+
h_group,
1630+
v_group,
1631+
num_splits,
1632+
shard_idx,
1633+
path,
1634+
):
1635+
itr = model.full(
1636+
aoa_config=aoa_config, h_group=h_group, v_group=v_group, num_splits=num_splits, shard_idx=shard_idx
1637+
)
1638+
num_saver_ranks = h_group.nranks * v_group.nranks
1639+
rank = h_group.rank + v_group.rank * h_group.nranks
1640+
total_saved_size = save_full_param(
1641+
itr=itr,
1642+
save_dir=path,
1643+
rank=rank,
1644+
moe_sharding_world_size=num_saver_ranks,
1645+
max_shard_size="16GB",
1646+
num_saver_ranks=num_saver_ranks,
1647+
)
1648+
paddle.distributed.barrier()
1649+
replace_name_and_gen_index(path, total_saved_size)

paddlenlp/trainer/training_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,8 @@ class TrainingArguments:
413413
Specifies the format for loading checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint'. (default: None). This setting is ignored if the corresponding switch is configured.
414414
aoa_config (`Optional[dict[str, list[str]]]`, *optional*):
415415
The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None.
416+
save_hf_steps (`int`, *optional*, defaults to 500):
417+
Number of updates steps before two huggingface checkpoint saves if `save_strategy="steps"`.
416418
"""
417419

418420
output_dir: str = field(
@@ -1142,6 +1144,8 @@ class TrainingArguments:
11421144
},
11431145
)
11441146

1147+
save_hf_steps: int = field(default=-1, metadata={"help": "Save huggingface checkpoint every X updates steps."})
1148+
11451149
def __post_init__(self):
11461150
world_size = paddle.distributed.get_world_size()
11471151
if in_auto_parallel_align_mode():

paddlenlp/utils/env.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,6 @@ def _get_bool_env(env_key: str, default_value: str) -> bool:
166166

167167
USE_FAST_TOKENIZER: bool = _get_bool_env("USE_FAST_TOKENIZER", "false")
168168
PREFILL_USE_SAGE_ATTN: bool = _get_bool_env("PREFILL_USE_SAGE_ATTN", "false")
169+
170+
# hf checkpoint dir name
171+
PREFIX_HF_CHECKPOINT_DIR = "hf_checkpoint"

0 commit comments

Comments
 (0)