diff --git a/paddlenlp/transformers/deepseek_v2/configuration.py b/paddlenlp/transformers/deepseek_v2/configuration.py index 53be3a6fa7c0..99b68b93b57d 100644 --- a/paddlenlp/transformers/deepseek_v2/configuration.py +++ b/paddlenlp/transformers/deepseek_v2/configuration.py @@ -182,8 +182,9 @@ def __init__( use_dualpipev=False, send_mtp_embed=False, using_post_norm_recompute=False, - recompute_fwd_gate_up=False, + recompute_fwd_gate_up=0, is_split_group_gemm=False, + fakse_gate_restrict_balance=False, **kwargs, ): self.vocab_size = vocab_size @@ -237,6 +238,7 @@ def __init__( self.using_post_norm_recompute = using_post_norm_recompute self.recompute_fwd_gate_up = recompute_fwd_gate_up self.is_split_group_gemm = is_split_group_gemm + self.fakse_gate_restrict_balance = fakse_gate_restrict_balance super().__init__( pad_token_id=pad_token_id, diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index a47fe66fa896..3fef76f40413 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -772,9 +772,20 @@ def backward(ctx, d_gate_logits, d_norm_output): return dx, d_rms_norm_weight, d_moe_gate_weight +def balance_expert_assignment(n, m, k): + assert k * n % m == 0 + matrix = paddle.zeros((n, m), dtype=paddle.int32) + for row in range(n): + start_col = row % m + for i in range(k): + col = (start_col + i) % m + matrix[row, col] = 1 + return matrix + + class FakeGate(paddle.autograd.PyLayer): @staticmethod - def forward(ctx, hidden_states, weight): + def forward(ctx, hidden_states, weight, fakse_gate_restrict_balance=False, num_experts_per_tok=8): expert_num = weight.shape[1] bsz, seq, _ = hidden_states.shape @@ -782,8 +793,12 @@ def forward(ctx, hidden_states, weight): ctx.x_dtype = hidden_states.dtype ctx.y_shape = weight.shape ctx.y_dtype = weight.dtype - - return paddle.randn([bsz, seq, expert_num]).cast(weight.dtype) + if fakse_gate_restrict_balance: + return paddle.reshape( + balance_expert_assignment(bsz * seq, expert_num, num_experts_per_tok), [bsz, seq, expert_num] + ) + else: + return paddle.randn([bsz, seq, expert_num]).cast(weight.dtype) @staticmethod def backward(ctx, grad_output): @@ -841,11 +856,23 @@ def forward(self, hidden_states): # compute gating score if self.using_post_norm_recompute: logits, norm_out = FusedNormGateFunc.apply(hidden_states, self.norm_weight, self.weight, self.norm_eps) + if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate: + logits = FakeGate.apply( + hidden_states, + self.weight, + self.config.fakse_gate_restrict_balance, + self.config.num_experts_per_tok, + ) else: with paddle.amp.auto_cast(False): hidden_states = hidden_states.cast(self.weight.dtype) if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate: - logits = FakeGate.apply(hidden_states, self.weight) + logits = FakeGate.apply( + hidden_states, + self.weight, + self.config.fakse_gate_restrict_balance, + self.config.num_experts_per_tok, + ) else: logits = F.linear(hidden_states, self.weight, None) diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index c51a2055516d..9e05faba3593 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import os from typing import OrderedDict, Tuple, Union @@ -29,6 +30,7 @@ from paddle.distributed.fleet.recompute.recompute import recompute from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp +from ...utils.log import logger from ...utils.tools import get_env_device from ..model_utils import PipelinePretrainedModel from .modeling import ( @@ -1445,6 +1447,43 @@ def get_hcg(): LayerDesc(DeepseekV2EmbeddingPipe, config=config), self._base_model.base_model_prefix ) + def compute_recompute_fwd_gate_up_list(pp_nums, all_dl_nums, dense_dl_nums, recompute_fwd_gate_up): + all_layers_nums = all_dl_nums + 4 # embedding, rms, lm_head, mtp + segment_size = all_layers_nums // pp_nums + boundary = math.ceil((1 + dense_dl_nums) / segment_size) * segment_size + recompute_fwd_gate_up_list = [dense_dl_nums] + for idx in range(boundary - 1, all_dl_nums, segment_size): + recompute_fwd_gate_up_list.append(idx) + + # If `recompute_fwd_gate_up` is a Boolean value and is True, means all O1 will be recomputed. + # Otherwise `recompute_fwd_gate_up` should be an integer representing how many O1 are recomputed. + assert isinstance(recompute_fwd_gate_up, (int, bool)) + if type(recompute_fwd_gate_up) is bool: + enable_k_o1_rc = segment_size if recompute_fwd_gate_up is True else 0 + else: + enable_k_o1_rc = recompute_fwd_gate_up + + ret = [] + for i in range(len(recompute_fwd_gate_up_list)): + for k in range(min(segment_size, enable_k_o1_rc)): + ret.append(recompute_fwd_gate_up_list[i] + k) + return ret + + pp_nums = ( + self.config["pipeline_parallel_degree"] * 2 + if self.config.use_dualpipev + else self.config["pipeline_parallel_degree"] + ) + recompute_fwd_gate_up_list = compute_recompute_fwd_gate_up_list( + pp_nums, + self.config.num_hidden_layers, + self.config.first_k_dense_replace, + self.config.recompute_fwd_gate_up, + ) + + logger.info(f"recompute_fwd_gate_up_list: {recompute_fwd_gate_up_list}") + config.recompute_fwd_gate_up_list = recompute_fwd_gate_up_list + for i in range(config.num_hidden_layers): self.add_sequential_layer( LayerDesc( @@ -1519,8 +1558,8 @@ def overlapped_forward_backward( backward_loss_fn_node, backward_input_grads, scaler, - combine_bw_event_to_wait = None, - pp_stream=None + combine_bw_event_to_wait=None, + pp_stream=None, ): if backward_loss_fn_node is not None: if scaler: diff --git a/paddlenlp/transformers/fp8_utils.py b/paddlenlp/transformers/fp8_utils.py index 57fc729075e1..a5d19e65a2ca 100644 --- a/paddlenlp/transformers/fp8_utils.py +++ b/paddlenlp/transformers/fp8_utils.py @@ -1117,6 +1117,11 @@ def backward_dx(self, out_grad): self.out_grad = out_grad + # clear status for save memory + self.m_indices = None + self.unzipped_probs = None + self.input = None + # dx dx = self.bwd_gate_up_input(do1, expert_w1, dx=out_grad[0] if isinstance(out_grad, tuple) else out_grad) diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index ff834ea0f46d..da3d62016a93 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -668,7 +668,7 @@ def backward(self, output_grad, event_to_wait=None): if DSV3_USE_FP8_DISPATCH: if event_to_wait is not None: assert self.moe_group is not None - event_to_wait.comm_stream_wait( self.moe_group.id) + event_to_wait.comm_stream_wait(self.moe_group.id) buffer = get_buffer(self.token_dispatcher._comm_manager.group, get_hidden_bytes(output_grad)) custom_stream = paddle.device.Stream(stream_base=buffer.runtime.get_comm_stream()) else: @@ -697,19 +697,19 @@ class FusionMlpNode: def __init__(self, custom_map, max_topk, recompute_fwd_gate_up=False, is_split_group_gemm=True): self.token_dispatcher = custom_map.token_dispatcher self.experts = custom_map.experts + self.unzip_node = UnZipNode() + self.zip_node = ZipNode() self.experts_group_gemm_node = FP8GroupGemmMlpFunctionNode( custom_map, recompute_fwd_gate_up=recompute_fwd_gate_up, is_split_group_gemm=is_split_group_gemm, ) - self.unzip_node = UnZipNode(self.token_dispatcher) - self.zip_node = ZipNode(self.token_dispatcher) self.dispatched_indices = None self.dispatched_probs = None self.tokens_per_expert = None self.router_topk = max_topk - def reset_statue(self): + def reset_statue(self, with_dw=False): """ 重置所有状态变量。 @@ -724,8 +724,15 @@ def reset_statue(self): self.dispatched_probs = None self.tokens_per_expert = None self.router_topk = None - self.experts_group_gemm_node.reset_statue() - self.experts_group_gemm_node = None + + del self.unzip_node + del self.zip_node + self.unzip_node = None + self.zip_node = None + + if with_dw: + self.experts_group_gemm_node.reset_statue() + self.experts_group_gemm_node = None @paddle.no_grad() def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs): @@ -847,13 +854,14 @@ def backward(self, hidden_states_out_grad, with_dw=True): self.dispatched_indices, num_experts=len(self.tokens_per_expert), ) - if with_dw: - self.reset_statue() + + self.reset_statue(with_dw) return hs_dispatched_grad, dispatched_probs_grad @paddle.no_grad() def backward_dw(self): self.experts_group_gemm_node.backward_dw() + self.reset_statue(True) class FusionMoeNode: diff --git a/paddlenlp/transformers/moe_utils.py b/paddlenlp/transformers/moe_utils.py index a1716a6bc30c..dd9756746015 100644 --- a/paddlenlp/transformers/moe_utils.py +++ b/paddlenlp/transformers/moe_utils.py @@ -125,8 +125,7 @@ def unpermute( class UnZipNode: - def __init__(self, token_dispatcher, name="unzip"): - self.token_dispatcher = token_dispatcher + def __init__(self, name="unzip"): self.name = name self.unzipped_probs = None self.zipped_expertwise_rowmap = None @@ -199,8 +198,7 @@ def backward(self, dx, hidden_states_out_grad, probs_grad, dispatched_indices, n class ZipNode: - def __init__(self, token_dispatcher, name="zip"): - self.token_dispatcher = token_dispatcher + def __init__(self, name="zip"): self.name = name @paddle.no_grad()