Skip to content

Commit c7b0059

Browse files
authored
refine fp8_utils (#10848)
* refine fp8_utils * refine fp8_utils * refine fp8_utils * fix * fix after review * fix * fix * fix * fix
1 parent 8a0986e commit c7b0059

File tree

4 files changed

+408
-528
lines changed

4 files changed

+408
-528
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2008,7 +2008,6 @@ def forward(
20082008

20092009
if not (self.using_post_norm_recompute and isinstance(self.mlp, DeepseekV2MoE)):
20102010
hidden_states = self.post_attention_layernorm(hidden_states)
2011-
20122011
hidden_states = self.mlp(hidden_states)
20132012
hidden_states = residual + hidden_states
20142013

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,7 @@
5959
)
6060
from paddlenlp.transformers.moe_layer import FusionMoeNode
6161

62-
from ..fp8_utils import (
63-
fp8_mlp_bwd,
64-
fp8_mlp_bwd_norm_rc,
65-
fp8_mlp_fwd,
66-
fp8_mlp_fwd_norm_rc,
67-
)
62+
from ..fp8_utils import FP8LinearFunctionBase
6863

6964
__all__ = [
7065
"DeepseekV2ForCausalLMPipe",
@@ -175,15 +170,17 @@ def forward(self, inputs):
175170
with paddle.no_grad():
176171
if self.shared_experts is not None:
177172
if self.using_post_norm_recompute:
178-
shared_expert_output = fp8_mlp_fwd_norm_rc(
173+
shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd_norm_rc(
179174
hidden_states,
180175
self.shared_experts.norm_weight,
181176
self.shared_experts.norm_eps,
182177
self.shared_experts.w1,
183178
self.shared_experts.w2,
184179
)
185180
else:
186-
shared_expert_output = fp8_mlp_fwd(hidden_states, self.shared_experts.w1, self.shared_experts.w2)
181+
_, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
182+
hidden_states, self.shared_experts.w1, self.shared_experts.w2
183+
)
187184
final_hidden_states = final_hidden_states + shared_expert_output
188185

189186
self.x = hidden_states
@@ -201,7 +198,7 @@ def backward(self, output_grad):
201198

202199
assert not self.send_mtp_embed, "not support have mtp have yet"
203200
if self.using_post_norm_recompute:
204-
dx = fp8_mlp_bwd_norm_rc(
201+
dx = FP8LinearFunctionBase.fp8_mlp_bwd_norm_rc(
205202
do3,
206203
self.x,
207204
self.shared_experts.norm_weight,
@@ -210,7 +207,7 @@ def backward(self, output_grad):
210207
self.shared_experts.w2,
211208
)
212209
else:
213-
dx = fp8_mlp_bwd(do3, self.x, self.shared_experts.w1, self.shared_experts.w2)
210+
dx = FP8LinearFunctionBase.fp8_mlp_bwd(do3, self.x, self.shared_experts.w1, self.shared_experts.w2)
214211

215212
self.x = None
216213

0 commit comments

Comments
 (0)