Skip to content

Commit 2a15001

Browse files
committed
support distribute checkpoint io
1 parent 5b094a8 commit 2a15001

File tree

7 files changed

+781
-14
lines changed

7 files changed

+781
-14
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def __init__(
7878
self.require_grad_sync = True
7979
self.overlap_allgather = overlap_allgather
8080
self.use_fp8 = use_fp8
81+
self.param_origin_shape = {}
82+
for name, param in module.named_parameters():
83+
self.param_origin_shape[name] = param.shape
8184

8285
shardformer = ShardFormer(shard_config)
8386
if custom_policy is not None:

colossalai/checkpoint_io/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .checkpoint_io_base import CheckpointIO
22
from .general_checkpoint_io import GeneralCheckpointIO
33
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
4+
from.distributed_checkpoint_io import DistributedCheckpointIO
45
from .index_file import CheckpointIndexFile
56
from .moe_checkpoint import MoECheckpointIO
67

@@ -10,4 +11,5 @@
1011
"GeneralCheckpointIO",
1112
"HybridParallelCheckpointIO",
1213
"MoECheckpointIO",
14+
"DistributedCheckpointIO"
1315
]

0 commit comments

Comments
 (0)