@@ -223,21 +223,6 @@ def compute_expert_w_grad(
223
223
weight ._apply_backward_hook ()
224
224
return result
225
225
226
- @staticmethod
227
- def common_fp8_mlp_fwd (x , w1 , w2 ):
228
- # ===== o1 = deep_gemm(x_fp8, w1_t_fp8) =====
229
- o1 , x_fp8 , x_scale = FP8LinearFunctionBase .compute_fp8_linear (
230
- x , w1 , weight_transpose = True , return_transpose_only = True , return_mode = "with_input_quant"
231
- )
232
-
233
- # ===== o2 = swiglu(o1) =====
234
- o2 = swiglu (o1 )
235
-
236
- # ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) =====
237
- o3 = FP8LinearFunctionBase .compute_fp8_linear (o2 , w2 , weight_transpose = True , return_transpose_only = True )
238
-
239
- return x_fp8 , x_scale , o3
240
-
241
226
@staticmethod
242
227
def common_fp8_mlp_bwd (do3 , x_fp8 , x_scale , x_t_fp8 , x_t_scale , w1 , w2 , apply_backward_hook = False ):
243
228
@@ -303,12 +288,21 @@ def fp8_mlp_fwd(x, w1, w2):
303
288
x_orig_shape = x .shape
304
289
x = x .reshape ([- 1 , x_orig_shape [- 1 ]])
305
290
306
- _ , _ , o3 = FP8LinearFunctionBase .common_fp8_mlp_fwd (x , w1 , w2 )
291
+ # ===== o1 = deep_gemm(x_fp8, w1_t_fp8) =====
292
+ o1 , x_fp8 , x_scale = FP8LinearFunctionBase .compute_fp8_linear (
293
+ x , w1 , weight_transpose = True , return_transpose_only = True , return_mode = "with_input_quant"
294
+ )
295
+
296
+ # ===== o2 = swiglu(o1) =====
297
+ o2 = swiglu (o1 )
298
+
299
+ # ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) =====
300
+ o3 = FP8LinearFunctionBase .compute_fp8_linear (o2 , w2 , weight_transpose = True , return_transpose_only = True )
307
301
308
302
if len (x_orig_shape ) > 2 :
309
303
o3 = o3 .reshape ([x_orig_shape [0 ], - 1 , o3 .shape [- 1 ]])
310
304
311
- return o3
305
+ return x_fp8 , x_scale , o3
312
306
313
307
@staticmethod
314
308
def fp8_mlp_fwd_norm_rc (x , norm_w , norm_eps , w1 , w2 ):
@@ -462,7 +456,7 @@ def forward(self, x):
462
456
return FP8LinearFunction .apply (x , self , keep_x = True )
463
457
464
458
465
- class FP8NormMlpRecomputeFunction (paddle .autograd .PyLayer ):
459
+ class FusedNormFP8MLPFunction (paddle .autograd .PyLayer ):
466
460
@staticmethod
467
461
def forward (ctx , x , norm_w , w1 , w2 , norm_eps ):
468
462
# ===== compute norm_output =====
@@ -529,7 +523,7 @@ def forward(ctx, x, w1, w2):
529
523
x = x .reshape ([- 1 , x_orig_shape [- 1 ]])
530
524
531
525
# ===== call func fp8_mlp_fwd =====
532
- x_fp8 , x_scale , o3 = FP8LinearFunctionBase .common_fp8_mlp_fwd (x , w1 , w2 )
526
+ x_fp8 , x_scale , o3 = FP8LinearFunctionBase .fp8_mlp_fwd (x , w1 , w2 )
533
527
# ===== reshape to origin shape =====
534
528
if len (x_orig_shape ) > 2 :
535
529
o3 = o3 .reshape ([x_orig_shape [0 ], - 1 , o3 .shape [- 1 ]])
@@ -610,7 +604,7 @@ def __init__(
610
604
611
605
def forward (self , x ):
612
606
if self .using_post_norm_recompute :
613
- return FP8NormMlpRecomputeFunction .apply (x , self .norm_weight , self .w1 , self .w2 , self .norm_eps )
607
+ return FusedNormFP8MLPFunction .apply (x , self .norm_weight , self .w1 , self .w2 , self .norm_eps )
614
608
else :
615
609
return FP8MlpFunction .apply (x , self .w1 , self .w2 )
616
610
0 commit comments