59
59
)
60
60
from paddlenlp .transformers .moe_layer import FusionMoeNode
61
61
62
- from ..fp8_utils import (
63
- fp8_mlp_bwd ,
64
- fp8_mlp_bwd_norm_rc ,
65
- fp8_mlp_fwd ,
66
- fp8_mlp_fwd_norm_rc ,
67
- )
62
+ from ..fp8_utils import FP8LinearFunctionBase
68
63
69
64
__all__ = [
70
65
"DeepseekV2ForCausalLMPipe" ,
@@ -175,15 +170,17 @@ def forward(self, inputs):
175
170
with paddle .no_grad ():
176
171
if self .shared_experts is not None :
177
172
if self .using_post_norm_recompute :
178
- shared_expert_output = fp8_mlp_fwd_norm_rc (
173
+ shared_expert_output = FP8LinearFunctionBase . fp8_mlp_fwd_norm_rc (
179
174
hidden_states ,
180
175
self .shared_experts .norm_weight ,
181
176
self .shared_experts .norm_eps ,
182
177
self .shared_experts .w1 ,
183
178
self .shared_experts .w2 ,
184
179
)
185
180
else :
186
- shared_expert_output = fp8_mlp_fwd (hidden_states , self .shared_experts .w1 , self .shared_experts .w2 )
181
+ _ , _ , shared_expert_output = FP8LinearFunctionBase .fp8_mlp_fwd (
182
+ hidden_states , self .shared_experts .w1 , self .shared_experts .w2
183
+ )
187
184
final_hidden_states = final_hidden_states + shared_expert_output
188
185
189
186
self .x = hidden_states
@@ -201,7 +198,7 @@ def backward(self, output_grad):
201
198
202
199
assert not self .send_mtp_embed , "not support have mtp have yet"
203
200
if self .using_post_norm_recompute :
204
- dx = fp8_mlp_bwd_norm_rc (
201
+ dx = FP8LinearFunctionBase . fp8_mlp_bwd_norm_rc (
205
202
do3 ,
206
203
self .x ,
207
204
self .shared_experts .norm_weight ,
@@ -210,7 +207,7 @@ def backward(self, output_grad):
210
207
self .shared_experts .w2 ,
211
208
)
212
209
else :
213
- dx = fp8_mlp_bwd (do3 , self .x , self .shared_experts .w1 , self .shared_experts .w2 )
210
+ dx = FP8LinearFunctionBase . fp8_mlp_bwd (do3 , self .x , self .shared_experts .w1 , self .shared_experts .w2 )
214
211
215
212
self .x = None
216
213
0 commit comments