Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ppfleetx/configs/nlp/gpt/gpt_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ Engine:
output_dir: ./output
ckpt_dir:

Model:
module: "GPTModule"
name: "GPT"
fused_linear: False


Data:
Train:
dataset:
Expand Down
31 changes: 31 additions & 0 deletions ppfleetx/configs/nlp/gpt/pretrain_gpt_1.3B.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
_base_: ./gpt_base.yaml

Global:
global_batch_size: 8
local_batch_size: 8
micro_batch_size: 8


Model:
vocab_size: 50304
hidden_size: 2048
num_layers: 24
num_attention_heads: 16
ffn_hidden_size:
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 1024
type_vocab_size: 16
initializer_range: 0.02
use_recompute: True
recompute_granularity:


Distributed:
dp_degree: 1
mp_degree: 1
pp_degree: 1
sharding:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
3 changes: 0 additions & 3 deletions ppfleetx/configs/nlp/gpt/pretrain_gpt_1.3B_dp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ Global:


Model:
module: "GPTModule"
name: "GPT"
vocab_size: 50304
hidden_size: 2048
num_layers: 24
Expand All @@ -21,7 +19,6 @@ Model:
initializer_range: 0.02
use_recompute: True
recompute_granularity:
fused_linear: False


Distributed:
Expand Down
31 changes: 31 additions & 0 deletions ppfleetx/configs/nlp/gpt/pretrain_gpt_175B_mp8_pp16.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
_base_: ./gpt_base.yaml

Global:
global_batch_size:
local_batch_size: 1536
micro_batch_size: 1


Model:
vocab_size: 51200
hidden_size: 12288
num_layers: 96
num_attention_heads: 96
ffn_hidden_size:
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 1024
type_vocab_size: 16
initializer_range: 0.02
use_recompute: True
recompute_granularity:


Distributed:
dp_degree:
mp_degree: 8
pp_degree: 16
sharding:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
31 changes: 31 additions & 0 deletions ppfleetx/configs/nlp/gpt/pretrain_gpt_345M.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
_base_: ./gpt_base.yaml

Global:
global_batch_size: 8
local_batch_size: 8
micro_batch_size: 8


Model:
vocab_size: 50304
hidden_size: 1024
num_layers: 24
num_attention_heads: 16
ffn_hidden_size: 4096
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 1024
type_vocab_size: 16
initializer_range: 0.02
use_recompute: False
recompute_granularity:


Distributed:
dp_degree: 1
mp_degree: 1
pp_degree: 1
sharding:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
31 changes: 31 additions & 0 deletions ppfleetx/configs/nlp/gpt/pretrain_gpt_6.7B_sharding16.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
_base_: ./gpt_base.yaml

Global:
global_batch_size:
local_batch_size: 8
micro_batch_size: 8


Model:
vocab_size: 50304
hidden_size: 4096
num_layers: 32
num_attention_heads: 32
ffn_hidden_size:
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 1024
type_vocab_size: 16
initializer_range: 0.02
use_recompute: True
recompute_granularity:


Distributed:
dp_degree:
mp_degree: 1
pp_degree: 1
sharding:
sharding_degree: 16
sharding_stage: 2
sharding_offload: False
2 changes: 1 addition & 1 deletion ppfleetx/models/language_model/gpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
GPTPretrainingCriterionHybird,
GPTForPretrainingHybrid)

from .single_model import GPTForPretraining, GPTPretrainingCriterion
from .single_model import GPTForPretraining, GPTPretrainingCriterion, GPTModel
2 changes: 0 additions & 2 deletions ppfleetx/models/language_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ def process_engine_configs(config):


def process_configs(config):

# process_dist_configs(config)
process_data_configs(config)
process_fused_configs(config)
process_model_configs(config)
Expand Down
15 changes: 9 additions & 6 deletions ppfleetx/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
__all__ = ['init_dist_env']


def set_dist_seed(seed):
# obtain rank message of hybrid parallel
hcg = fleet.get_hybrid_communicate_group()
mp_rank = hcg.get_model_parallel_rank()
pp_rank = hcg.get_stage_id()
data_world_rank = get_data_world_rank()
def set_seed(seed):
if dist.get_world_size() > 1:
# obtain rank message of hybrid parallel
hcg = fleet.get_hybrid_communicate_group()
mp_rank = hcg.get_model_parallel_rank()
pp_rank = hcg.get_stage_id()
data_world_rank = get_data_world_rank()
else:
mp_rank, pp_rank, data_world_rank = 1, 1, 1

random.seed(seed + data_world_rank)
np.random.seed(seed + data_world_rank)
Expand Down
19 changes: 19 additions & 0 deletions projects/gpt/pretrain_gpt_1.3B.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

#! /bin/bash

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


python ../../tools/train.py -c ../../ppfleetx/configs/nlp/gpt/pretrain_gpt_1.3B.yaml
23 changes: 23 additions & 0 deletions projects/gpt/pretrain_gpt_1.3B_dp8.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#! /bin/bash

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

log_dir=log_hybrid
rm -rf $log_dir

# 1.3B+dp8 run_pretrain
python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3,4,5,6,7" \
../../tools/train.py \
-c ../../ppfleetx/configs/nlp/gpt/pretrain_gpt_1.3B_dp8.yaml
23 changes: 23 additions & 0 deletions projects/gpt/pretrain_gpt_175B_mp8_pp16.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#! /bin/bash

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

log_dir=log_hybrid
rm -rf $log_dir

# 175B+mp8_pp16 run_pretrain
python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3,4,5,6,7" \
../../tools/train.py \
-c ../../ppfleetx/configs/nlp/gpt/pretrain_gpt_175B_mp8_pp16.yaml
20 changes: 20 additions & 0 deletions projects/gpt/pretrain_gpt_345M.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

#! /bin/bash

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.



python ../../tools/train.py -c ../../ppfleetx/configs/nlp/gpt/pretrain_gpt_345M.yaml
23 changes: 23 additions & 0 deletions projects/gpt/pretrain_gpt_6.7B_sharding16.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#! /bin/bash
# Runs the "1.3B" parameter model
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

log_dir=log_hybrid
rm -rf $log_dir

# 6.7B+sharding16 run_pretrain
python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3,4,5,6,7" \
../../tools/train.py \
-c ../../ppfleetx/configs/nlp/gpt/pretrain_gpt_6.7B_sharding16.yaml
7 changes: 5 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,18 @@
# from ppfleetx.data import build_dataloader
from ppfleetx.models import build_module
from ppfleetx.optims import build_lr_scheduler, build_optimizer
import paddle.distributed as dist

init_logger()

if __name__ == "__main__":
args = config.parse_args()
cfg = config.get_config(args.config, overrides=args.override, show=False)

fleet.init(is_collective=True, strategy=env.init_dist_env(cfg))
env.set_dist_seed(cfg.Global.seed)
if dist.get_world_size() > 1:
fleet.init(is_collective=True, strategy=env.init_dist_env(cfg))

env.set_seed(cfg.Global.seed)

module = build_module(cfg)
config.print_config(cfg)
Expand Down
1 change: 1 addition & 0 deletions tools/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
# for multi-cards train
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python3.7 -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" train.py -c ../ppfleetx/configs/nlp/gpt/pretrain_gpt_1.3B_dp8.yaml
# python train.py -c ../ppfleetx/configs/nlp/gpt/pretrain_gpt_345M.yaml