|
26 | 26 | import math |
27 | 27 | import os |
28 | 28 | import random |
| 29 | +import re |
29 | 30 | import threading |
30 | 31 | import time |
31 | 32 | from contextlib import contextmanager |
32 | 33 | from enum import Enum |
33 | 34 | from pathlib import Path |
34 | | -from typing import Dict, List, NamedTuple, Optional, Tuple, Union |
| 35 | +from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple, Union |
35 | 36 |
|
36 | 37 | import numpy as np |
37 | 38 | import paddle |
38 | 39 | import paddle.distributed as dist |
| 40 | +from paddle import Tensor |
39 | 41 | from paddle.distributed import fleet |
40 | 42 | from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( |
41 | 43 | DygraphShardingOptimizer, |
|
44 | 46 | from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker |
45 | 47 | from paddle.io import IterableDataset |
46 | 48 | from paddle.optimizer.lr import LambdaDecay |
| 49 | +from safetensors import safe_open |
| 50 | +from safetensors.paddle import save_file |
47 | 51 |
|
48 | 52 | from paddlenlp.ops import Topology |
49 | 53 |
|
@@ -1445,3 +1449,201 @@ def buffer_params(): |
1445 | 1449 | continue |
1446 | 1450 | param_list.append(param) |
1447 | 1451 | optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list) |
| 1452 | + |
| 1453 | + |
| 1454 | +def _parse_size(size_str: str) -> int: |
| 1455 | + """Parses a size string like '100MB', '2GB' into the number of bytes.""" |
| 1456 | + size_str = size_str.upper().strip() |
| 1457 | + match = re.match(r"^(\d+\.?\d*)\s*(B|KB|MB|GB|TB)?$", size_str) |
| 1458 | + if not match: |
| 1459 | + raise ValueError(f"Could not parse size string: '{size_str}'") |
| 1460 | + |
| 1461 | + num_str, unit = match.groups() |
| 1462 | + num = float(num_str) |
| 1463 | + |
| 1464 | + if unit == "B" or unit is None: |
| 1465 | + return int(num) |
| 1466 | + elif unit == "KB": |
| 1467 | + return int(num * 1024) |
| 1468 | + elif unit == "MB": |
| 1469 | + return int(num * 1024**2) |
| 1470 | + elif unit == "GB": |
| 1471 | + return int(num * 1024**3) |
| 1472 | + elif unit == "TB": |
| 1473 | + return int(num * 1024**4) |
| 1474 | + else: |
| 1475 | + # This case should not be reached due to regex |
| 1476 | + raise ValueError(f"Unknown unit: '{unit}'") |
| 1477 | + |
| 1478 | + |
| 1479 | +def save_full_param( |
| 1480 | + itr: Iterator[tuple[str, Tensor]], |
| 1481 | + save_dir: str, |
| 1482 | + rank: int, |
| 1483 | + moe_sharding_world_size: int, |
| 1484 | + max_shard_size: str = "2GB", |
| 1485 | + num_saver_ranks: int = 8, |
| 1486 | +) -> None: |
| 1487 | + """ |
| 1488 | + Saves model weights from an iterator into shards, supporting max shard size |
| 1489 | + and a limited number of saver ranks. |
| 1490 | +
|
| 1491 | + Only ranks less than `num_saver_ranks` will perform disk I/O. All other ranks |
| 1492 | + will iterate through the data to maintain synchronization but will not save. |
| 1493 | + The parameter distribution logic is based on `num_saver_ranks`, ensuring all |
| 1494 | + parameters are handled by a designated saver rank. |
| 1495 | +
|
| 1496 | + Args: |
| 1497 | + itr (Iterator): An iterator that yields (param_key, param_tensor). |
| 1498 | + save_dir (str): The directory where shard files will be saved. |
| 1499 | + rank (int): The rank of the current process. |
| 1500 | + moe_sharding_world_size (int): The total number of processes. |
| 1501 | + max_shard_size (str): The maximum size for each shard file, e.g., "500MB", "2GB". |
| 1502 | + num_saver_ranks (int): The number of ranks (starting from 0) that will save files. |
| 1503 | + """ |
| 1504 | + |
| 1505 | + # 1. Non-saver ranks simply consume the iterator to stay in sync. |
| 1506 | + if rank >= num_saver_ranks: |
| 1507 | + logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Non-saver) Consuming iterator for synchronization...") |
| 1508 | + for _ in itr: |
| 1509 | + pass |
| 1510 | + logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Non-saver) Iterator consumption complete.") |
| 1511 | + return |
| 1512 | + |
| 1513 | + max_shard_size_bytes = _parse_size(max_shard_size) |
| 1514 | + logger.info( |
| 1515 | + f"[Rank {rank}/{moe_sharding_world_size}] (Saver) Initializing save. " |
| 1516 | + f"Max shard size set to: {max_shard_size_bytes / 1024**3:.2f} GB" |
| 1517 | + ) |
| 1518 | + |
| 1519 | + os.makedirs(save_dir, exist_ok=True) |
| 1520 | + |
| 1521 | + current_shard_state_dict = {} |
| 1522 | + current_shard_size_bytes = 0 |
| 1523 | + sub_shard_index = 0 |
| 1524 | + |
| 1525 | + def _save_current_shard(): |
| 1526 | + nonlocal sub_shard_index, current_shard_state_dict, current_shard_size_bytes |
| 1527 | + if not current_shard_state_dict: |
| 1528 | + return |
| 1529 | + |
| 1530 | + # Filename includes the main shard number (rank) and the sub-shard index |
| 1531 | + cur_rank = paddle.distributed.get_rank() |
| 1532 | + shard_filename = f"shard_{cur_rank}-{sub_shard_index}.safetensors" |
| 1533 | + save_path = os.path.join(save_dir, shard_filename) |
| 1534 | + |
| 1535 | + logger.info( |
| 1536 | + f"[Rank {rank}/{moe_sharding_world_size}] Saving sub-shard {sub_shard_index}... " |
| 1537 | + f"Size: {current_shard_size_bytes / 1024**2:.2f} MB, " |
| 1538 | + f"Params: {len(current_shard_state_dict)}, " |
| 1539 | + f"Path: {save_path}" |
| 1540 | + ) |
| 1541 | + |
| 1542 | + save_file(current_shard_state_dict, save_path) |
| 1543 | + |
| 1544 | + # Reset for the next shard |
| 1545 | + sub_shard_index += 1 |
| 1546 | + current_shard_state_dict = {} |
| 1547 | + current_shard_size_bytes = 0 |
| 1548 | + |
| 1549 | + logger.info(f"[Rank {rank}/{moe_sharding_world_size}] Starting to process the weight iterator...") |
| 1550 | + |
| 1551 | + total_size = 0 |
| 1552 | + |
| 1553 | + for i, (param_key, param) in enumerate(itr): |
| 1554 | + param_size_bytes = param.numel() * param.element_size() |
| 1555 | + total_size += param_size_bytes.item() |
| 1556 | + if i % num_saver_ranks == rank: |
| 1557 | + if current_shard_size_bytes > 0 and (current_shard_size_bytes + param_size_bytes > max_shard_size_bytes): |
| 1558 | + _save_current_shard() |
| 1559 | + |
| 1560 | + current_shard_state_dict[param_key] = param |
| 1561 | + current_shard_size_bytes += param_size_bytes |
| 1562 | + |
| 1563 | + if current_shard_size_bytes >= max_shard_size_bytes: |
| 1564 | + _save_current_shard() |
| 1565 | + _save_current_shard() |
| 1566 | + logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Saver) All shards saved successfully.") |
| 1567 | + return total_size |
| 1568 | + |
| 1569 | + |
| 1570 | +def replace_name_and_gen_index(path, cur_rank_total_size): |
| 1571 | + index_mapping = {} |
| 1572 | + cur_rank = paddle.distributed.get_rank() |
| 1573 | + safetensor_files = [fname for fname in os.listdir(path) if fname.endswith(".safetensors")] |
| 1574 | + files_num = len(safetensor_files) |
| 1575 | + all_files_num = [] |
| 1576 | + paddle.distributed.all_gather_object(all_files_num, files_num) |
| 1577 | + total_files_num = sum(all_files_num) |
| 1578 | + |
| 1579 | + all_sizes = [] |
| 1580 | + paddle.distributed.all_gather_object(all_sizes, cur_rank_total_size) |
| 1581 | + total_size = sum(all_sizes) |
| 1582 | + |
| 1583 | + start_idx = [] |
| 1584 | + acc = 1 |
| 1585 | + for files_num in all_files_num: |
| 1586 | + start_idx.append(acc) |
| 1587 | + acc += files_num |
| 1588 | + |
| 1589 | + env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1)) |
| 1590 | + env_local_size = int(os.environ.get("PADDLE_LOCAL_SIZE", 8)) |
| 1591 | + assert env_local_rank >= 0 |
| 1592 | + |
| 1593 | + cur_file_index = start_idx[cur_rank] // env_local_size |
| 1594 | + total_files_num = total_files_num // env_local_size |
| 1595 | + |
| 1596 | + total_size = total_size // env_local_size |
| 1597 | + |
| 1598 | + index_mapping = {} |
| 1599 | + if env_local_rank == 0: |
| 1600 | + for file in safetensor_files: |
| 1601 | + cur_file_index += 1 |
| 1602 | + file_path = os.path.join(path, file) |
| 1603 | + new_file_name = f"model-{cur_file_index:05d}-of-{total_files_num:05d}.safetensors" |
| 1604 | + with safe_open(file_path, framework="np") as f: |
| 1605 | + for key in f.keys(): |
| 1606 | + index_mapping[key] = new_file_name |
| 1607 | + new_file_path = os.path.join(path, new_file_name) |
| 1608 | + os.rename(file_path, new_file_path) |
| 1609 | + |
| 1610 | + index_mapping_list = [] |
| 1611 | + paddle.distributed.all_gather_object(index_mapping_list, index_mapping) |
| 1612 | + index_mapping = {} |
| 1613 | + for mapping in index_mapping_list: |
| 1614 | + index_mapping.update(mapping) |
| 1615 | + |
| 1616 | + if env_local_rank == 0: |
| 1617 | + index_file_name = "model.safetensors.index.json" |
| 1618 | + index_infos = {} |
| 1619 | + index_infos["metadata"] = {} |
| 1620 | + index_infos["metadata"]["total_size"] = total_size |
| 1621 | + index_infos["weight_map"] = dict(sorted(index_mapping.items())) |
| 1622 | + with open(os.path.join(path, index_file_name), "w") as f: |
| 1623 | + json.dump(index_infos, f, indent=4) |
| 1624 | + |
| 1625 | + |
| 1626 | +def save_hf_checkpoint( |
| 1627 | + model, |
| 1628 | + aoa_config, |
| 1629 | + h_group, |
| 1630 | + v_group, |
| 1631 | + num_splits, |
| 1632 | + shard_idx, |
| 1633 | + path, |
| 1634 | +): |
| 1635 | + itr = model.full( |
| 1636 | + aoa_config=aoa_config, h_group=h_group, v_group=v_group, num_splits=num_splits, shard_idx=shard_idx |
| 1637 | + ) |
| 1638 | + num_saver_ranks = h_group.nranks * v_group.nranks |
| 1639 | + rank = h_group.rank + v_group.rank * h_group.nranks |
| 1640 | + total_saved_size = save_full_param( |
| 1641 | + itr=itr, |
| 1642 | + save_dir=path, |
| 1643 | + rank=rank, |
| 1644 | + moe_sharding_world_size=num_saver_ranks, |
| 1645 | + max_shard_size="16GB", |
| 1646 | + num_saver_ranks=num_saver_ranks, |
| 1647 | + ) |
| 1648 | + paddle.distributed.barrier() |
| 1649 | + replace_name_and_gen_index(path, total_saved_size) |
0 commit comments