From 370e5457df29c2df6c4b092f0266d52a31670a2d Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 16 Jan 2025 18:11:54 +0800 Subject: [PATCH 1/8] support distribute checkpoint io --- .../booster/plugin/hybrid_parallel_plugin.py | 3 + colossalai/checkpoint_io/__init__.py | 3 + .../distributed_checkpoint_io.py | 621 ++++++++++++++++++ .../hybrid_parallel_checkpoint_io.py | 10 +- colossalai/checkpoint_io/utils.py | 13 +- .../shardformer/layer/parallel_module.py | 1 - .../test_dist_checkpointio.py | 147 +++++ 7 files changed, 784 insertions(+), 14 deletions(-) create mode 100644 colossalai/checkpoint_io/distributed_checkpoint_io.py create mode 100644 tests/test_checkpoint_io/test_dist_checkpointio.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bc9425a0b0cd..1fba6d4b5c4e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -78,6 +78,9 @@ def __init__( self.require_grad_sync = True self.overlap_allgather = overlap_allgather self.use_fp8 = use_fp8 + self.param_origin_shape = {} + for name, param in module.named_parameters(): + self.param_origin_shape[name] = param.shape shardformer = ShardFormer(shard_config) if custom_policy is not None: diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index ef37534fe01a..bf61a862aab4 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -1,6 +1,8 @@ from .checkpoint_io_base import CheckpointIO from .general_checkpoint_io import GeneralCheckpointIO from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO + +from .distributed_checkpoint_io import DistributedCheckpointIO from .index_file import CheckpointIndexFile from .moe_checkpoint import MoECheckpointIO @@ -10,4 +12,5 @@ "GeneralCheckpointIO", "HybridParallelCheckpointIO", "MoECheckpointIO", + "DistributedCheckpointIO", ] diff --git a/colossalai/checkpoint_io/distributed_checkpoint_io.py b/colossalai/checkpoint_io/distributed_checkpoint_io.py new file mode 100644 index 000000000000..05214833eb60 --- /dev/null +++ b/colossalai/checkpoint_io/distributed_checkpoint_io.py @@ -0,0 +1,621 @@ +import json +import logging +import os +from pathlib import Path +from typing import Dict, Iterator, Optional, OrderedDict, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_default_group + +from colossalai.cluster import DistCoordinator +from colossalai.interface import ModelWrapper +from colossalai.utils import get_non_persistent_buffers_set + +from .general_checkpoint_io import GeneralCheckpointIO +from .index_file import CheckpointIndexFile +from .utils import ( + StateDictSharder, + async_save_state_dict_shards, + create_pinned_state_dict, + get_model_base_filenames, + load_state_dict, + save_state_dict, + save_state_dict_shards, + search_tp_partition_dim, +) + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" + +MODEL_META_PREFIX = "pytorch_model-meta-dist-" +MODEL_WEIGHT_PREFIX = "pytorch_model-dist-" +MODEL_SHARD_SUUFIX = ".index.json" + + +class DistributedCheckpointIO(GeneralCheckpointIO): + """ + CheckpointIO for Hybrid Parallel Training. + + Args: + dp_group (ProcessGroup): Process group along data parallel dimension. + pp_group (ProcessGroup): Process group along pipeline parallel dimension. + tp_group (ProcessGroup): Process group along tensor parallel dimension. + zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2]. + verbose (bool, optional): Whether to print logging massage when saving/loading has been successfully executed. Defaults to True. + """ + + def __init__( + self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + sp_group: ProcessGroup, + zero_stage: int, + verbose: bool = True, + ) -> None: + super().__init__() + self.global_dp_group = dp_group + self.pp_group = pp_group + self.tp_group = tp_group + self.sp_group = sp_group + self.dp_rank = dist.get_rank(self.global_dp_group) + self.tp_rank = dist.get_rank(self.tp_group) + self.pp_rank = dist.get_rank(self.pp_group) + self.sp_rank = dist.get_rank(self.sp_group) + self.global_dp_size = dist.get_world_size(dp_group) + self.pp_size = dist.get_world_size(pp_group) + self.tp_size = dist.get_world_size(tp_group) + self.use_zero = zero_stage > 0 + self.verbose = verbose + self.coordinator = DistCoordinator() + self.model_metadata = None + self.optimizer_metadata = None + self.global_rank = dist.get_rank(_get_default_group()) + + @staticmethod + def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False): + destination = dict() + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + destination[prefix + name] = param + # Save buffers. + non_persist_buffers_set = get_non_persistent_buffers_set(model) + for name, buf in model.named_buffers(): + if buf is not None and name not in non_persist_buffers_set: + buffer = buf if keep_vars else buf.detach() + destination[prefix + name] = buffer + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + extra_state = model.get_extra_state() + destination[extra_state_key] = extra_state + return destination + + @staticmethod + def load_state_dict( + model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False + ): + destination = dict() + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + with torch.no_grad(): + param.copy_(state_dict[prefix + name]) + # Save buffers. + non_persist_buffers_set = get_non_persistent_buffers_set(model) + for name, buf in model.named_buffers(): + if buf is not None and name not in non_persist_buffers_set: + with torch.no_grad(): + buf.copy_(state_dict[prefix + name]) + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + extra_state = model.get_extra_state() + with torch.no_grad(): + extra_state.copy_(state_dict[extra_state_key]) + return destination + + def create_model_metadata( + self, + model: nn.Module, + prefix: str = "", + ): + param_origin_shape = model.param_origin_shape + model = model.unwrap() + self.model_metadata = {} + for name, param in model.named_parameters(): + if param is None: + continue + self.model_metadata[prefix + name] = {} + original_shape = param_origin_shape[name] + tp_partition_dim = search_tp_partition_dim( + current_shape=param.shape, original_shape=original_shape, tp_size=self.tp_size + ) + self.model_metadata[prefix + name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int) + self.model_metadata[prefix + name]["lengths"] = list(param.shape) + self.model_metadata[prefix + name]["global_shape"] = list(original_shape) + if tp_partition_dim is not None: + partition_size = param.shape[tp_partition_dim] + self.model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * self.tp_rank + if self.tp_rank == self.tp_size - 1: + self.model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[ + tp_partition_dim + ] - (partition_size * (self.tp_size - 1)) + + def save_metadata(self, metadata_file, checkpoint_file=None, total_size=None): + metadata_dicts = { + "checkpoint_version": "1.0", + "total_size": total_size, + "metadata": {}, + } + for name, data in self.model_metadata.items(): + metadata_dicts["metadata"][name] = {} + for k, v in data.items(): + if isinstance(v, torch.Tensor): + v = v.tolist() + metadata_dicts["metadata"][name][k] = v + if checkpoint_file is not None: + metadata_dicts["metadata"][name]["file"] = checkpoint_file + metadata_dicts["metadata"][name]["rank"] = self.global_rank + with open(metadata_file, "w") as json_file: + json.dump(metadata_dicts, json_file, indent=4) + + def save_unsharded_model( + self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False + ): + """ + Save model state dict to a single file with given checkpointing path. + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path. + gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. + """ + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() + self.create_model_metadata(model) + + model = model.unwrap() + if self.dp_rank != 0 and self.sp_rank == 0: + return + + # The logic of collecting parameter shards along tp degree + # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. + state_dict = DistributedCheckpointIO.model_state_dict(model) + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + model_dist_id = self.tp_size * self.pp_rank + self.tp_rank + file_name = f"{MODEL_WEIGHT_PREFIX}{model_dist_id:05d}.bin" + if use_async: + file_name = file_name.replace(".bin", ".safetensors") + checkpoint_file = os.path.join(checkpoint, file_name) + metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{model_dist_id:05d}.json") + self.save_metadata(metadata_file, file_name) + + if use_async: + from colossalai.utils.safetensors import save + + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + for name, param in state_dict.items(): + self.pinned_state_dicts[id(model)][name].copy_(param) + state_dict[name] = self.pinned_state_dicts[id(model)][name] + writer = save(path=checkpoint_file, state_dict=state_dict) + self.async_writers.append(writer) + else: + save_state_dict(state_dict, checkpoint_file, use_safetensors) + + def load_metadata(self, checkpoint: str): + metadata_dict = {} + for filename in os.listdir(checkpoint): + if filename.startswith(MODEL_META_PREFIX) and filename.endswith(".json"): + file_path = os.path.join(checkpoint, filename) + try: + with open(file_path, "r") as f: + metadata_json = json.load(f) + for name, item in metadata_json["metadata"].items(): + if name not in metadata_dict: + metadata_dict[name] = {} + metadata_dict[name]["global_shape"] = item["global_shape"] + metadata_dict[name]["shards"] = {} + else: + assert metadata_dict[name]["global_shape"] == item["global_shape"] + shard = {} + shard[item["rank"]] = {} + shard[item["rank"]]["file"] = item["file"] + shard[item["rank"]]["offsets"] = item["offsets"] + shard[item["rank"]]["lengths"] = item["lengths"] + shard[item["rank"]]["global_shape"] = item["global_shape"] + metadata_dict[name]["shards"].update(shard) + except (json.JSONDecodeError, IOError) as e: + print(f"Unable to load file {file_path}: {e}") + return metadata_dict + + def find_covering_shards(self, shards, target_offsets, target_lengths): + """ + Parameters: + + shards: A list containing information about all shards. + target_offsets: A one-dimensional array representing the starting position of the target tensor in each dimension. + target_lengths: A one-dimensional array representing the lengths of the target tensor in each dimension. + Returns: + + A list of all shards that cover the target range. + """ + target_start = target_offsets + target_end = [start + length for start, length in zip(target_offsets, target_lengths)] + + covering_shards = {} + + global_shape = None + total_lengths = None + for rank, shard in shards.items(): + shard_start = shard["offsets"] + shard_lengths = shard["lengths"] + if global_shape == None: + global_shape = shard["global_shape"] + total_lengths = [0] * len(global_shape) + shard_end = [start + length for start, length in zip(shard_start, shard_lengths)] + + overlap = any( + not (target_end[dim] <= shard_start[dim] or target_start[dim] >= shard_end[dim]) + for dim in range(len(target_start)) + ) + if overlap: + covering_shards.update({rank: shard}) + for dim in range(len(shard_start)): + total_lengths[dim] = max(total_lengths[dim], shard_start[dim] + shard_lengths[dim]) + + assert total_lengths == global_shape + return covering_shards + + def extract_weight_from_shard_partial(self, shard, target_offsets, target_lengths): + """ + Extract the target range of weights from shard data, supporting partial overlap. + + param shard: A dictionary containing shard data, including 'offsets', 'lengths', and 'weight'. + param target_offsets: A 1D array indicating the starting position of the target tensor in each dimension. + param target_lengths: A 1D array indicating the length of the target tensor in each dimension. + return: The extracted sub-tensor of the target weights and its position within the target range. + """ + shard_offsets = shard["offsets"] + shard_lengths = shard["lengths"] + weight = shard["weight"] + + slices = [] + target_slices = [] + + for dim, (t_offset, t_length, s_offset, s_length) in enumerate( + zip(target_offsets, target_lengths, shard_offsets, shard_lengths) + ): + intersection_start = max(t_offset, s_offset) + intersection_end = min(t_offset + t_length, s_offset + s_length) + + if intersection_start >= intersection_end: + return None, None + + shard_slice_start = intersection_start - s_offset + shard_slice_end = intersection_end - s_offset + slices.append(slice(shard_slice_start, shard_slice_end)) + + target_slice_start = intersection_start - t_offset + target_slice_end = intersection_end - t_offset + target_slices.append(slice(target_slice_start, target_slice_end)) + + target_weight = weight[tuple(slices)] + return target_weight, target_slices + + def assemble_tensor_from_shards_partial(self, shards, target_offsets, target_lengths, dtype): + target_tensor = torch.zeros(target_lengths, dtype=dtype) + + for rank, shard in shards.items(): + target_weight, target_slices = self.extract_weight_from_shard_partial(shard, target_offsets, target_lengths) + + if target_weight is not None and target_slices is not None: + target_tensor[tuple(target_slices)] = target_weight + + return target_tensor + + def load_unsharded_model( + self, + model: ModelWrapper, + checkpoint: str, + strict: bool = False, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): + """ + Load model from a single file with the given path of checkpoint. + + Args: + model (nn.Module): The model to be loaded. + checkpoint_index_file (str): Path to the checkpoint file. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled. + """ + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + model._force_wait_all_gather() + model_before_wrapping = model + self.create_model_metadata(model) + model = model.unwrap() + + metadata_loaded = self.load_metadata(checkpoint) + + load_files = {} + covered_shards = {} + for key, item in self.model_metadata.items(): + offsets = item["offsets"] + lengths = item["lengths"] + assert ( + item["global_shape"] == metadata_loaded[key]["global_shape"] + ), f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}" + shards = metadata_loaded[key]["shards"] + covering_shards = self.find_covering_shards(shards=shards, target_offsets=offsets, target_lengths=lengths) + covered_shards[key] = covering_shards + for rank, shard in covering_shards.items(): + if rank not in load_files: + load_files[rank] = set() + load_files[rank].add(shard["file"]) + + dtype = None + for rank, files in load_files.items(): + for file in files: + file_path = os.path.join(checkpoint, file) + state_dict_shard = load_state_dict(file_path) + for key, weight in state_dict_shard.items(): + if key not in covered_shards: + continue + if dtype == None: + dtype = weight.dtype + covered_shards[key][rank]["weight"] = weight + state_dict = {} + for key, shards in covered_shards.items(): + state = self.assemble_tensor_from_shards_partial( + shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype + ) + state_dict[key] = state + + if not low_cpu_mem_mode: + state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) + + DistributedCheckpointIO.load_state_dict(model=model, state_dict=state_dict) + + # Update master params if mixed-precision training is enabled. + model_before_wrapping.update_master_params() + + @staticmethod + def _model_sharder( + model: nn.Module, + prefix: str = "", + keep_vars: bool = False, + size_per_shard: int = 1024, + pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, + ) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + if pinned_state_dicts is not None: + if (prefix + name) not in pinned_state_dicts: + pinned_state_dicts[prefix + name] = torch.empty_like(param, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name].copy_(param) + param = pinned_state_dicts[prefix + name] + block, block_size = state_dict_sharder.append_param(prefix + name, param) + if block is not None: + yield block, block_size + + # Save buffers. + non_persist_buffers_set = get_non_persistent_buffers_set(model) + for name, buf in model.named_buffers(): + if buf is not None and name not in non_persist_buffers_set: + buffer = buf if keep_vars else buf.detach() + if pinned_state_dicts is not None: + if (prefix + name) not in pinned_state_dicts: + pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name].copy_(buffer) + buffer = pinned_state_dicts[prefix + name] + block, block_size = state_dict_sharder.append_param(prefix + name, buffer) + if block is not None: + yield block, block_size + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + extra_state = model.get_extra_state() + if pinned_state_dicts is not None: + if extra_state_key not in pinned_state_dicts: + pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu") + pinned_state_dicts[extra_state_key].copy_(extra_state) + extra_state = pinned_state_dicts[extra_state_key] + block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + def save_sharded_model( + self, + model: ModelWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + use_async: bool = False, + ) -> None: + """ + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_model.-000XX.bin" + + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. + """ + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() + self.create_model_metadata(model) + + model = model.unwrap() + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + # Devices along the same dp_group share the same copies of model. + # So only let the device with dp_rank == 0 and sp_rank == 0 save the model. + if self.dp_rank != 0 and self.sp_rank == 0: + return + + if use_async: + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(model)] + else: + pinned_state_dicts = None + state_dict_shard = DistributedCheckpointIO._model_sharder( + model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts + ) + weights_name, _ = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + + # Manage filenames of sharded weights and index file for each pipeline stage. + model_dist_id = self.tp_size * self.pp_rank + self.tp_rank + weights_name = weights_name.replace(".bin", f"-dist-{model_dist_id:05d}-shard.bin") + weights_name = weights_name.replace(".safetensors", f"-dist-{model_dist_id:05d}-shard.safetensors") + metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{model_dist_id:05d}{MODEL_SHARD_SUUFIX}") + if use_async: + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=True, + state_preprocess=False, + ) + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=True, + use_safetensors=use_safetensors, + use_pp_format=True, + ) + for k, _ in self.model_metadata.items(): + self.model_metadata[k]["file"] = index_file.get_checkpoint_file(k) + + self.save_metadata(metadata_file, total_size=total_size) + + def load_sharded_model( + self, + model: ModelWrapper, + checkpoint: Path = None, + strict: bool = False, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): + """ + Load sharded model with the given path of checkpoint folder. + + Args: + model (nn.Module): The model to be loaded. + checkpoint (str): Path of checkpointing folder. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since params on same device might be stored in different files. + """ + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + model._force_wait_all_gather() + model_before_wrapping = model # backup for model before wrapping + self.create_model_metadata(model) + model = model.unwrap() + + metadata_loaded = self.load_metadata(checkpoint=checkpoint) + + load_files = {} + covered_shards = {} + for key, item in self.model_metadata.items(): + offsets = item["offsets"] + lengths = item["lengths"] + assert ( + item["global_shape"] == metadata_loaded[key]["global_shape"] + ), f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}" + shards = metadata_loaded[key]["shards"] + covering_shards = self.find_covering_shards(shards=shards, target_offsets=offsets, target_lengths=lengths) + covered_shards[key] = covering_shards + for rank, shard in covering_shards.items(): + if rank not in load_files: + load_files[rank] = set() + load_files[rank].add(shard["file"]) + + dtype = None + for rank, files in load_files.items(): + for file in files: + file_path = os.path.join(checkpoint, file) + state_dict_shard = load_state_dict(file_path) + for key, weight in state_dict_shard.items(): + if key not in covered_shards: + continue + if dtype == None: + dtype = weight.dtype + covered_shards[key][rank]["weight"] = weight + + state_dict = {} + for key, shards in covered_shards.items(): + state = self.assemble_tensor_from_shards_partial( + shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype + ) + state_dict[key] = state + + if not low_cpu_mem_mode: + state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) + + DistributedCheckpointIO.load_state_dict(model=model, state_dict=state_dict) + + # Update master params if mixed-precision training is enabled. + model_before_wrapping.update_master_params() diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 154d5cb5e5f3..b72c576386aa 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -126,7 +126,7 @@ def _model_sharder( buffer = buf if keep_vars else buf.detach() if pinned_state_dicts is not None: if (prefix + name) not in pinned_state_dicts: - pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu") pinned_state_dicts[prefix + name].copy_(buffer) buffer = pinned_state_dicts[prefix + name] block, block_size = state_dict_sharder.append_param(prefix + name, buffer) @@ -142,7 +142,7 @@ def _model_sharder( extra_state = model.get_extra_state() if pinned_state_dicts is not None: if extra_state_key not in pinned_state_dicts: - pinned_state_dicts[extra_state_key] = torch.empty_like(param_, pin_memory=True, device="cpu") + pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu") pinned_state_dicts[extra_state_key].copy_(extra_state) extra_state = pinned_state_dicts[extra_state_key] block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) @@ -298,9 +298,9 @@ def save_sharded_model( Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) # Manage filenames of sharded weights and index file for each pipeline stage. - weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") - weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") - save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") + weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) if use_async: total_size, writers = async_save_state_dict_shards( diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 50b6f1438961..6cfdf695ba5c 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -854,14 +854,11 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: # check if there is only one a file ending with .index.json in this directory index_files = list(checkpoint_path.glob("*.index.*json")) - # if we found a .index.json file, make sure there is only one - if len(index_files) > 0: - assert ( - len(index_files) == 1 - ), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}" - if len(index_files) == 1: return True, index_files[0] + elif len(index_files) > 1: + # Used for distributed checkpoint IO, where the metadata is stored across multiple files. + return True, checkpoint_path else: return False, None else: @@ -943,8 +940,8 @@ def get_shard_filename(weights_name: str, idx: int): """ get shard file name """ - shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") - shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors") + shard_file = weights_name.replace(".bin", f"-{idx:05d}.bin") + shard_file = shard_file.replace(".safetensors", f"-{idx:05d}.safetensors") return shard_file diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index 11ef73538c36..c7c717d93cfa 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -120,7 +120,6 @@ def _load_from_state_dict( "received {}".format(key, type(input_param)) ) continue - if is_distributed_tensor(param): # shard the input param device_mesh = get_device_mesh(param) diff --git a/tests/test_checkpoint_io/test_dist_checkpointio.py b/tests/test_checkpoint_io/test_dist_checkpointio.py new file mode 100644 index 000000000000..dcbb21789290 --- /dev/null +++ b/tests/test_checkpoint_io/test_dist_checkpointio.py @@ -0,0 +1,147 @@ +import pytest +import torch +import torch.distributed as dist +from torch.optim import Adam +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.checkpoint_io import DistributedCheckpointIO +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + +TEST_CONFIGS = [ + ( + {"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1}, + {"tp_size": 2, "pp_size": 1, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1}, + ) +] + + +@parameterize("shard", [False, True]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) +@parameterize("size_per_shard", [1]) +@parameterize("test_config", TEST_CONFIGS) +@parameterize("use_async", [False, True]) +@parameterize("low_cpu_mem_mode", [False, True]) +@clear_cache_before_run() +def exam_state_dict( + shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool, low_cpu_mem_mode: bool +): + (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) + criterion = loss_fn + test_config_0, test_config_1 = test_config + plugin_0 = HybridParallelPlugin(**test_config_0) + booster_0 = Booster(plugin=plugin_0) + hybrid_ckp_0 = booster_0.checkpoint_io + booster_0.checkpoint_io = DistributedCheckpointIO( + hybrid_ckp_0.global_dp_group, + hybrid_ckp_0.pp_group, + hybrid_ckp_0.tp_group, + hybrid_ckp_0.sp_group, + hybrid_ckp_0.use_zero, + ) + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + def _preprocess_data(data): + if booster_0.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to("cuda").repeat(*new_shape) + return iter([data]) + else: + return {k: v.cuda() for k, v in data.items()} + + model_0 = model_fn().cuda() + optimizer_0 = Adam(model_0.parameters(), lr=1e-3) + model_0, optimizer_0, criterion, _, _ = booster_0.boost(model_0, optimizer_0, criterion) + + data = data_gen_fn() + model_0.train() + if booster_0.plugin.stage_manager is not None: + booster_0.execute_pipeline(_preprocess_data(data), model_0, _criterion, optimizer_0, return_loss=True) + else: + output = model_0(**_preprocess_data(data)) + loss = criterion(output) + optimizer_0.backward(loss) + + optimizer_0.step() + optimizer_0.zero_grad() + with shared_tempdir() as tempdir: + model_ckpt_path_0 = f"{tempdir}/model_0" + + booster_0.save_model( + model_0, model_ckpt_path_0, shard=shard, size_per_shard=size_per_shard, use_async=use_async + ) + booster_0.checkpoint_io._sync_d2h() + booster_0.checkpoint_io._sync_io() + dist.barrier() + + plugin_1 = HybridParallelPlugin(**test_config_1) + booster_1 = Booster(plugin=plugin_1) + hybrid_ckp_1 = booster_1.checkpoint_io + booster_1.checkpoint_io = DistributedCheckpointIO( + hybrid_ckp_1.global_dp_group, + hybrid_ckp_1.pp_group, + hybrid_ckp_1.tp_group, + hybrid_ckp_1.sp_group, + hybrid_ckp_1.use_zero, + ) + + model_1 = model_fn().cuda() + optimizer_1 = Adam(model_1.parameters(), lr=1e-3) + model_1, optimizer_1, criterion, _, _ = booster_1.boost(model_1, optimizer_1, criterion) + + booster_1.load_model(model_1, model_ckpt_path_0, low_cpu_mem_mode=low_cpu_mem_mode) + + model_ckpt_path_1 = f"{tempdir}/model_1" + booster_1.save_model( + model_1, model_ckpt_path_1, shard=shard, size_per_shard=size_per_shard, use_async=use_async + ) + booster_1.checkpoint_io._sync_d2h() + booster_1.checkpoint_io._sync_io() + dist.barrier() + + model_2 = model_fn().cuda() + optimizer_2 = Adam(model_2.parameters(), lr=1e-3) + model_2, optimizer_2, criterion, _, _ = booster_0.boost(model_2, optimizer_2, criterion) + + booster_0.load_model(model_2, model_ckpt_path_1, low_cpu_mem_mode=low_cpu_mem_mode) + check_state_dict_equal(model_0.unwrap().state_dict(), model_2.unwrap().state_dict()) + + dist.barrier() + Randomizer.reset_index() + clear_layout_converter() + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + exam_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_hybrid_ckpIO(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_hybrid_ckpIO(4) From c34ba4e033749ba2d2d38c7d494c35e24e2a7dcb Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 17 Jan 2025 10:42:15 +0800 Subject: [PATCH 2/8] fix --- colossalai/shardformer/layer/parallel_module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index c7c717d93cfa..11ef73538c36 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -120,6 +120,7 @@ def _load_from_state_dict( "received {}".format(key, type(input_param)) ) continue + if is_distributed_tensor(param): # shard the input param device_mesh = get_device_mesh(param) From e3f9de32087f77f3650571b841e0ee1aef70bdb0 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 20 Jan 2025 11:24:51 +0800 Subject: [PATCH 3/8] Modify the design --- colossalai/checkpoint_io/__init__.py | 3 - .../distributed_checkpoint_io.py | 621 ------------------ .../distributed_checkpoint_utils.py | 501 ++++++++++++++ .../checkpoint_io/general_checkpoint_io.py | 2 +- .../hybrid_parallel_checkpoint_io.py | 46 ++ colossalai/checkpoint_io/utils.py | 5 +- .../test_dist_checkpointio.py | 21 +- 7 files changed, 551 insertions(+), 648 deletions(-) delete mode 100644 colossalai/checkpoint_io/distributed_checkpoint_io.py create mode 100644 colossalai/checkpoint_io/distributed_checkpoint_utils.py diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index bf61a862aab4..ef37534fe01a 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -1,8 +1,6 @@ from .checkpoint_io_base import CheckpointIO from .general_checkpoint_io import GeneralCheckpointIO from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO - -from .distributed_checkpoint_io import DistributedCheckpointIO from .index_file import CheckpointIndexFile from .moe_checkpoint import MoECheckpointIO @@ -12,5 +10,4 @@ "GeneralCheckpointIO", "HybridParallelCheckpointIO", "MoECheckpointIO", - "DistributedCheckpointIO", ] diff --git a/colossalai/checkpoint_io/distributed_checkpoint_io.py b/colossalai/checkpoint_io/distributed_checkpoint_io.py deleted file mode 100644 index 05214833eb60..000000000000 --- a/colossalai/checkpoint_io/distributed_checkpoint_io.py +++ /dev/null @@ -1,621 +0,0 @@ -import json -import logging -import os -from pathlib import Path -from typing import Dict, Iterator, Optional, OrderedDict, Tuple - -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.distributed import ProcessGroup -from torch.distributed.distributed_c10d import _get_default_group - -from colossalai.cluster import DistCoordinator -from colossalai.interface import ModelWrapper -from colossalai.utils import get_non_persistent_buffers_set - -from .general_checkpoint_io import GeneralCheckpointIO -from .index_file import CheckpointIndexFile -from .utils import ( - StateDictSharder, - async_save_state_dict_shards, - create_pinned_state_dict, - get_model_base_filenames, - load_state_dict, - save_state_dict, - save_state_dict_shards, - search_tp_partition_dim, -) - -try: - from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX -except ImportError: - _EXTRA_STATE_KEY_SUFFIX = "_extra_state" - -MODEL_META_PREFIX = "pytorch_model-meta-dist-" -MODEL_WEIGHT_PREFIX = "pytorch_model-dist-" -MODEL_SHARD_SUUFIX = ".index.json" - - -class DistributedCheckpointIO(GeneralCheckpointIO): - """ - CheckpointIO for Hybrid Parallel Training. - - Args: - dp_group (ProcessGroup): Process group along data parallel dimension. - pp_group (ProcessGroup): Process group along pipeline parallel dimension. - tp_group (ProcessGroup): Process group along tensor parallel dimension. - zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2]. - verbose (bool, optional): Whether to print logging massage when saving/loading has been successfully executed. Defaults to True. - """ - - def __init__( - self, - dp_group: ProcessGroup, - pp_group: ProcessGroup, - tp_group: ProcessGroup, - sp_group: ProcessGroup, - zero_stage: int, - verbose: bool = True, - ) -> None: - super().__init__() - self.global_dp_group = dp_group - self.pp_group = pp_group - self.tp_group = tp_group - self.sp_group = sp_group - self.dp_rank = dist.get_rank(self.global_dp_group) - self.tp_rank = dist.get_rank(self.tp_group) - self.pp_rank = dist.get_rank(self.pp_group) - self.sp_rank = dist.get_rank(self.sp_group) - self.global_dp_size = dist.get_world_size(dp_group) - self.pp_size = dist.get_world_size(pp_group) - self.tp_size = dist.get_world_size(tp_group) - self.use_zero = zero_stage > 0 - self.verbose = verbose - self.coordinator = DistCoordinator() - self.model_metadata = None - self.optimizer_metadata = None - self.global_rank = dist.get_rank(_get_default_group()) - - @staticmethod - def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False): - destination = dict() - # Save parameters. - for name, param in model.named_parameters(): - if param is None: - continue - destination[prefix + name] = param - # Save buffers. - non_persist_buffers_set = get_non_persistent_buffers_set(model) - for name, buf in model.named_buffers(): - if buf is not None and name not in non_persist_buffers_set: - buffer = buf if keep_vars else buf.detach() - destination[prefix + name] = buffer - - # Save extra states. - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if ( - getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) - is not torch.nn.Module.get_extra_state - ): - extra_state = model.get_extra_state() - destination[extra_state_key] = extra_state - return destination - - @staticmethod - def load_state_dict( - model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False - ): - destination = dict() - # Save parameters. - for name, param in model.named_parameters(): - if param is None: - continue - with torch.no_grad(): - param.copy_(state_dict[prefix + name]) - # Save buffers. - non_persist_buffers_set = get_non_persistent_buffers_set(model) - for name, buf in model.named_buffers(): - if buf is not None and name not in non_persist_buffers_set: - with torch.no_grad(): - buf.copy_(state_dict[prefix + name]) - - # Save extra states. - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if ( - getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) - is not torch.nn.Module.get_extra_state - ): - extra_state = model.get_extra_state() - with torch.no_grad(): - extra_state.copy_(state_dict[extra_state_key]) - return destination - - def create_model_metadata( - self, - model: nn.Module, - prefix: str = "", - ): - param_origin_shape = model.param_origin_shape - model = model.unwrap() - self.model_metadata = {} - for name, param in model.named_parameters(): - if param is None: - continue - self.model_metadata[prefix + name] = {} - original_shape = param_origin_shape[name] - tp_partition_dim = search_tp_partition_dim( - current_shape=param.shape, original_shape=original_shape, tp_size=self.tp_size - ) - self.model_metadata[prefix + name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int) - self.model_metadata[prefix + name]["lengths"] = list(param.shape) - self.model_metadata[prefix + name]["global_shape"] = list(original_shape) - if tp_partition_dim is not None: - partition_size = param.shape[tp_partition_dim] - self.model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * self.tp_rank - if self.tp_rank == self.tp_size - 1: - self.model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[ - tp_partition_dim - ] - (partition_size * (self.tp_size - 1)) - - def save_metadata(self, metadata_file, checkpoint_file=None, total_size=None): - metadata_dicts = { - "checkpoint_version": "1.0", - "total_size": total_size, - "metadata": {}, - } - for name, data in self.model_metadata.items(): - metadata_dicts["metadata"][name] = {} - for k, v in data.items(): - if isinstance(v, torch.Tensor): - v = v.tolist() - metadata_dicts["metadata"][name][k] = v - if checkpoint_file is not None: - metadata_dicts["metadata"][name]["file"] = checkpoint_file - metadata_dicts["metadata"][name]["rank"] = self.global_rank - with open(metadata_file, "w") as json_file: - json.dump(metadata_dicts, json_file, indent=4) - - def save_unsharded_model( - self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False - ): - """ - Save model state dict to a single file with given checkpointing path. - - Args: - model (nn.Module): Model on local device to be saved. - checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path. - gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True. - use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. - use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. - """ - if self.coordinator.is_master(): - logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") - - assert isinstance(model, ModelWrapper), "Please boost the model before saving!" - model._force_wait_all_gather() - self.create_model_metadata(model) - - model = model.unwrap() - if self.dp_rank != 0 and self.sp_rank == 0: - return - - # The logic of collecting parameter shards along tp degree - # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. - state_dict = DistributedCheckpointIO.model_state_dict(model) - - Path(checkpoint).mkdir(parents=True, exist_ok=True) - model_dist_id = self.tp_size * self.pp_rank + self.tp_rank - file_name = f"{MODEL_WEIGHT_PREFIX}{model_dist_id:05d}.bin" - if use_async: - file_name = file_name.replace(".bin", ".safetensors") - checkpoint_file = os.path.join(checkpoint, file_name) - metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{model_dist_id:05d}.json") - self.save_metadata(metadata_file, file_name) - - if use_async: - from colossalai.utils.safetensors import save - - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) - for name, param in state_dict.items(): - self.pinned_state_dicts[id(model)][name].copy_(param) - state_dict[name] = self.pinned_state_dicts[id(model)][name] - writer = save(path=checkpoint_file, state_dict=state_dict) - self.async_writers.append(writer) - else: - save_state_dict(state_dict, checkpoint_file, use_safetensors) - - def load_metadata(self, checkpoint: str): - metadata_dict = {} - for filename in os.listdir(checkpoint): - if filename.startswith(MODEL_META_PREFIX) and filename.endswith(".json"): - file_path = os.path.join(checkpoint, filename) - try: - with open(file_path, "r") as f: - metadata_json = json.load(f) - for name, item in metadata_json["metadata"].items(): - if name not in metadata_dict: - metadata_dict[name] = {} - metadata_dict[name]["global_shape"] = item["global_shape"] - metadata_dict[name]["shards"] = {} - else: - assert metadata_dict[name]["global_shape"] == item["global_shape"] - shard = {} - shard[item["rank"]] = {} - shard[item["rank"]]["file"] = item["file"] - shard[item["rank"]]["offsets"] = item["offsets"] - shard[item["rank"]]["lengths"] = item["lengths"] - shard[item["rank"]]["global_shape"] = item["global_shape"] - metadata_dict[name]["shards"].update(shard) - except (json.JSONDecodeError, IOError) as e: - print(f"Unable to load file {file_path}: {e}") - return metadata_dict - - def find_covering_shards(self, shards, target_offsets, target_lengths): - """ - Parameters: - - shards: A list containing information about all shards. - target_offsets: A one-dimensional array representing the starting position of the target tensor in each dimension. - target_lengths: A one-dimensional array representing the lengths of the target tensor in each dimension. - Returns: - - A list of all shards that cover the target range. - """ - target_start = target_offsets - target_end = [start + length for start, length in zip(target_offsets, target_lengths)] - - covering_shards = {} - - global_shape = None - total_lengths = None - for rank, shard in shards.items(): - shard_start = shard["offsets"] - shard_lengths = shard["lengths"] - if global_shape == None: - global_shape = shard["global_shape"] - total_lengths = [0] * len(global_shape) - shard_end = [start + length for start, length in zip(shard_start, shard_lengths)] - - overlap = any( - not (target_end[dim] <= shard_start[dim] or target_start[dim] >= shard_end[dim]) - for dim in range(len(target_start)) - ) - if overlap: - covering_shards.update({rank: shard}) - for dim in range(len(shard_start)): - total_lengths[dim] = max(total_lengths[dim], shard_start[dim] + shard_lengths[dim]) - - assert total_lengths == global_shape - return covering_shards - - def extract_weight_from_shard_partial(self, shard, target_offsets, target_lengths): - """ - Extract the target range of weights from shard data, supporting partial overlap. - - param shard: A dictionary containing shard data, including 'offsets', 'lengths', and 'weight'. - param target_offsets: A 1D array indicating the starting position of the target tensor in each dimension. - param target_lengths: A 1D array indicating the length of the target tensor in each dimension. - return: The extracted sub-tensor of the target weights and its position within the target range. - """ - shard_offsets = shard["offsets"] - shard_lengths = shard["lengths"] - weight = shard["weight"] - - slices = [] - target_slices = [] - - for dim, (t_offset, t_length, s_offset, s_length) in enumerate( - zip(target_offsets, target_lengths, shard_offsets, shard_lengths) - ): - intersection_start = max(t_offset, s_offset) - intersection_end = min(t_offset + t_length, s_offset + s_length) - - if intersection_start >= intersection_end: - return None, None - - shard_slice_start = intersection_start - s_offset - shard_slice_end = intersection_end - s_offset - slices.append(slice(shard_slice_start, shard_slice_end)) - - target_slice_start = intersection_start - t_offset - target_slice_end = intersection_end - t_offset - target_slices.append(slice(target_slice_start, target_slice_end)) - - target_weight = weight[tuple(slices)] - return target_weight, target_slices - - def assemble_tensor_from_shards_partial(self, shards, target_offsets, target_lengths, dtype): - target_tensor = torch.zeros(target_lengths, dtype=dtype) - - for rank, shard in shards.items(): - target_weight, target_slices = self.extract_weight_from_shard_partial(shard, target_offsets, target_lengths) - - if target_weight is not None and target_slices is not None: - target_tensor[tuple(target_slices)] = target_weight - - return target_tensor - - def load_unsharded_model( - self, - model: ModelWrapper, - checkpoint: str, - strict: bool = False, - low_cpu_mem_mode: bool = True, - num_threads: int = 1, - ): - """ - Load model from a single file with the given path of checkpoint. - - Args: - model (nn.Module): The model to be loaded. - checkpoint_index_file (str): Path to the checkpoint file. - strict (bool, optional): For name matching during loading state_dict. Defaults to False. - This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled. - """ - if self.coordinator.is_master(): - logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") - - assert isinstance(model, ModelWrapper), "Please boost the model before loading!" - model._force_wait_all_gather() - model_before_wrapping = model - self.create_model_metadata(model) - model = model.unwrap() - - metadata_loaded = self.load_metadata(checkpoint) - - load_files = {} - covered_shards = {} - for key, item in self.model_metadata.items(): - offsets = item["offsets"] - lengths = item["lengths"] - assert ( - item["global_shape"] == metadata_loaded[key]["global_shape"] - ), f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}" - shards = metadata_loaded[key]["shards"] - covering_shards = self.find_covering_shards(shards=shards, target_offsets=offsets, target_lengths=lengths) - covered_shards[key] = covering_shards - for rank, shard in covering_shards.items(): - if rank not in load_files: - load_files[rank] = set() - load_files[rank].add(shard["file"]) - - dtype = None - for rank, files in load_files.items(): - for file in files: - file_path = os.path.join(checkpoint, file) - state_dict_shard = load_state_dict(file_path) - for key, weight in state_dict_shard.items(): - if key not in covered_shards: - continue - if dtype == None: - dtype = weight.dtype - covered_shards[key][rank]["weight"] = weight - state_dict = {} - for key, shards in covered_shards.items(): - state = self.assemble_tensor_from_shards_partial( - shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype - ) - state_dict[key] = state - - if not low_cpu_mem_mode: - state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) - - DistributedCheckpointIO.load_state_dict(model=model, state_dict=state_dict) - - # Update master params if mixed-precision training is enabled. - model_before_wrapping.update_master_params() - - @staticmethod - def _model_sharder( - model: nn.Module, - prefix: str = "", - keep_vars: bool = False, - size_per_shard: int = 1024, - pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, - ) -> Iterator[Tuple[OrderedDict, int]]: - # An internel method that breaks state_dict of model into shards within limited size. - - state_dict_sharder = StateDictSharder(size_per_shard) - - # Save parameters. - for name, param in model.named_parameters(): - if param is None: - continue - if pinned_state_dicts is not None: - if (prefix + name) not in pinned_state_dicts: - pinned_state_dicts[prefix + name] = torch.empty_like(param, pin_memory=True, device="cpu") - pinned_state_dicts[prefix + name].copy_(param) - param = pinned_state_dicts[prefix + name] - block, block_size = state_dict_sharder.append_param(prefix + name, param) - if block is not None: - yield block, block_size - - # Save buffers. - non_persist_buffers_set = get_non_persistent_buffers_set(model) - for name, buf in model.named_buffers(): - if buf is not None and name not in non_persist_buffers_set: - buffer = buf if keep_vars else buf.detach() - if pinned_state_dicts is not None: - if (prefix + name) not in pinned_state_dicts: - pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu") - pinned_state_dicts[prefix + name].copy_(buffer) - buffer = pinned_state_dicts[prefix + name] - block, block_size = state_dict_sharder.append_param(prefix + name, buffer) - if block is not None: - yield block, block_size - - # Save extra states. - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if ( - getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) - is not torch.nn.Module.get_extra_state - ): - extra_state = model.get_extra_state() - if pinned_state_dicts is not None: - if extra_state_key not in pinned_state_dicts: - pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu") - pinned_state_dicts[extra_state_key].copy_(extra_state) - extra_state = pinned_state_dicts[extra_state_key] - block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) - if block is not None: - yield block, block_size - - # Return the last block in sharder. - yield state_dict_sharder.current_block, state_dict_sharder.current_block_size - - def save_sharded_model( - self, - model: ModelWrapper, - checkpoint: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024, - use_safetensors: bool = False, - use_async: bool = False, - ) -> None: - """ - Save sharded model checkpoint under the given checkpointing path. - The following files will be created under the path: - - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. - - Multiple files that store state tensors of models. - If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". - If pipeline parallelism is not used, "pytorch_model.-000XX.bin" - - - Args: - model (nn.Module): Model on local device to be saved. - checkpoint (str): Checkpointing path which should be a directory path. - gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. - prefix (str, optional): Perfix of file to save. Defaults to None. - size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. - use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. - use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. - """ - - assert isinstance(model, ModelWrapper), "Please boost the model before saving!" - model._force_wait_all_gather() - self.create_model_metadata(model) - - model = model.unwrap() - - if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") - return - - Path(checkpoint).mkdir(parents=True, exist_ok=True) - # Devices along the same dp_group share the same copies of model. - # So only let the device with dp_rank == 0 and sp_rank == 0 save the model. - if self.dp_rank != 0 and self.sp_rank == 0: - return - - if use_async: - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = {} - pinned_state_dicts = self.pinned_state_dicts[id(model)] - else: - pinned_state_dicts = None - state_dict_shard = DistributedCheckpointIO._model_sharder( - model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts - ) - weights_name, _ = get_model_base_filenames(prefix, use_safetensors) - index_file = CheckpointIndexFile(checkpoint) - - # Manage filenames of sharded weights and index file for each pipeline stage. - model_dist_id = self.tp_size * self.pp_rank + self.tp_rank - weights_name = weights_name.replace(".bin", f"-dist-{model_dist_id:05d}-shard.bin") - weights_name = weights_name.replace(".safetensors", f"-dist-{model_dist_id:05d}-shard.safetensors") - metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{model_dist_id:05d}{MODEL_SHARD_SUUFIX}") - if use_async: - total_size, writers = async_save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=True, - state_preprocess=False, - ) - self.async_writers.extend(writers) - else: - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=True, - use_safetensors=use_safetensors, - use_pp_format=True, - ) - for k, _ in self.model_metadata.items(): - self.model_metadata[k]["file"] = index_file.get_checkpoint_file(k) - - self.save_metadata(metadata_file, total_size=total_size) - - def load_sharded_model( - self, - model: ModelWrapper, - checkpoint: Path = None, - strict: bool = False, - low_cpu_mem_mode: bool = True, - num_threads: int = 1, - ): - """ - Load sharded model with the given path of checkpoint folder. - - Args: - model (nn.Module): The model to be loaded. - checkpoint (str): Path of checkpointing folder. - strict (bool, optional): For name matching during loading state_dict. Defaults to False. - This argument should be manually set to False since params on same device might be stored in different files. - """ - assert isinstance(model, ModelWrapper), "Please boost the model before loading!" - model._force_wait_all_gather() - model_before_wrapping = model # backup for model before wrapping - self.create_model_metadata(model) - model = model.unwrap() - - metadata_loaded = self.load_metadata(checkpoint=checkpoint) - - load_files = {} - covered_shards = {} - for key, item in self.model_metadata.items(): - offsets = item["offsets"] - lengths = item["lengths"] - assert ( - item["global_shape"] == metadata_loaded[key]["global_shape"] - ), f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}" - shards = metadata_loaded[key]["shards"] - covering_shards = self.find_covering_shards(shards=shards, target_offsets=offsets, target_lengths=lengths) - covered_shards[key] = covering_shards - for rank, shard in covering_shards.items(): - if rank not in load_files: - load_files[rank] = set() - load_files[rank].add(shard["file"]) - - dtype = None - for rank, files in load_files.items(): - for file in files: - file_path = os.path.join(checkpoint, file) - state_dict_shard = load_state_dict(file_path) - for key, weight in state_dict_shard.items(): - if key not in covered_shards: - continue - if dtype == None: - dtype = weight.dtype - covered_shards[key][rank]["weight"] = weight - - state_dict = {} - for key, shards in covered_shards.items(): - state = self.assemble_tensor_from_shards_partial( - shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype - ) - state_dict[key] = state - - if not low_cpu_mem_mode: - state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) - - DistributedCheckpointIO.load_state_dict(model=model, state_dict=state_dict) - - # Update master params if mixed-precision training is enabled. - model_before_wrapping.update_master_params() diff --git a/colossalai/checkpoint_io/distributed_checkpoint_utils.py b/colossalai/checkpoint_io/distributed_checkpoint_utils.py new file mode 100644 index 000000000000..72ee327d81ae --- /dev/null +++ b/colossalai/checkpoint_io/distributed_checkpoint_utils.py @@ -0,0 +1,501 @@ +import json +import logging +import os +from pathlib import Path +from typing import Dict, Iterator, Optional, OrderedDict, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.distributed_c10d import _get_default_group + +from colossalai.interface import ModelWrapper +from colossalai.utils import get_non_persistent_buffers_set + +from .index_file import CheckpointIndexFile +from .utils import ( + StateDictSharder, + async_save_state_dict_shards, + create_pinned_state_dict, + get_model_base_filenames, + load_state_dict, + save_state_dict, + save_state_dict_shards, + search_tp_partition_dim, +) + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" + +MODEL_META_PREFIX = "pytorch_model-meta-dist-" +MODEL_WEIGHT_PREFIX = "pytorch_model-dist-" +SHARD_META_SUFFIX = ".index.json" + + +def dist_model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False): + destination = dict() + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + destination[prefix + name] = param + # Save buffers. + non_persist_buffers_set = get_non_persistent_buffers_set(model) + for name, buf in model.named_buffers(): + if buf is not None and name not in non_persist_buffers_set: + buffer = buf if keep_vars else buf.detach() + destination[prefix + name] = buffer + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + extra_state = model.get_extra_state() + destination[extra_state_key] = extra_state + return destination + +def load_state_dict_into_dist_model( + model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False +): + destination = dict() + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + with torch.no_grad(): + param.copy_(state_dict[prefix + name]) + # Save buffers. + non_persist_buffers_set = get_non_persistent_buffers_set(model) + for name, buf in model.named_buffers(): + if buf is not None and name not in non_persist_buffers_set: + with torch.no_grad(): + buf.copy_(state_dict[prefix + name]) + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + extra_state = model.get_extra_state() + with torch.no_grad(): + extra_state.copy_(state_dict[extra_state_key]) + return destination + +def create_model_metadata( + model: nn.Module, + prefix: str = "", + tp_size = None, + tp_rank = None, +): + param_origin_shape = model.param_origin_shape + model = model.unwrap() + model_metadata = {} + for name, param in model.named_parameters(): + if param is None: + continue + model_metadata[prefix + name] = {} + original_shape = param_origin_shape[name] + tp_partition_dim = search_tp_partition_dim( + current_shape=param.shape, original_shape=original_shape, tp_size=tp_size + ) + model_metadata[prefix + name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int) + model_metadata[prefix + name]["lengths"] = list(param.shape) + model_metadata[prefix + name]["global_shape"] = list(original_shape) + if tp_partition_dim is not None: + partition_size = param.shape[tp_partition_dim] + model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * tp_rank + if tp_rank == tp_size - 1: + model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[ + tp_partition_dim + ] - (partition_size * (tp_size - 1)) + return model_metadata + +def save_metadata(model_metadata, metadata_file, checkpoint_file=None, total_size=None): + metadata_dicts = { + "checkpoint_version": "1.0", + "total_size": total_size, + "metadata": {}, + } + for name, data in model_metadata.items(): + metadata_dicts["metadata"][name] = {} + for k, v in data.items(): + if isinstance(v, torch.Tensor): + v = v.tolist() + metadata_dicts["metadata"][name][k] = v + if checkpoint_file is not None: + metadata_dicts["metadata"][name]["file"] = checkpoint_file + metadata_dicts["metadata"][name]["rank"] = dist.get_rank(_get_default_group()) + with open(metadata_file, "w") as json_file: + json.dump(metadata_dicts, json_file, indent=4) + +def load_metadata(checkpoint: str): + metadata_dict = {} + for filename in os.listdir(checkpoint): + if filename.startswith(MODEL_META_PREFIX) and filename.endswith(".json"): + file_path = os.path.join(checkpoint, filename) + try: + with open(file_path, "r") as f: + metadata_json = json.load(f) + for name, item in metadata_json["metadata"].items(): + if name not in metadata_dict: + metadata_dict[name] = {} + metadata_dict[name]["global_shape"] = item["global_shape"] + metadata_dict[name]["shards"] = {} + else: + assert metadata_dict[name]["global_shape"] == item["global_shape"] + shard = {item["rank"]: {}} + for k, v in item.items(): + if k == "rank": + continue + shard[item["rank"]][k] = v + metadata_dict[name]["shards"].update(shard) + except (json.JSONDecodeError, IOError) as e: + print(f"Unable to load file {file_path}: {e}") + return metadata_dict + + +def find_covering_shards(shards, target_offsets, target_lengths): + """ + Parameters: + + shards: A list containing information about all shards. + target_offsets: A one-dimensional array representing the starting position of the target tensor in each dimension. + target_lengths: A one-dimensional array representing the lengths of the target tensor in each dimension. + Returns: + + A list of all shards that cover the target range. + """ + target_start = target_offsets + target_end = [start + length for start, length in zip(target_offsets, target_lengths)] + + covering_shards = {} + + global_shape = None + total_lengths = None + for rank, shard in shards.items(): + shard_start = shard["offsets"] + shard_lengths = shard["lengths"] + if global_shape == None: + global_shape = shard["global_shape"] + total_lengths = [0] * len(global_shape) + shard_end = [start + length for start, length in zip(shard_start, shard_lengths)] + + overlap = any( + not (target_end[dim] <= shard_start[dim] or target_start[dim] >= shard_end[dim]) + for dim in range(len(target_start)) + ) + if overlap: + covering_shards.update({rank: shard}) + for dim in range(len(shard_start)): + total_lengths[dim] = max(total_lengths[dim], shard_start[dim] + shard_lengths[dim]) + + assert total_lengths == global_shape + return covering_shards + +def extract_weight_from_shard_partial(shard, target_offsets, target_lengths): + """ + Extract the target range of weights from shard data, supporting partial overlap. + + param shard: A dictionary containing shard data, including 'offsets', 'lengths', and 'weight'. + param target_offsets: A 1D array indicating the starting position of the target tensor in each dimension. + param target_lengths: A 1D array indicating the length of the target tensor in each dimension. + return: The extracted sub-tensor of the target weights and its position within the target range. + """ + shard_offsets = shard["offsets"] + shard_lengths = shard["lengths"] + weight = shard["weight"] + + slices = [] + target_slices = [] + + for dim, (t_offset, t_length, s_offset, s_length) in enumerate( + zip(target_offsets, target_lengths, shard_offsets, shard_lengths) + ): + intersection_start = max(t_offset, s_offset) + intersection_end = min(t_offset + t_length, s_offset + s_length) + + if intersection_start >= intersection_end: + return None, None + + shard_slice_start = intersection_start - s_offset + shard_slice_end = intersection_end - s_offset + slices.append(slice(shard_slice_start, shard_slice_end)) + + target_slice_start = intersection_start - t_offset + target_slice_end = intersection_end - t_offset + target_slices.append(slice(target_slice_start, target_slice_end)) + + target_weight = weight[tuple(slices)] + return target_weight, target_slices + +def assemble_tensor_from_shards_partial(shards, target_offsets, target_lengths, dtype): + target_tensor = torch.zeros(target_lengths, dtype=dtype) + + for rank, shard in shards.items(): + target_weight, target_slices = extract_weight_from_shard_partial(shard, target_offsets, target_lengths) + + if target_weight is not None and target_slices is not None: + target_tensor[tuple(target_slices)] = target_weight + + return target_tensor + + +def is_pytorch_model_meta_dist_file(checkpoint_index_file): + if MODEL_META_PREFIX in str(checkpoint_index_file): + return True + return False + + +def dist_model_sharder( + model: nn.Module, + prefix: str = "", + keep_vars: bool = False, + size_per_shard: int = 1024, + pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, +) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + if pinned_state_dicts is not None: + if (prefix + name) not in pinned_state_dicts: + pinned_state_dicts[prefix + name] = torch.empty_like(param, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name].copy_(param) + param = pinned_state_dicts[prefix + name] + block, block_size = state_dict_sharder.append_param(prefix + name, param) + if block is not None: + yield block, block_size + + # Save buffers. + non_persist_buffers_set = get_non_persistent_buffers_set(model) + for name, buf in model.named_buffers(): + if buf is not None and name not in non_persist_buffers_set: + buffer = buf if keep_vars else buf.detach() + if pinned_state_dicts is not None: + if (prefix + name) not in pinned_state_dicts: + pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name].copy_(buffer) + buffer = pinned_state_dicts[prefix + name] + block, block_size = state_dict_sharder.append_param(prefix + name, buffer) + if block is not None: + yield block, block_size + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + extra_state = model.get_extra_state() + if pinned_state_dicts is not None: + if extra_state_key not in pinned_state_dicts: + pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu") + pinned_state_dicts[extra_state_key].copy_(extra_state) + extra_state = pinned_state_dicts[extra_state_key] + block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + +def save_dist_unshard_model( + model: ModelWrapper, model_metadata: Dict, checkpoint: str, use_safetensors: bool, use_async: bool = False, dist_id = 0, pinned_state_dicts = None +): + """ + Save model state dict to a single file with given checkpointing path. + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path. + gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. + """ + + model = model.unwrap() + + # The logic of collecting parameter shards along tp degree + # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. + state_dict = dist_model_state_dict(model) + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + file_name = f"{MODEL_WEIGHT_PREFIX}{dist_id:05d}.bin" + if use_async: + file_name = file_name.replace(".bin", ".safetensors") + checkpoint_file = os.path.join(checkpoint, file_name) + metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}.json") + save_metadata(model_metadata, metadata_file, file_name) + + if use_async: + from colossalai.utils.safetensors import save + + if id(model) not in pinned_state_dicts: + pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + for name, param in state_dict.items(): + pinned_state_dicts[id(model)][name].copy_(param) + state_dict[name] = pinned_state_dicts[id(model)][name] + writer = save(path=checkpoint_file, state_dict=state_dict) + return writer + else: + save_state_dict(state_dict, checkpoint_file, use_safetensors) + return None + + +def load_dist_model( + model: ModelWrapper, + model_metadata: Dict, + checkpoint: str, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, +): + """ + Load model from a single file with the given path of checkpoint. + + Args: + model (nn.Module): The model to be loaded. + checkpoint_index_file (str): Path to the checkpoint file. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled. + """ + + model_before_wrapping = model + model = model.unwrap() + + metadata_loaded = load_metadata(checkpoint) + + load_files = {} + covered_shards = {} + for key, item in model_metadata.items(): + offsets = item["offsets"] + lengths = item["lengths"] + assert ( + item["global_shape"] == metadata_loaded[key]["global_shape"] + ), f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}" + shards = metadata_loaded[key]["shards"] + covering_shards = find_covering_shards(shards=shards, target_offsets=offsets, target_lengths=lengths) + covered_shards[key] = covering_shards + for rank, shard in covering_shards.items(): + if rank not in load_files: + load_files[rank] = set() + load_files[rank].add(shard["file"]) + + dtype = None + for rank, files in load_files.items(): + for file in files: + file_path = os.path.join(checkpoint, file) + state_dict_shard = load_state_dict(file_path) + for key, weight in state_dict_shard.items(): + if key not in covered_shards: + continue + if dtype == None: + dtype = weight.dtype + covered_shards[key][rank]["weight"] = weight + state_dict = {} + for key, shards in covered_shards.items(): + state = assemble_tensor_from_shards_partial( + shards, model_metadata[key]["offsets"], model_metadata[key]["lengths"], dtype=dtype + ) + state_dict[key] = state + + if not low_cpu_mem_mode: + state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) + + load_state_dict_into_dist_model(model=model, state_dict=state_dict) + + # Update master params if mixed-precision training is enabled. + model_before_wrapping.update_master_params() + + +def save_dist_sharded_model( + model: ModelWrapper, + model_metadata: Dict, + checkpoint: str, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + use_async: bool = False, + dist_id: int = 0, + pinned_state_dicts = None, +) -> None: + """ + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_model.-000XX.bin" + + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. + """ + + model = model.unwrap() + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + # Devices along the same dp_group share the same copies of model. + # So only let the device with dp_rank == 0 and sp_rank == 0 save the model. + + if use_async: + if id(model) not in pinned_state_dicts: + pinned_state_dicts[id(model)] = {} + pinned_state_dicts = pinned_state_dicts[id(model)] + else: + pinned_state_dicts = None + state_dict_shard = dist_model_sharder( + model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts + ) + weights_name, _ = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + + # Manage filenames of sharded weights and index file for each pipeline stage. + weights_name = weights_name.replace(".bin", f"-dist-{dist_id:05d}-shard.bin") + weights_name = weights_name.replace(".safetensors", f"-dist-{dist_id:05d}-shard.safetensors") + metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}") + async_writers = [] + if use_async: + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=True, + state_preprocess=False, + ) + async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=True, + use_safetensors=use_safetensors, + use_pp_format=True, + ) + for k, _ in model_metadata.items(): + model_metadata[k]["file"] = index_file.get_checkpoint_file(k) + + save_metadata(model_metadata, metadata_file, total_size=total_size) + return async_writers diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index c38958ee31b9..d5ed5b848de3 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -309,4 +309,4 @@ def load_sharded_model( ) def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None: - raise NotImplementedError + raise NotImplementedError \ No newline at end of file diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index b72c576386aa..1ba8a617b8df 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -47,6 +47,14 @@ sharded_optimizer_loading_epilogue, ) +from .distributed_checkpoint_utils import ( + save_dist_sharded_model, + save_dist_unshard_model, + load_dist_model, + is_pytorch_model_meta_dist_file, + create_model_metadata +) + try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX except ImportError: @@ -230,6 +238,16 @@ def save_sharded_model( assert isinstance(model, ModelWrapper), "Please boost the model before saving!" model._force_wait_all_gather() + + if gather_dtensor: + if self.dp_rank != 0 and self.sp_rank != 0: + return + dist_id = self.tp_size * self.pp_rank + self.tp_rank + model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) + async_writers = save_dist_sharded_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, prefix=prefix, size_per_shard=size_per_shard, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts) + self.async_writers.extend(async_writers) + return + model = model.unwrap() if os.path.isfile(checkpoint): @@ -374,6 +392,13 @@ def load_sharded_model( """ assert isinstance(model, ModelWrapper), "Please boost the model before loading!" model._force_wait_all_gather() + + if is_pytorch_model_meta_dist_file(checkpoint_index_file): + model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) + checkpoint = checkpoint_index_file.parent + load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads) + return + model_before_wrapping = model # backup for model before wrapping model = model.unwrap() @@ -762,6 +787,18 @@ def save_unsharded_model( assert isinstance(model, ModelWrapper), "Please boost the model before saving!" model._force_wait_all_gather() + + model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) + + if gather_dtensor: + if self.dp_rank != 0 and self.sp_rank != 0: + return + dist_id = self.tp_size * self.pp_rank + self.tp_rank + writer= save_dist_unshard_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts) + if writer is not None: + self.async_writers.append(writer) + return + model = model.unwrap() if self.dp_rank != 0: return @@ -829,6 +866,14 @@ def load_unsharded_model( assert isinstance(model, ModelWrapper), "Please boost the model before loading!" model._force_wait_all_gather() + + if os.path.isdir(checkpoint): + for filename in os.listdir(checkpoint): + if is_pytorch_model_meta_dist_file(filename): + model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) + load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads) + return + strict = False model_before_wrapping = model model = model.unwrap() @@ -1058,6 +1103,7 @@ def gather_from_sharded_optimizer_state( dist.all_gather(gather_tensor, v, group=dp_group) v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) + # Then gather TP shards. partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) if partition_dim is not None: diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 6cfdf695ba5c..0a55c52e0470 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -854,11 +854,8 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: # check if there is only one a file ending with .index.json in this directory index_files = list(checkpoint_path.glob("*.index.*json")) - if len(index_files) == 1: + if len(index_files) >= 1: return True, index_files[0] - elif len(index_files) > 1: - # Used for distributed checkpoint IO, where the metadata is stored across multiple files. - return True, checkpoint_path else: return False, None else: diff --git a/tests/test_checkpoint_io/test_dist_checkpointio.py b/tests/test_checkpoint_io/test_dist_checkpointio.py index dcbb21789290..c009c1d26ecc 100644 --- a/tests/test_checkpoint_io/test_dist_checkpointio.py +++ b/tests/test_checkpoint_io/test_dist_checkpointio.py @@ -7,7 +7,6 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin -from colossalai.checkpoint_io import DistributedCheckpointIO from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import ( @@ -44,14 +43,6 @@ def exam_state_dict( test_config_0, test_config_1 = test_config plugin_0 = HybridParallelPlugin(**test_config_0) booster_0 = Booster(plugin=plugin_0) - hybrid_ckp_0 = booster_0.checkpoint_io - booster_0.checkpoint_io = DistributedCheckpointIO( - hybrid_ckp_0.global_dp_group, - hybrid_ckp_0.pp_group, - hybrid_ckp_0.tp_group, - hybrid_ckp_0.sp_group, - hybrid_ckp_0.use_zero, - ) def _criterion(outputs, inputs): outputs = output_transform_fn(outputs) @@ -88,7 +79,7 @@ def _preprocess_data(data): model_ckpt_path_0 = f"{tempdir}/model_0" booster_0.save_model( - model_0, model_ckpt_path_0, shard=shard, size_per_shard=size_per_shard, use_async=use_async + model_0, model_ckpt_path_0, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async ) booster_0.checkpoint_io._sync_d2h() booster_0.checkpoint_io._sync_io() @@ -96,14 +87,6 @@ def _preprocess_data(data): plugin_1 = HybridParallelPlugin(**test_config_1) booster_1 = Booster(plugin=plugin_1) - hybrid_ckp_1 = booster_1.checkpoint_io - booster_1.checkpoint_io = DistributedCheckpointIO( - hybrid_ckp_1.global_dp_group, - hybrid_ckp_1.pp_group, - hybrid_ckp_1.tp_group, - hybrid_ckp_1.sp_group, - hybrid_ckp_1.use_zero, - ) model_1 = model_fn().cuda() optimizer_1 = Adam(model_1.parameters(), lr=1e-3) @@ -113,7 +96,7 @@ def _preprocess_data(data): model_ckpt_path_1 = f"{tempdir}/model_1" booster_1.save_model( - model_1, model_ckpt_path_1, shard=shard, size_per_shard=size_per_shard, use_async=use_async + model_1, model_ckpt_path_1, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async ) booster_1.checkpoint_io._sync_d2h() booster_1.checkpoint_io._sync_io() From a28fdde3780f9ce3e2a29aec9e1be8372e33dd85 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Jan 2025 03:55:39 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../distributed_checkpoint_utils.py | 30 ++++++---- .../checkpoint_io/general_checkpoint_io.py | 2 +- .../hybrid_parallel_checkpoint_io.py | 60 ++++++++++++++----- .../test_dist_checkpointio.py | 14 ++++- 4 files changed, 77 insertions(+), 29 deletions(-) diff --git a/colossalai/checkpoint_io/distributed_checkpoint_utils.py b/colossalai/checkpoint_io/distributed_checkpoint_utils.py index 72ee327d81ae..3654ee94d6c7 100644 --- a/colossalai/checkpoint_io/distributed_checkpoint_utils.py +++ b/colossalai/checkpoint_io/distributed_checkpoint_utils.py @@ -58,6 +58,7 @@ def dist_model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = destination[extra_state_key] = extra_state return destination + def load_state_dict_into_dist_model( model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False ): @@ -86,11 +87,12 @@ def load_state_dict_into_dist_model( extra_state.copy_(state_dict[extra_state_key]) return destination + def create_model_metadata( model: nn.Module, prefix: str = "", - tp_size = None, - tp_rank = None, + tp_size=None, + tp_rank=None, ): param_origin_shape = model.param_origin_shape model = model.unwrap() @@ -110,11 +112,12 @@ def create_model_metadata( partition_size = param.shape[tp_partition_dim] model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * tp_rank if tp_rank == tp_size - 1: - model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[ - tp_partition_dim - ] - (partition_size * (tp_size - 1)) + model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[tp_partition_dim] - ( + partition_size * (tp_size - 1) + ) return model_metadata + def save_metadata(model_metadata, metadata_file, checkpoint_file=None, total_size=None): metadata_dicts = { "checkpoint_version": "1.0", @@ -133,6 +136,7 @@ def save_metadata(model_metadata, metadata_file, checkpoint_file=None, total_siz with open(metadata_file, "w") as json_file: json.dump(metadata_dicts, json_file, indent=4) + def load_metadata(checkpoint: str): metadata_dict = {} for filename in os.listdir(checkpoint): @@ -197,6 +201,7 @@ def find_covering_shards(shards, target_offsets, target_lengths): assert total_lengths == global_shape return covering_shards + def extract_weight_from_shard_partial(shard, target_offsets, target_lengths): """ Extract the target range of weights from shard data, supporting partial overlap. @@ -233,6 +238,7 @@ def extract_weight_from_shard_partial(shard, target_offsets, target_lengths): target_weight = weight[tuple(slices)] return target_weight, target_slices + def assemble_tensor_from_shards_partial(shards, target_offsets, target_lengths, dtype): target_tensor = torch.zeros(target_lengths, dtype=dtype) @@ -310,7 +316,13 @@ def dist_model_sharder( def save_dist_unshard_model( - model: ModelWrapper, model_metadata: Dict, checkpoint: str, use_safetensors: bool, use_async: bool = False, dist_id = 0, pinned_state_dicts = None + model: ModelWrapper, + model_metadata: Dict, + checkpoint: str, + use_safetensors: bool, + use_async: bool = False, + dist_id=0, + pinned_state_dicts=None, ): """ Save model state dict to a single file with given checkpointing path. @@ -426,7 +438,7 @@ def save_dist_sharded_model( use_safetensors: bool = False, use_async: bool = False, dist_id: int = 0, - pinned_state_dicts = None, + pinned_state_dicts=None, ) -> None: """ Save sharded model checkpoint under the given checkpointing path. @@ -463,9 +475,7 @@ def save_dist_sharded_model( pinned_state_dicts = pinned_state_dicts[id(model)] else: pinned_state_dicts = None - state_dict_shard = dist_model_sharder( - model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts - ) + state_dict_shard = dist_model_sharder(model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts) weights_name, _ = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index d5ed5b848de3..c38958ee31b9 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -309,4 +309,4 @@ def load_sharded_model( ) def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 1ba8a617b8df..cbad7d78854a 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -24,6 +24,13 @@ from colossalai.utils import get_current_device, get_non_persistent_buffers_set from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat +from .distributed_checkpoint_utils import ( + create_model_metadata, + is_pytorch_model_meta_dist_file, + load_dist_model, + save_dist_sharded_model, + save_dist_unshard_model, +) from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile from .utils import ( @@ -47,14 +54,6 @@ sharded_optimizer_loading_epilogue, ) -from .distributed_checkpoint_utils import ( - save_dist_sharded_model, - save_dist_unshard_model, - load_dist_model, - is_pytorch_model_meta_dist_file, - create_model_metadata -) - try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX except ImportError: @@ -244,10 +243,20 @@ def save_sharded_model( return dist_id = self.tp_size * self.pp_rank + self.tp_rank model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) - async_writers = save_dist_sharded_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, prefix=prefix, size_per_shard=size_per_shard, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts) + async_writers = save_dist_sharded_model( + model=model, + model_metadata=model_metadata, + checkpoint=checkpoint, + prefix=prefix, + size_per_shard=size_per_shard, + use_safetensors=use_safetensors, + use_async=use_async, + dist_id=dist_id, + pinned_state_dicts=self.pinned_state_dicts, + ) self.async_writers.extend(async_writers) return - + model = model.unwrap() if os.path.isfile(checkpoint): @@ -396,9 +405,15 @@ def load_sharded_model( if is_pytorch_model_meta_dist_file(checkpoint_index_file): model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) checkpoint = checkpoint_index_file.parent - load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads) + load_dist_model( + model=model, + model_metadata=model_metadata, + checkpoint=checkpoint, + low_cpu_mem_mode=low_cpu_mem_mode, + num_threads=num_threads, + ) return - + model_before_wrapping = model # backup for model before wrapping model = model.unwrap() @@ -794,11 +809,19 @@ def save_unsharded_model( if self.dp_rank != 0 and self.sp_rank != 0: return dist_id = self.tp_size * self.pp_rank + self.tp_rank - writer= save_dist_unshard_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts) + writer = save_dist_unshard_model( + model=model, + model_metadata=model_metadata, + checkpoint=checkpoint, + use_safetensors=use_safetensors, + use_async=use_async, + dist_id=dist_id, + pinned_state_dicts=self.pinned_state_dicts, + ) if writer is not None: self.async_writers.append(writer) return - + model = model.unwrap() if self.dp_rank != 0: return @@ -871,7 +894,13 @@ def load_unsharded_model( for filename in os.listdir(checkpoint): if is_pytorch_model_meta_dist_file(filename): model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) - load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads) + load_dist_model( + model=model, + model_metadata=model_metadata, + checkpoint=checkpoint, + low_cpu_mem_mode=low_cpu_mem_mode, + num_threads=num_threads, + ) return strict = False @@ -1103,7 +1132,6 @@ def gather_from_sharded_optimizer_state( dist.all_gather(gather_tensor, v, group=dp_group) v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) - # Then gather TP shards. partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) if partition_dim is not None: diff --git a/tests/test_checkpoint_io/test_dist_checkpointio.py b/tests/test_checkpoint_io/test_dist_checkpointio.py index c009c1d26ecc..09d6eb345bab 100644 --- a/tests/test_checkpoint_io/test_dist_checkpointio.py +++ b/tests/test_checkpoint_io/test_dist_checkpointio.py @@ -79,7 +79,12 @@ def _preprocess_data(data): model_ckpt_path_0 = f"{tempdir}/model_0" booster_0.save_model( - model_0, model_ckpt_path_0, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async + model_0, + model_ckpt_path_0, + shard=shard, + gather_dtensor=True, + size_per_shard=size_per_shard, + use_async=use_async, ) booster_0.checkpoint_io._sync_d2h() booster_0.checkpoint_io._sync_io() @@ -96,7 +101,12 @@ def _preprocess_data(data): model_ckpt_path_1 = f"{tempdir}/model_1" booster_1.save_model( - model_1, model_ckpt_path_1, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async + model_1, + model_ckpt_path_1, + shard=shard, + gather_dtensor=True, + size_per_shard=size_per_shard, + use_async=use_async, ) booster_1.checkpoint_io._sync_d2h() booster_1.checkpoint_io._sync_io() From f388bbe26008501294a0a90a2d4eb8373701e392 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 20 Jan 2025 14:10:44 +0800 Subject: [PATCH 5/8] fix --- colossalai/checkpoint_io/distributed_checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/checkpoint_io/distributed_checkpoint_utils.py b/colossalai/checkpoint_io/distributed_checkpoint_utils.py index 3654ee94d6c7..22b79d977332 100644 --- a/colossalai/checkpoint_io/distributed_checkpoint_utils.py +++ b/colossalai/checkpoint_io/distributed_checkpoint_utils.py @@ -408,7 +408,7 @@ def load_dist_model( file_path = os.path.join(checkpoint, file) state_dict_shard = load_state_dict(file_path) for key, weight in state_dict_shard.items(): - if key not in covered_shards: + if key not in covered_shards or rank not in covered_shards[key]: continue if dtype == None: dtype = weight.dtype From c5b088219f4f3ab08fa97f441208bc47268da656 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 21 Jan 2025 15:00:56 +0800 Subject: [PATCH 6/8] Remove duplicates --- .../distributed_checkpoint_utils.py | 302 ++---------------- .../hybrid_parallel_checkpoint_io.py | 197 ++++++------ .../test_dist_checkpointio.py | 4 +- 3 files changed, 136 insertions(+), 367 deletions(-) diff --git a/colossalai/checkpoint_io/distributed_checkpoint_utils.py b/colossalai/checkpoint_io/distributed_checkpoint_utils.py index 22b79d977332..a56386a1ffa3 100644 --- a/colossalai/checkpoint_io/distributed_checkpoint_utils.py +++ b/colossalai/checkpoint_io/distributed_checkpoint_utils.py @@ -1,8 +1,6 @@ import json -import logging import os -from pathlib import Path -from typing import Dict, Iterator, Optional, OrderedDict, Tuple +from typing import Dict import torch import torch.distributed as dist @@ -10,89 +8,43 @@ from torch.distributed.distributed_c10d import _get_default_group from colossalai.interface import ModelWrapper -from colossalai.utils import get_non_persistent_buffers_set +from colossalai.shardformer.layer.parallel_module import ParallelModule +from contextlib import contextmanager -from .index_file import CheckpointIndexFile from .utils import ( - StateDictSharder, - async_save_state_dict_shards, - create_pinned_state_dict, - get_model_base_filenames, load_state_dict, - save_state_dict, - save_state_dict_shards, search_tp_partition_dim, ) -try: - from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX -except ImportError: - _EXTRA_STATE_KEY_SUFFIX = "_extra_state" - MODEL_META_PREFIX = "pytorch_model-meta-dist-" MODEL_WEIGHT_PREFIX = "pytorch_model-dist-" SHARD_META_SUFFIX = ".index.json" +UNSHARD_META_SUFFIX = ".json" -def dist_model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False): - destination = dict() - # Save parameters. - for name, param in model.named_parameters(): - if param is None: - continue - destination[prefix + name] = param - # Save buffers. - non_persist_buffers_set = get_non_persistent_buffers_set(model) - for name, buf in model.named_buffers(): - if buf is not None and name not in non_persist_buffers_set: - buffer = buf if keep_vars else buf.detach() - destination[prefix + name] = buffer - - # Save extra states. - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if ( - getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) - is not torch.nn.Module.get_extra_state - ): - extra_state = model.get_extra_state() - destination[extra_state_key] = extra_state - return destination - - -def load_state_dict_into_dist_model( - model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False -): - destination = dict() - # Save parameters. - for name, param in model.named_parameters(): - if param is None: - continue - with torch.no_grad(): - param.copy_(state_dict[prefix + name]) - # Save buffers. - non_persist_buffers_set = get_non_persistent_buffers_set(model) - for name, buf in model.named_buffers(): - if buf is not None and name not in non_persist_buffers_set: - with torch.no_grad(): - buf.copy_(state_dict[prefix + name]) - - # Save extra states. - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if ( - getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) - is not torch.nn.Module.get_extra_state - ): - extra_state = model.get_extra_state() - with torch.no_grad(): - extra_state.copy_(state_dict[extra_state_key]) - return destination +@contextmanager +def RestoreDefaultStateDictBehavior(model): + original_methods = {} + for name, module in model.named_modules(): + if isinstance(module, ParallelModule): + original_methods[module] = (module._save_to_state_dict, module._load_from_state_dict) + module._save_to_state_dict = nn.Module._save_to_state_dict.__get__(module, nn.Module) + module._load_from_state_dict = nn.Module._load_from_state_dict.__get__(module, nn.Module) + try: + yield model + finally: + for module, original_method in original_methods.items(): + module._save_to_state_dict, module._load_from_state_dict = original_method + def create_model_metadata( - model: nn.Module, + model: ModelWrapper, prefix: str = "", - tp_size=None, - tp_rank=None, + tp_size: int = None, + tp_rank: int = None, + zero_size: int = None, + zero_rank: int = None, ): param_origin_shape = model.param_origin_shape model = model.unwrap() @@ -105,7 +57,7 @@ def create_model_metadata( tp_partition_dim = search_tp_partition_dim( current_shape=param.shape, original_shape=original_shape, tp_size=tp_size ) - model_metadata[prefix + name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int) + model_metadata[prefix + name]["offsets"] = [0] * len(original_shape) model_metadata[prefix + name]["lengths"] = list(param.shape) model_metadata[prefix + name]["global_shape"] = list(original_shape) if tp_partition_dim is not None: @@ -257,119 +209,9 @@ def is_pytorch_model_meta_dist_file(checkpoint_index_file): return False -def dist_model_sharder( - model: nn.Module, - prefix: str = "", - keep_vars: bool = False, - size_per_shard: int = 1024, - pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, -) -> Iterator[Tuple[OrderedDict, int]]: - # An internel method that breaks state_dict of model into shards within limited size. - - state_dict_sharder = StateDictSharder(size_per_shard) - - # Save parameters. - for name, param in model.named_parameters(): - if param is None: - continue - if pinned_state_dicts is not None: - if (prefix + name) not in pinned_state_dicts: - pinned_state_dicts[prefix + name] = torch.empty_like(param, pin_memory=True, device="cpu") - pinned_state_dicts[prefix + name].copy_(param) - param = pinned_state_dicts[prefix + name] - block, block_size = state_dict_sharder.append_param(prefix + name, param) - if block is not None: - yield block, block_size - - # Save buffers. - non_persist_buffers_set = get_non_persistent_buffers_set(model) - for name, buf in model.named_buffers(): - if buf is not None and name not in non_persist_buffers_set: - buffer = buf if keep_vars else buf.detach() - if pinned_state_dicts is not None: - if (prefix + name) not in pinned_state_dicts: - pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu") - pinned_state_dicts[prefix + name].copy_(buffer) - buffer = pinned_state_dicts[prefix + name] - block, block_size = state_dict_sharder.append_param(prefix + name, buffer) - if block is not None: - yield block, block_size - - # Save extra states. - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if ( - getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) - is not torch.nn.Module.get_extra_state - ): - extra_state = model.get_extra_state() - if pinned_state_dicts is not None: - if extra_state_key not in pinned_state_dicts: - pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu") - pinned_state_dicts[extra_state_key].copy_(extra_state) - extra_state = pinned_state_dicts[extra_state_key] - block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) - if block is not None: - yield block, block_size - - # Return the last block in sharder. - yield state_dict_sharder.current_block, state_dict_sharder.current_block_size - - -def save_dist_unshard_model( - model: ModelWrapper, - model_metadata: Dict, - checkpoint: str, - use_safetensors: bool, - use_async: bool = False, - dist_id=0, - pinned_state_dicts=None, -): - """ - Save model state dict to a single file with given checkpointing path. - - Args: - model (nn.Module): Model on local device to be saved. - checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path. - gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True. - use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. - use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. - """ - - model = model.unwrap() - - # The logic of collecting parameter shards along tp degree - # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. - state_dict = dist_model_state_dict(model) - - Path(checkpoint).mkdir(parents=True, exist_ok=True) - file_name = f"{MODEL_WEIGHT_PREFIX}{dist_id:05d}.bin" - if use_async: - file_name = file_name.replace(".bin", ".safetensors") - checkpoint_file = os.path.join(checkpoint, file_name) - metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}.json") - save_metadata(model_metadata, metadata_file, file_name) - - if use_async: - from colossalai.utils.safetensors import save - - if id(model) not in pinned_state_dicts: - pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) - for name, param in state_dict.items(): - pinned_state_dicts[id(model)][name].copy_(param) - state_dict[name] = pinned_state_dicts[id(model)][name] - writer = save(path=checkpoint_file, state_dict=state_dict) - return writer - else: - save_state_dict(state_dict, checkpoint_file, use_safetensors) - return None - - def load_dist_model( - model: ModelWrapper, model_metadata: Dict, checkpoint: str, - low_cpu_mem_mode: bool = True, - num_threads: int = 1, ): """ Load model from a single file with the given path of checkpoint. @@ -380,10 +222,6 @@ def load_dist_model( strict (bool, optional): For name matching during loading state_dict. Defaults to False. This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled. """ - - model_before_wrapping = model - model = model.unwrap() - metadata_loaded = load_metadata(checkpoint) load_files = {} @@ -420,92 +258,14 @@ def load_dist_model( ) state_dict[key] = state - if not low_cpu_mem_mode: - state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) - - load_state_dict_into_dist_model(model=model, state_dict=state_dict) - - # Update master params if mixed-precision training is enabled. - model_before_wrapping.update_master_params() + return state_dict - -def save_dist_sharded_model( - model: ModelWrapper, - model_metadata: Dict, - checkpoint: str, - prefix: Optional[str] = None, - size_per_shard: int = 1024, - use_safetensors: bool = False, - use_async: bool = False, - dist_id: int = 0, - pinned_state_dicts=None, -) -> None: - """ - Save sharded model checkpoint under the given checkpointing path. - The following files will be created under the path: - - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. - - Multiple files that store state tensors of models. - If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". - If pipeline parallelism is not used, "pytorch_model.-000XX.bin" - - - Args: - model (nn.Module): Model on local device to be saved. - checkpoint (str): Checkpointing path which should be a directory path. - gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. - prefix (str, optional): Perfix of file to save. Defaults to None. - size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. - use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. - use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. - """ - - model = model.unwrap() - - if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") - return - - Path(checkpoint).mkdir(parents=True, exist_ok=True) - # Devices along the same dp_group share the same copies of model. - # So only let the device with dp_rank == 0 and sp_rank == 0 save the model. - - if use_async: - if id(model) not in pinned_state_dicts: - pinned_state_dicts[id(model)] = {} - pinned_state_dicts = pinned_state_dicts[id(model)] - else: - pinned_state_dicts = None - state_dict_shard = dist_model_sharder(model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts) - weights_name, _ = get_model_base_filenames(prefix, use_safetensors) - index_file = CheckpointIndexFile(checkpoint) - - # Manage filenames of sharded weights and index file for each pipeline stage. +def get_dist_files_name(weights_name, dist_id): weights_name = weights_name.replace(".bin", f"-dist-{dist_id:05d}-shard.bin") weights_name = weights_name.replace(".safetensors", f"-dist-{dist_id:05d}-shard.safetensors") - metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}") - async_writers = [] - if use_async: - total_size, writers = async_save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=True, - state_preprocess=False, - ) - async_writers.extend(writers) - else: - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=True, - use_safetensors=use_safetensors, - use_pp_format=True, - ) - for k, _ in model_metadata.items(): - model_metadata[k]["file"] = index_file.get_checkpoint_file(k) + return weights_name - save_metadata(model_metadata, metadata_file, total_size=total_size) - return async_writers +def get_dist_meta_file_name(checkpoint, dist_id, use_safetensors): + if use_safetensors: + return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}") + return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{UNSHARD_META_SUFFIX}") \ No newline at end of file diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index cbad7d78854a..5827020eb8d8 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -5,6 +5,7 @@ from pathlib import Path from shutil import rmtree from typing import Dict, Iterator, Optional, OrderedDict, Tuple +from contextlib import nullcontext import torch import torch.distributed as dist @@ -28,8 +29,11 @@ create_model_metadata, is_pytorch_model_meta_dist_file, load_dist_model, - save_dist_sharded_model, - save_dist_unshard_model, + save_metadata, + get_dist_files_name, + get_dist_meta_file_name, + MODEL_WEIGHT_PREFIX, + RestoreDefaultStateDictBehavior ) from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -97,13 +101,14 @@ def __init__( self.verbose = verbose self.coordinator = DistCoordinator() - @staticmethod def _model_sharder( + self, model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024, pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, + gather_dtensor: bool = True, ) -> Iterator[Tuple[OrderedDict, int]]: # An internel method that breaks state_dict of model into shards within limited size. @@ -113,10 +118,15 @@ def _model_sharder( for name, param in model.named_parameters(): if param is None: continue - # Gather tensor pieces when using tensor parallel. - if is_padded_tensor(param): - param = to_unpadded_tensor(param) - param_ = gather_distributed_param(param, keep_vars=False) + + if gather_dtensor: + # Gather tensor pieces when using tensor parallel. + if is_padded_tensor(param): + param = to_unpadded_tensor(param) + param_ = gather_distributed_param(param, keep_vars=False) + else: + param_ = param + if pinned_state_dicts is not None: if (prefix + name) not in pinned_state_dicts: pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu") @@ -237,26 +247,14 @@ def save_sharded_model( assert isinstance(model, ModelWrapper), "Please boost the model before saving!" model._force_wait_all_gather() - - if gather_dtensor: - if self.dp_rank != 0 and self.sp_rank != 0: - return - dist_id = self.tp_size * self.pp_rank + self.tp_rank - model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) - async_writers = save_dist_sharded_model( - model=model, - model_metadata=model_metadata, - checkpoint=checkpoint, - prefix=prefix, - size_per_shard=size_per_shard, - use_safetensors=use_safetensors, - use_async=use_async, - dist_id=dist_id, - pinned_state_dicts=self.pinned_state_dicts, - ) - self.async_writers.extend(async_writers) + if self.dp_rank != 0 and self.sp_rank != 0: return - + + model_metadata = None + if not gather_dtensor: + # Manage filenames of sharded weights and index file for each pipeline stage. + model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) + model = model.unwrap() if os.path.isfile(checkpoint): @@ -264,28 +262,30 @@ def save_sharded_model( return Path(checkpoint).mkdir(parents=True, exist_ok=True) - # Devices along the same dp_group share the same copies of model. - # So only let the device with dp_rank == 0 save the model. - if self.dp_rank != 0: - return # Then collect the sharded parameters & buffers along tp_group. # Only devices with tp_rank == 0 are responsible for model saving. - control_saving = self.tp_rank == 0 and self.sp_rank == 0 + control_saving = self.tp_rank == 0 if gather_dtensor else True if control_saving and use_async: if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = {} pinned_state_dicts = self.pinned_state_dicts[id(model)] else: pinned_state_dicts = None - state_dict_shard = HybridParallelCheckpointIO._model_sharder( - model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts + state_dict_shard = self._model_sharder( + model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts, gather_dtensor=gather_dtensor ) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) - if self.pp_size == 1: + if self.pp_size == 1 or not gather_dtensor: # When pipeline is not used, save the model shards as in general checkpointIO + if not gather_dtensor: + dist_id = self.tp_size * self.pp_rank + self.tp_rank + weights_name = get_dist_files_name(weights_name=weights_name, dist_id=dist_id) + metadata_file = get_dist_meta_file_name(checkpoint=checkpoint, dist_id=dist_id, use_safetensors=use_safetensors) + if use_async: total_size, writers = async_save_state_dict_shards( sharded_state_dict=state_dict_shard, @@ -305,16 +305,22 @@ def save_sharded_model( is_master=control_saving, use_safetensors=use_safetensors, ) - if control_saving: - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model, checkpoint) - if self.verbose and self.coordinator.is_master(): - logging.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + if not gather_dtensor: + # saving metadata for distributed checkpoint + for k, _ in model_metadata.items(): + model_metadata[k]["file"] = index_file.get_checkpoint_file(k) + save_metadata(model_metadata, metadata_file, total_size=total_size) + else: + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) else: # When pipeline is used, each stage produces its own shard files and index files. @@ -405,13 +411,15 @@ def load_sharded_model( if is_pytorch_model_meta_dist_file(checkpoint_index_file): model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) checkpoint = checkpoint_index_file.parent - load_dist_model( - model=model, + state_dict = load_dist_model( model_metadata=model_metadata, checkpoint=checkpoint, - low_cpu_mem_mode=low_cpu_mem_mode, - num_threads=num_threads, ) + model = model.unwrap() + with RestoreDefaultStateDictBehavior(model): + load_state_dict_into_model( + model, state_dict, missing_keys=[], strict=False, load_sub_module=True + ) return model_before_wrapping = model # backup for model before wrapping @@ -803,47 +811,43 @@ def save_unsharded_model( assert isinstance(model, ModelWrapper), "Please boost the model before saving!" model._force_wait_all_gather() - model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) + if self.dp_rank != 0 and self.sp_rank != 0: + return - if gather_dtensor: - if self.dp_rank != 0 and self.sp_rank != 0: - return + if not gather_dtensor: dist_id = self.tp_size * self.pp_rank + self.tp_rank - writer = save_dist_unshard_model( - model=model, - model_metadata=model_metadata, - checkpoint=checkpoint, - use_safetensors=use_safetensors, - use_async=use_async, - dist_id=dist_id, - pinned_state_dicts=self.pinned_state_dicts, - ) - if writer is not None: - self.async_writers.append(writer) - return + Path(checkpoint).mkdir(parents=True, exist_ok=True) + model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) + checkpoint_file = os.path.join(checkpoint, f"{MODEL_WEIGHT_PREFIX}{dist_id:05d}.bin") + if use_async: + checkpoint_file = checkpoint_file.replace(".bin", f".safetensors") + metadata_file = get_dist_meta_file_name(checkpoint=checkpoint, dist_id=dist_id, use_safetensors=use_async) + save_metadata(model_metadata=model_metadata, metadata_file=metadata_file, checkpoint_file=checkpoint_file) + else: + checkpoint_file = checkpoint model = model.unwrap() - if self.dp_rank != 0: - return # The logic of collecting parameter shards along tp degree # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. - state_dict = model.state_dict() - if self.pp_size == 1: - # When pipeline is not used, let master rank directly save the collected state_dict. - if self.tp_rank == 0: - if use_async: - from colossalai.utils.safetensors import save + ctx = RestoreDefaultStateDictBehavior(model) if not gather_dtensor else nullcontext() + with ctx: + state_dict = model.state_dict() - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) - for name, param in state_dict.items(): - self.pinned_state_dicts[id(model)][name].copy_(param) - state_dict[name] = self.pinned_state_dicts[id(model)][name] - writer = save(path=checkpoint, state_dict=state_dict) - self.async_writers.append(writer) - else: - save_state_dict(state_dict, checkpoint, use_safetensors) + if (self.pp_size == 1 and self.tp_rank == 0) or not gather_dtensor: + # When pipeline is not used, let master rank directly save the collected state_dict. + if use_async: + from colossalai.utils.safetensors import save + + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + for name, param in state_dict.items(): + self.pinned_state_dicts[id(model)][name].copy_(param) + state_dict[name] = self.pinned_state_dicts[id(model)][name] + writer = save(path=checkpoint_file, state_dict=state_dict) + self.async_writers.append(writer) + else: + save_state_dict(state_dict, checkpoint_file, use_safetensors) else: # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. state_dict_list = [None for _ in range(self.pp_size)] @@ -862,10 +866,10 @@ def save_unsharded_model( for name, param in complete_state_dict.items(): self.pinned_state_dicts[id(model)][name].copy_(param) complete_state_dict[name] = self.pinned_state_dicts[id(model)][name] - writer = save(path=checkpoint, state_dict=complete_state_dict) + writer = save(path=checkpoint_file, state_dict=complete_state_dict) self.async_writers.append(writer) else: - save_state_dict(complete_state_dict, checkpoint, use_safetensors) + save_state_dict(complete_state_dict, checkpoint_file, use_safetensors) def load_unsharded_model( self, @@ -890,18 +894,16 @@ def load_unsharded_model( assert isinstance(model, ModelWrapper), "Please boost the model before loading!" model._force_wait_all_gather() + load_dtensor = False if os.path.isdir(checkpoint): for filename in os.listdir(checkpoint): if is_pytorch_model_meta_dist_file(filename): - model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) - load_dist_model( - model=model, - model_metadata=model_metadata, - checkpoint=checkpoint, - low_cpu_mem_mode=low_cpu_mem_mode, - num_threads=num_threads, - ) - return + load_dtensor = True + break + + model_metadata = None # used for dist model + if load_dtensor: + model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) strict = False model_before_wrapping = model @@ -910,10 +912,17 @@ def load_unsharded_model( # Load from checkpoint. Since the logic of breaking parameter shards along tp degree # has been implemented by _load_from_state_dict method of ParallelModule in Shardformer, # model.load_state_dict can be directly called. - state_dict = load_state_dict(checkpoint) + if load_dtensor: + state_dict = load_dist_model(model_metadata=model_metadata, checkpoint=checkpoint) + else: + state_dict = load_state_dict(checkpoint) + if not low_cpu_mem_mode: state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) - model.load_state_dict(state_dict, strict=strict) + + ctx = RestoreDefaultStateDictBehavior(model) if load_dtensor else nullcontext() + with ctx: + model.load_state_dict(state_dict, strict=strict) # Update master params if mixed-precision training is enabled. model_before_wrapping.update_master_params() diff --git a/tests/test_checkpoint_io/test_dist_checkpointio.py b/tests/test_checkpoint_io/test_dist_checkpointio.py index 09d6eb345bab..5aa9c4f1fadd 100644 --- a/tests/test_checkpoint_io/test_dist_checkpointio.py +++ b/tests/test_checkpoint_io/test_dist_checkpointio.py @@ -82,7 +82,7 @@ def _preprocess_data(data): model_0, model_ckpt_path_0, shard=shard, - gather_dtensor=True, + gather_dtensor=False, size_per_shard=size_per_shard, use_async=use_async, ) @@ -104,7 +104,7 @@ def _preprocess_data(data): model_1, model_ckpt_path_1, shard=shard, - gather_dtensor=True, + gather_dtensor=False, size_per_shard=size_per_shard, use_async=use_async, ) From 8e6902c8466acb3780cd629bd690d3bda913ecd7 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 22 Jan 2025 10:51:32 +0800 Subject: [PATCH 7/8] fix --- .../distributed_checkpoint_utils.py | 33 ----------------- .../hybrid_parallel_checkpoint_io.py | 37 ++++++++++++++++--- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/colossalai/checkpoint_io/distributed_checkpoint_utils.py b/colossalai/checkpoint_io/distributed_checkpoint_utils.py index a56386a1ffa3..563ec99dc21e 100644 --- a/colossalai/checkpoint_io/distributed_checkpoint_utils.py +++ b/colossalai/checkpoint_io/distributed_checkpoint_utils.py @@ -35,39 +35,6 @@ def RestoreDefaultStateDictBehavior(model): finally: for module, original_method in original_methods.items(): module._save_to_state_dict, module._load_from_state_dict = original_method - - - -def create_model_metadata( - model: ModelWrapper, - prefix: str = "", - tp_size: int = None, - tp_rank: int = None, - zero_size: int = None, - zero_rank: int = None, -): - param_origin_shape = model.param_origin_shape - model = model.unwrap() - model_metadata = {} - for name, param in model.named_parameters(): - if param is None: - continue - model_metadata[prefix + name] = {} - original_shape = param_origin_shape[name] - tp_partition_dim = search_tp_partition_dim( - current_shape=param.shape, original_shape=original_shape, tp_size=tp_size - ) - model_metadata[prefix + name]["offsets"] = [0] * len(original_shape) - model_metadata[prefix + name]["lengths"] = list(param.shape) - model_metadata[prefix + name]["global_shape"] = list(original_shape) - if tp_partition_dim is not None: - partition_size = param.shape[tp_partition_dim] - model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * tp_rank - if tp_rank == tp_size - 1: - model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[tp_partition_dim] - ( - partition_size * (tp_size - 1) - ) - return model_metadata def save_metadata(model_metadata, metadata_file, checkpoint_file=None, total_size=None): diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 5827020eb8d8..2c06cf4e80c1 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -26,7 +26,6 @@ from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat from .distributed_checkpoint_utils import ( - create_model_metadata, is_pytorch_model_meta_dist_file, load_dist_model, save_metadata, @@ -216,6 +215,34 @@ def _optimizer_sharder( # Return the last block in sharder. yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + def create_model_metadata( + self, + model: ModelWrapper, + prefix: str = "", + ): + param_origin_shape = model.param_origin_shape + model = model.unwrap() + model_metadata = {} + for name, param in model.named_parameters(): + if param is None: + continue + model_metadata[prefix + name] = {} + original_shape = param_origin_shape[name] + tp_partition_dim = search_tp_partition_dim( + current_shape=param.shape, original_shape=original_shape, tp_size=self.tp_size + ) + model_metadata[prefix + name]["offsets"] = [0] * len(original_shape) + model_metadata[prefix + name]["lengths"] = list(param.shape) + model_metadata[prefix + name]["global_shape"] = list(original_shape) + if tp_partition_dim is not None: + partition_size = param.shape[tp_partition_dim] + model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * self.tp_rank + if self.tp_rank == self.tp_size - 1: + model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[tp_partition_dim] - ( + partition_size * (self.tp_size - 1) + ) + return model_metadata + def save_sharded_model( self, model: ModelWrapper, @@ -253,7 +280,7 @@ def save_sharded_model( model_metadata = None if not gather_dtensor: # Manage filenames of sharded weights and index file for each pipeline stage. - model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) + model_metadata = self.create_model_metadata(model) model = model.unwrap() @@ -409,7 +436,7 @@ def load_sharded_model( model._force_wait_all_gather() if is_pytorch_model_meta_dist_file(checkpoint_index_file): - model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) + model_metadata = self.create_model_metadata(model) checkpoint = checkpoint_index_file.parent state_dict = load_dist_model( model_metadata=model_metadata, @@ -817,7 +844,7 @@ def save_unsharded_model( if not gather_dtensor: dist_id = self.tp_size * self.pp_rank + self.tp_rank Path(checkpoint).mkdir(parents=True, exist_ok=True) - model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) + model_metadata = self.create_model_metadata(model) checkpoint_file = os.path.join(checkpoint, f"{MODEL_WEIGHT_PREFIX}{dist_id:05d}.bin") if use_async: checkpoint_file = checkpoint_file.replace(".bin", f".safetensors") @@ -903,7 +930,7 @@ def load_unsharded_model( model_metadata = None # used for dist model if load_dtensor: - model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) + model_metadata = self.create_model_metadata(model) strict = False model_before_wrapping = model From 7e880a13012686c3b2260b5d2278602412277b46 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 13 Feb 2025 13:59:04 +0800 Subject: [PATCH 8/8] fix async io --- colossalai/checkpoint_io/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 0a55c52e0470..524fc3b2190e 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -309,12 +309,13 @@ def async_save_state_dict_shards( checkpoint_file_path = os.path.join(checkpoint, shard_file) if state_preprocess: - state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator=".") + state_dict, metadata = _flatten_optim_state_dict(state_dict=shard, seperator=".") else: state_dict = shard + metadata = None # Only save on master rank. - writer = save(checkpoint_file_path, state_dict=state_dict) + writer = save(checkpoint_file_path, state_dict=state_dict, metadata=metadata) writers.append(writer) shard_filenames.append(shard_file) del shard @@ -371,9 +372,10 @@ def async_move_save_state_dict_shards( checkpoint_file_path = os.path.join(checkpoint, shard_file) if state_preprocess: - state_dict, _ = _flatten_optim_state_dict(state_dict=shard) + state_dict, metadata = _flatten_optim_state_dict(state_dict=shard) else: state_dict = shard + metadata = None if pinned_state_dict is not None: sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()} @@ -382,7 +384,7 @@ def async_move_save_state_dict_shards( returned_state_dict.update(sub_pinned_state_dict) # Only save on master rank. - writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict) + writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict, metadata) writers.append(writer) shard_filenames.append(shard_file) del shard