Skip to content

Commit 7e4c254

Browse files
committed
fix
1 parent a6702f7 commit 7e4c254

File tree

1 file changed

+14
-20
lines changed

1 file changed

+14
-20
lines changed

paddlenlp/transformers/fp8_utils.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -223,21 +223,6 @@ def compute_expert_w_grad(
223223
weight._apply_backward_hook()
224224
return result
225225

226-
@staticmethod
227-
def common_fp8_mlp_fwd(x, w1, w2):
228-
# ===== o1 = deep_gemm(x_fp8, w1_t_fp8) =====
229-
o1, x_fp8, x_scale = FP8LinearFunctionBase.compute_fp8_linear(
230-
x, w1, weight_transpose=True, return_transpose_only=True, return_mode="with_input_quant"
231-
)
232-
233-
# ===== o2 = swiglu(o1) =====
234-
o2 = swiglu(o1)
235-
236-
# ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) =====
237-
o3 = FP8LinearFunctionBase.compute_fp8_linear(o2, w2, weight_transpose=True, return_transpose_only=True)
238-
239-
return x_fp8, x_scale, o3
240-
241226
@staticmethod
242227
def common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_backward_hook=False):
243228

@@ -303,12 +288,21 @@ def fp8_mlp_fwd(x, w1, w2):
303288
x_orig_shape = x.shape
304289
x = x.reshape([-1, x_orig_shape[-1]])
305290

306-
_, _, o3 = FP8LinearFunctionBase.common_fp8_mlp_fwd(x, w1, w2)
291+
# ===== o1 = deep_gemm(x_fp8, w1_t_fp8) =====
292+
o1, x_fp8, x_scale = FP8LinearFunctionBase.compute_fp8_linear(
293+
x, w1, weight_transpose=True, return_transpose_only=True, return_mode="with_input_quant"
294+
)
295+
296+
# ===== o2 = swiglu(o1) =====
297+
o2 = swiglu(o1)
298+
299+
# ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) =====
300+
o3 = FP8LinearFunctionBase.compute_fp8_linear(o2, w2, weight_transpose=True, return_transpose_only=True)
307301

308302
if len(x_orig_shape) > 2:
309303
o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]])
310304

311-
return o3
305+
return x_fp8, x_scale, o3
312306

313307
@staticmethod
314308
def fp8_mlp_fwd_norm_rc(x, norm_w, norm_eps, w1, w2):
@@ -462,7 +456,7 @@ def forward(self, x):
462456
return FP8LinearFunction.apply(x, self, keep_x=True)
463457

464458

465-
class FP8NormMlpRecomputeFunction(paddle.autograd.PyLayer):
459+
class FusedNormFP8MLPFunction(paddle.autograd.PyLayer):
466460
@staticmethod
467461
def forward(ctx, x, norm_w, w1, w2, norm_eps):
468462
# ===== compute norm_output =====
@@ -529,7 +523,7 @@ def forward(ctx, x, w1, w2):
529523
x = x.reshape([-1, x_orig_shape[-1]])
530524

531525
# ===== call func fp8_mlp_fwd =====
532-
x_fp8, x_scale, o3 = FP8LinearFunctionBase.common_fp8_mlp_fwd(x, w1, w2)
526+
x_fp8, x_scale, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(x, w1, w2)
533527
# ===== reshape to origin shape =====
534528
if len(x_orig_shape) > 2:
535529
o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]])
@@ -610,7 +604,7 @@ def __init__(
610604

611605
def forward(self, x):
612606
if self.using_post_norm_recompute:
613-
return FP8NormMlpRecomputeFunction.apply(x, self.norm_weight, self.w1, self.w2, self.norm_eps)
607+
return FusedNormFP8MLPFunction.apply(x, self.norm_weight, self.w1, self.w2, self.norm_eps)
614608
else:
615609
return FP8MlpFunction.apply(x, self.w1, self.w2)
616610

0 commit comments

Comments
 (0)