Skip to content

Balance_gate & O1 recompute configuration #10883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddlenlp/transformers/deepseek_v2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,9 @@ def __init__(
use_dualpipev=False,
send_mtp_embed=False,
using_post_norm_recompute=False,
recompute_fwd_gate_up=False,
recompute_fwd_gate_up=0,
is_split_group_gemm=False,
fakse_gate_restrict_balance=False,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -237,6 +238,7 @@ def __init__(
self.using_post_norm_recompute = using_post_norm_recompute
self.recompute_fwd_gate_up = recompute_fwd_gate_up
self.is_split_group_gemm = is_split_group_gemm
self.fakse_gate_restrict_balance = fakse_gate_restrict_balance

super().__init__(
pad_token_id=pad_token_id,
Expand Down
35 changes: 31 additions & 4 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,18 +772,33 @@ def backward(ctx, d_gate_logits, d_norm_output):
return dx, d_rms_norm_weight, d_moe_gate_weight


def balance_expert_assignment(n, m, k):
assert k * n % m == 0
matrix = paddle.zeros((n, m), dtype=paddle.int32)
for row in range(n):
start_col = row % m
for i in range(k):
col = (start_col + i) % m
matrix[row, col] = 1
return matrix


class FakeGate(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, hidden_states, weight):
def forward(ctx, hidden_states, weight, fakse_gate_restrict_balance=False, num_experts_per_tok=8):
expert_num = weight.shape[1]
bsz, seq, _ = hidden_states.shape

ctx.x_shape = hidden_states.shape
ctx.x_dtype = hidden_states.dtype
ctx.y_shape = weight.shape
ctx.y_dtype = weight.dtype

return paddle.randn([bsz, seq, expert_num]).cast(weight.dtype)
if fakse_gate_restrict_balance:
return paddle.reshape(
balance_expert_assignment(bsz * seq, expert_num, num_experts_per_tok), [bsz, seq, expert_num]
)
else:
return paddle.randn([bsz, seq, expert_num]).cast(weight.dtype)

@staticmethod
def backward(ctx, grad_output):
Expand Down Expand Up @@ -841,11 +856,23 @@ def forward(self, hidden_states):
# compute gating score
if self.using_post_norm_recompute:
logits, norm_out = FusedNormGateFunc.apply(hidden_states, self.norm_weight, self.weight, self.norm_eps)
if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate:
logits = FakeGate.apply(
hidden_states,
self.weight,
self.config.fakse_gate_restrict_balance,
self.config.num_experts_per_tok,
)
else:
with paddle.amp.auto_cast(False):
hidden_states = hidden_states.cast(self.weight.dtype)
if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate:
logits = FakeGate.apply(hidden_states, self.weight)
logits = FakeGate.apply(
hidden_states,
self.weight,
self.config.fakse_gate_restrict_balance,
self.config.num_experts_per_tok,
)
else:
logits = F.linear(hidden_states, self.weight, None)

Expand Down
43 changes: 41 additions & 2 deletions paddlenlp/transformers/deepseek_v2/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os
from typing import OrderedDict, Tuple, Union

Expand All @@ -29,6 +30,7 @@
from paddle.distributed.fleet.recompute.recompute import recompute
from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp

from ...utils.log import logger
from ...utils.tools import get_env_device
from ..model_utils import PipelinePretrainedModel
from .modeling import (
Expand Down Expand Up @@ -1445,6 +1447,43 @@ def get_hcg():
LayerDesc(DeepseekV2EmbeddingPipe, config=config), self._base_model.base_model_prefix
)

def compute_recompute_fwd_gate_up_list(pp_nums, all_dl_nums, dense_dl_nums, recompute_fwd_gate_up):
all_layers_nums = all_dl_nums + 4 # embedding, rms, lm_head, mtp
segment_size = all_layers_nums // pp_nums
boundary = math.ceil((1 + dense_dl_nums) / segment_size) * segment_size
recompute_fwd_gate_up_list = [dense_dl_nums]
for idx in range(boundary - 1, all_dl_nums, segment_size):
recompute_fwd_gate_up_list.append(idx)

# If `recompute_fwd_gate_up` is a Boolean value and is True, means all O1 will be recomputed.
# Otherwise `recompute_fwd_gate_up` should be an integer representing how many O1 are recomputed.
assert isinstance(recompute_fwd_gate_up, (int, bool))
if type(recompute_fwd_gate_up) is bool:
enable_k_o1_rc = segment_size if recompute_fwd_gate_up is True else 0
else:
enable_k_o1_rc = recompute_fwd_gate_up

ret = []
for i in range(len(recompute_fwd_gate_up_list)):
for k in range(min(segment_size, enable_k_o1_rc)):
ret.append(recompute_fwd_gate_up_list[i] + k)
return ret

pp_nums = (
self.config["pipeline_parallel_degree"] * 2
if self.config.use_dualpipev
else self.config["pipeline_parallel_degree"]
)
recompute_fwd_gate_up_list = compute_recompute_fwd_gate_up_list(
pp_nums,
self.config.num_hidden_layers,
self.config.first_k_dense_replace,
self.config.recompute_fwd_gate_up,
)

logger.info(f"recompute_fwd_gate_up_list: {recompute_fwd_gate_up_list}")
config.recompute_fwd_gate_up_list = recompute_fwd_gate_up_list

for i in range(config.num_hidden_layers):
self.add_sequential_layer(
LayerDesc(
Expand Down Expand Up @@ -1519,8 +1558,8 @@ def overlapped_forward_backward(
backward_loss_fn_node,
backward_input_grads,
scaler,
combine_bw_event_to_wait = None,
pp_stream=None
combine_bw_event_to_wait=None,
pp_stream=None,
):
if backward_loss_fn_node is not None:
if scaler:
Expand Down
5 changes: 5 additions & 0 deletions paddlenlp/transformers/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,11 @@ def backward_dx(self, out_grad):

self.out_grad = out_grad

# clear status for save memory
self.m_indices = None
self.unzipped_probs = None
self.input = None

# dx
dx = self.bwd_gate_up_input(do1, expert_w1, dx=out_grad[0] if isinstance(out_grad, tuple) else out_grad)

Expand Down
24 changes: 16 additions & 8 deletions paddlenlp/transformers/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def backward(self, output_grad, event_to_wait=None):
if DSV3_USE_FP8_DISPATCH:
if event_to_wait is not None:
assert self.moe_group is not None
event_to_wait.comm_stream_wait( self.moe_group.id)
event_to_wait.comm_stream_wait(self.moe_group.id)
buffer = get_buffer(self.token_dispatcher._comm_manager.group, get_hidden_bytes(output_grad))
custom_stream = paddle.device.Stream(stream_base=buffer.runtime.get_comm_stream())
else:
Expand Down Expand Up @@ -697,19 +697,19 @@ class FusionMlpNode:
def __init__(self, custom_map, max_topk, recompute_fwd_gate_up=False, is_split_group_gemm=True):
self.token_dispatcher = custom_map.token_dispatcher
self.experts = custom_map.experts
self.unzip_node = UnZipNode()
self.zip_node = ZipNode()
self.experts_group_gemm_node = FP8GroupGemmMlpFunctionNode(
custom_map,
recompute_fwd_gate_up=recompute_fwd_gate_up,
is_split_group_gemm=is_split_group_gemm,
)
self.unzip_node = UnZipNode(self.token_dispatcher)
self.zip_node = ZipNode(self.token_dispatcher)
self.dispatched_indices = None
self.dispatched_probs = None
self.tokens_per_expert = None
self.router_topk = max_topk

def reset_statue(self):
def reset_statue(self, with_dw=False):
"""
重置所有状态变量。

Expand All @@ -724,8 +724,15 @@ def reset_statue(self):
self.dispatched_probs = None
self.tokens_per_expert = None
self.router_topk = None
self.experts_group_gemm_node.reset_statue()
self.experts_group_gemm_node = None

del self.unzip_node
del self.zip_node
self.unzip_node = None
self.zip_node = None

if with_dw:
self.experts_group_gemm_node.reset_statue()
self.experts_group_gemm_node = None

@paddle.no_grad()
def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs):
Expand Down Expand Up @@ -847,13 +854,14 @@ def backward(self, hidden_states_out_grad, with_dw=True):
self.dispatched_indices,
num_experts=len(self.tokens_per_expert),
)
if with_dw:
self.reset_statue()

self.reset_statue(with_dw)
return hs_dispatched_grad, dispatched_probs_grad

@paddle.no_grad()
def backward_dw(self):
self.experts_group_gemm_node.backward_dw()
self.reset_statue(True)


class FusionMoeNode:
Expand Down
6 changes: 2 additions & 4 deletions paddlenlp/transformers/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ def unpermute(


class UnZipNode:
def __init__(self, token_dispatcher, name="unzip"):
self.token_dispatcher = token_dispatcher
def __init__(self, name="unzip"):
self.name = name
self.unzipped_probs = None
self.zipped_expertwise_rowmap = None
Expand Down Expand Up @@ -199,8 +198,7 @@ def backward(self, dx, hidden_states_out_grad, probs_grad, dispatched_indices, n


class ZipNode:
def __init__(self, token_dispatcher, name="zip"):
self.token_dispatcher = token_dispatcher
def __init__(self, name="zip"):
self.name = name

@paddle.no_grad()
Expand Down
Loading