diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 3fef76f40413..cebd72af2c14 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -2008,7 +2008,6 @@ def forward( if not (self.using_post_norm_recompute and isinstance(self.mlp, DeepseekV2MoE)): hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 9e05faba3593..1b90fe83877b 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -59,12 +59,7 @@ ) from paddlenlp.transformers.moe_layer import FusionMoeNode -from ..fp8_utils import ( - fp8_mlp_bwd, - fp8_mlp_bwd_norm_rc, - fp8_mlp_fwd, - fp8_mlp_fwd_norm_rc, -) +from ..fp8_utils import FP8LinearFunctionBase __all__ = [ "DeepseekV2ForCausalLMPipe", @@ -175,7 +170,7 @@ def forward(self, inputs): with paddle.no_grad(): if self.shared_experts is not None: if self.using_post_norm_recompute: - shared_expert_output = fp8_mlp_fwd_norm_rc( + shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd_norm_rc( hidden_states, self.shared_experts.norm_weight, self.shared_experts.norm_eps, @@ -183,7 +178,9 @@ def forward(self, inputs): self.shared_experts.w2, ) else: - shared_expert_output = fp8_mlp_fwd(hidden_states, self.shared_experts.w1, self.shared_experts.w2) + _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + hidden_states, self.shared_experts.w1, self.shared_experts.w2 + ) final_hidden_states = final_hidden_states + shared_expert_output self.x = hidden_states @@ -201,7 +198,7 @@ def backward(self, output_grad): assert not self.send_mtp_embed, "not support have mtp have yet" if self.using_post_norm_recompute: - dx = fp8_mlp_bwd_norm_rc( + dx = FP8LinearFunctionBase.fp8_mlp_bwd_norm_rc( do3, self.x, self.shared_experts.norm_weight, @@ -210,7 +207,7 @@ def backward(self, output_grad): self.shared_experts.w2, ) else: - dx = fp8_mlp_bwd(do3, self.x, self.shared_experts.w1, self.shared_experts.w2) + dx = FP8LinearFunctionBase.fp8_mlp_bwd(do3, self.x, self.shared_experts.w1, self.shared_experts.w2) self.x = None diff --git a/paddlenlp/transformers/fp8_utils.py b/paddlenlp/transformers/fp8_utils.py index 7a632401b7b4..72110eab3ba1 100644 --- a/paddlenlp/transformers/fp8_utils.py +++ b/paddlenlp/transformers/fp8_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os + import numpy import paddle import paddle.nn.functional as F @@ -39,519 +40,447 @@ def swiglu(x, y=None): __all__ = [ - "kitchen_fp8_gemm", - "dequantize_fp8_to_fp32", + "FP8LinearFunctionBase", "FP8Linear", "FP8GroupGemmMlpFunctionNode", ] -def kitchen_fp8_gemm( - x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled, out=None, rtn_dtype=paddle.bfloat16 -): - if USE_DS_GEMM: - if out is None: - out = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype) - if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0: - deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((x_fp8, x_scale), (w_fp8, w_scale), out, num_sms=112) - return out - - if out is not None: - accumulate = True - out_dtype = out.dtype - else: - accumulate = False - out_dtype = rtn_dtype - if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0: - y = paddle.incubate.nn.functional.fp8_gemm_blockwise( - a=x_fp8, - a_decode_scale=x_scale, - b=w_fp8, - b_decode_scale=w_scale, - out_dtype=out_dtype, - out=out, - accumulate=accumulate, - use_split_accumulator=True, - is_a_1d_scaled=is_a_1d_scaled, - is_b_1d_scaled=is_b_1d_scaled, - ) - else: - y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], out_dtype) - if out is not None: - out = out + y - return out - - return y - +class FP8LinearFunctionBase: + @staticmethod + def dequantize_fp8_to_fp32(fp8_tensor, scale): + res = fp8_tensor.reshape([-1, 128]).astype("bfloat16") * (scale.reshape([-1, 1])) + return res.reshape(fp8_tensor.shape) -def dequantize_fp8_to_fp32(fp8_tensor, scale): - # expanded_scale = paddle.repeat_interleave(scale, repeats=128, axis=-1) - res = fp8_tensor.reshape([-1, 128]).astype("bfloat16") * (scale.reshape([-1, 1])) - res = res.reshape(fp8_tensor.shape) + @staticmethod + def padding(x, axis): + if x.shape[axis] % 512 != 0: + if (x.shape[axis] + 128 - (x.shape[axis] % 128)) % 512 != 0: + padding_size = 512 + else: + padding_size = 128 + pad_size = padding_size - (x.shape[axis] % padding_size) + if axis == 0: + x = paddle.concat([x, paddle.zeros([pad_size, x.shape[-1]], dtype=x.dtype)], axis=0) + else: + x = paddle.concat([x, paddle.zeros([x.shape[0], pad_size], dtype=x.dtype)], axis=-1) + return x - return res + @staticmethod + def padding_and_quant_input(tensor): + """Quantize input to FP8, with fallback to padded transposed version if shape not aligned.""" + if tensor.shape[0] % 512 != 0: + tensor_fp8, tensor_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + tensor = FP8LinearFunctionBase.padding(tensor, 0) + tensor_t_fp8, tensor_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, + output_scale_transpose=True, + tquant_method="1x128", + input_transpose=True, + return_transpose_only=True, + ) + return tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale + else: + tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, output_scale_transpose=True, quant_method="1x128", input_transpose=True + ) + return tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale + @staticmethod + def kitchen_gemm( + x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled, out=None, rtn_dtype=paddle.bfloat16 + ): + if USE_DS_GEMM: + if out is None: + out = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype) + if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0: + deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((x_fp8, x_scale), (w_fp8, w_scale), out, num_sms=112) + return out -def padding(x, axis): - if x.shape[axis] % 512 != 0: - if (x.shape[axis] + 128 - (x.shape[axis] % 128)) % 512 != 0: - padding_size = 512 + if out is not None: + accumulate = True + out_dtype = out.dtype else: - padding_size = 128 - pad_size = padding_size - (x.shape[axis] % padding_size) - if axis == 0: - x = paddle.concat([x, paddle.zeros([pad_size, x.shape[-1]], dtype=x.dtype)], axis=0) + accumulate = False + out_dtype = rtn_dtype + if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0: + y = paddle.incubate.nn.functional.fp8_gemm_blockwise( + a=x_fp8, + a_decode_scale=x_scale, + b=w_fp8, + b_decode_scale=w_scale, + out_dtype=out_dtype, + out=out, + accumulate=accumulate, + use_split_accumulator=True, + is_a_1d_scaled=is_a_1d_scaled, + is_b_1d_scaled=is_b_1d_scaled, + ) else: - x = paddle.concat([x, paddle.zeros([x.shape[0], pad_size], dtype=x.dtype)], axis=-1) - return x + y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], out_dtype) + if out is not None: + out = out + y + return out + return y -class FP8LinearFunction(paddle.autograd.PyLayer): @staticmethod - def forward(ctx, x, custom_map): - weight = custom_map.weight - x_orig_shape = x.shape - x_t = x.T - - # deep_gemm only support 2D - x = x.reshape([-1, x_orig_shape[-1]]).contiguous() - - # quant - if x.shape[0] % 512 != 0: - x_fp8, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, output_scale_transpose=True, quant_method="1x128", input_transpose=False - ) - x = padding(x, 0) - x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True - ) + def compute_fp8_linear( + input, weight, weight_transpose=False, return_transpose_only=False, return_mode="output_only" + ): + """ + FP8 Linear 计算函数,支持多种返回模式,支持量化/未量化输入。 + + Args: + input: 输入张量(原始或已经量化的(input_fp8, input_scale) 元组)。 + weight: 权重张量。 + weight_transpose (bool): 是否转置权重。 + return_transpose_only (bool): 是否仅返回转置后的权重。 + return_mode (str): 返回模式,可选: + - "output_only": 仅返回输出张量。 + - "with_input_quant": 返回输出 + 输入量化结果 (input_fp8, input_scale)。 + - "with_input_transpose_quant": 返回输出(out) + 输入量化转置结果 (input_t_fp8, input_t_scale). + Returns: + 根据 return_mode 返回不同组合的张量。 + + Raises: + RuntimeError: 如果 return_mode 不支持。 + """ + # check input + is_input_quantized = isinstance(input, tuple) and len(input) == 2 + + if is_input_quantized: + input_fp8, input_scale = input + if return_mode == "with_input_transpose_quant": + raise RuntimeError( + "Cannot return transposed quant if input is already quantized. " "Use raw input instead." + ) else: - x_fp8, x_scale, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, output_scale_transpose=True, quant_method="1x128", input_transpose=True - ) + # quant input (with optional transposed output) + if return_mode == "with_input_transpose_quant": + input_fp8, input_scale, input_t_fp8, input_t_scale = FP8LinearFunctionBase.padding_and_quant_input( + input + ) + else: + input_fp8, input_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + input, + output_scale_transpose=True, + quant_method="1x128", + input_transpose=False, + return_transpose_only=False, + ) - w_fp8, w_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + # quant weight + weight_fp8, weight_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( weight, output_scale_transpose=False, quant_method="128x128", - input_transpose=True, - return_transpose_only=True, + input_transpose=weight_transpose, + return_transpose_only=return_transpose_only, ) - out = paddle.empty([x_fp8.shape[0], w_fp8.shape[0]], dtype=x.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_scale), out, num_sms=112) - out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]]) + # FP8 GEMM + out = paddle.empty([input_fp8.shape[0], weight_fp8.shape[0]], dtype=input.dtype) - # save for bwd - ctx.save_for_backward(x_t_fp8, x_t_scale, weight) - ctx.x_t_shape = x_t.shape - return out + deep_gemm.gemm_fp8_fp8_bf16_nt((input_fp8, input_scale.T), (weight_fp8, weight_scale), out, num_sms=112) - @staticmethod - def backward(ctx, dout): - x_t_fp8, x_t_scale, weight = ctx.saved_tensor() - - # ===== dx = deep_gemm(dout_fp8, w_fp8) - dout_2d = dout.reshape([-1, dout.shape[-1]]) - if dout_2d.shape[0] % 512 != 0: - dout_fp8, dout_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - dout_2d, output_scale_transpose=True, quant_method="1x128", input_transpose=False - ) - dout_2d = padding(dout_2d, 0) - dout_t_fp8, dout_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - dout_2d, - output_scale_transpose=True, - quant_method="1x128", - input_transpose=True, - return_transpose_only=True, - ) + # Return outputs + if return_mode == "output_only": + return out + elif return_mode == "with_input_quant": + return (out, input_fp8, input_scale) + elif return_mode == "with_input_transpose_quant": + return (out, input_t_fp8, input_t_scale) else: - dout_fp8, dout_scale, dout_t_fp8, dout_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - dout_2d, output_scale_transpose=True, quant_method="1x128", input_transpose=True + raise RuntimeError( + f"Unsupported return_mode: {return_mode}. " + "Supported modes: 'output_only', 'with_input_quant', 'with_input_transpose_quant'" ) - w_fp8, w_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - weight, output_scale_transpose=False, quant_method="128x128", input_transpose=False - ) - dx = paddle.empty([ctx.x_t_shape[1], ctx.x_t_shape[0]], dout.dtype) - dx_orig_shape = dout.shape[:-1] - dx_orig_shape.append(ctx.x_t_shape[0]) - deep_gemm.gemm_fp8_fp8_bf16_nt((dout_fp8, dout_scale.T), (w_fp8, w_scale), dx) - dx = dx.reshape(dx_orig_shape) - # ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8) + @staticmethod + def compute_expert_w_grad( + input_t, + input_t_scale, + dout_t, + dout_t_scale, + is_a_1d_scaled=True, + is_b_1d_scaled=True, + weight=None, + rtn_dtype=paddle.bfloat16, + ): + """ + 统一处理 expert_w 的梯度计算(支持 main_grad 和普通 grad) + """ if hasattr(weight, "main_grad"): if weight.main_grad is None: weight.main_grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32) - kitchen_fp8_gemm( - x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, weight.main_grad, rtn_dtype=paddle.float32 + result = FP8LinearFunctionBase.kitchen_gemm( + input_t, + input_t_scale, + dout_t, + dout_t_scale, + is_a_1d_scaled, + is_b_1d_scaled, + weight.main_grad, + rtn_dtype, ) else: if weight.grad is None: weight.grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32) - kitchen_fp8_gemm( - x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, weight.grad, rtn_dtype=paddle.float32 + result = FP8LinearFunctionBase.kitchen_gemm( + input_t, input_t_scale, dout_t, dout_t_scale, is_a_1d_scaled, is_b_1d_scaled, weight.grad, rtn_dtype ) if hasattr(weight, "_apply_backward_hook"): weight._apply_backward_hook() + return result - return dx - - -class FP8Linear(paddle.nn.Layer): - def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None: - super().__init__() - self._dtype = self._helper.get_default_dtype() - - self.weight = self.create_parameter( - shape=[in_features, out_features], - dtype="bfloat16", - is_bias=False, - ) - - def forward(self, x): - return FP8LinearFunction.apply(x, self) - - -class FP8LinearKeepXFunction(paddle.autograd.PyLayer): @staticmethod - def forward(ctx, x, custom_map): - weight = custom_map.weight - x_orig_shape = x.shape + def common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_backward_hook=False): - # deep_gemm only support 2D - x = x.reshape([-1, x_orig_shape[-1]]).contiguous() + # # ===== [recompute] o1 = deep_gemm(x_fp8, w1_t_fp8) ===== + # o1, x_t_fp8, x_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + # x, w1, weight_transpose=True, return_transpose_only=True, return_mode="with_input_transpose_quant" + # ) - # quant - x_fp8, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, output_scale_transpose=True, quant_method="1x128", input_transpose=False + w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True ) - w_fp8, w_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - weight, - output_scale_transpose=False, - quant_method="128x128", - input_transpose=True, - return_transpose_only=True, - ) - - # compute out = mm(x, w_t) - out = paddle.empty([x_fp8.shape[0], w_fp8.shape[0]], dtype=x.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_scale), out, num_sms=112) - out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]]) - - ctx.save_for_backward(x, weight) - return out + o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype) + deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=112) - @staticmethod - def backward(ctx, dout): - x, weight = ctx.saved_tensor() - dx_orig_shape = x.shape + # ===== [recompute] o2 = swiglu(o1) ===== + o2 = swiglu(o1) - # padding - x = padding(x, 0) - x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True + # ===== do2 = deep_gemm(do3_fp8, w2_fp8) + do2, do3_t_fp8, do3_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + do3, w2, return_mode="with_input_transpose_quant" ) - w_fp8, w_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - weight, output_scale_transpose=False, quant_method="128x128", input_transpose=False + # ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8) + o2 = FP8LinearFunctionBase.padding(o2, 0) + o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + o2, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True ) - - dout_2d = dout.reshape([-1, dout.shape[-1]]) - if dout_2d.shape[0] % 512 != 0: - dout_fp8, dout_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - dout_2d, output_scale_transpose=True, quant_method="1x128", input_transpose=False - ) - dout_2d = padding(dout_2d, 0) - dout_t_fp8, dout_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - dout_2d, - output_scale_transpose=True, - quant_method="1x128", - input_transpose=True, - return_transpose_only=True, + if apply_backward_hook: + FP8LinearFunctionBase.compute_expert_w_grad( + o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, w2, rtn_dtype=paddle.float32 ) else: - dout_fp8, dout_scale, dout_t_fp8, dout_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - dout_2d, output_scale_transpose=True, quant_method="1x128", input_transpose=True + dw2 = FP8LinearFunctionBase.kitchen_gemm( + o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, rtn_dtype=paddle.float32 ) - dx = paddle.empty([dout_fp8.shape[0], w_fp8.shape[0]], dout.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((dout_fp8, dout_scale.T), (w_fp8, w_scale), dx, num_sms=112) - dx = dx.reshape(dx_orig_shape) + # ===== do1 = swiglu_grad(o1, None, do2) ===== + do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2) - # ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8) - if hasattr(weight, "main_grad"): - if weight.main_grad is None: - weight.main_grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32) - kitchen_fp8_gemm( - x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, weight.main_grad, rtn_dtype=paddle.float32 + # ===== dx = deep_gemm(do1_fp8, w1_fp8) ===== + dx, do1_t_fp8, do1_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + do1, w1, return_mode="with_input_transpose_quant" + ) + + # ===== dw1 = deep_gemm(x_t_fp8, do1_t_fp8) ===== + if apply_backward_hook: + FP8LinearFunctionBase.compute_expert_w_grad( + x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, w1, rtn_dtype=paddle.float32 ) else: - if weight.grad is None: - weight.grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32) - kitchen_fp8_gemm( - x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, weight.grad, rtn_dtype=paddle.float32 + dw1 = FP8LinearFunctionBase.kitchen_gemm( + x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, rtn_dtype=paddle.float32 ) - if hasattr(weight, "_apply_backward_hook"): - weight._apply_backward_hook() - - return dx - + if apply_backward_hook: + return dx + else: + assert dw1 is not None and dw2 is not None + return dx, dw1, dw2 -class FP8KeepXLinear(paddle.nn.Layer): - def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None: - super().__init__() - self._dtype = self._helper.get_default_dtype() + @staticmethod + def fp8_mlp_fwd(x, w1, w2): + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) - self.weight = self.create_parameter( - shape=[in_features, out_features], - dtype="bfloat16", - is_bias=False, + # ===== o1 = deep_gemm(x_fp8, w1_t_fp8) ===== + o1, x_fp8, x_scale = FP8LinearFunctionBase.compute_fp8_linear( + x, w1, weight_transpose=True, return_transpose_only=True, return_mode="with_input_quant" ) - def forward(self, x): - return FP8LinearKeepXFunction.apply(x, self) + # ===== o2 = swiglu(o1) ===== + o2 = swiglu(o1) + # ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) ===== + o3 = FP8LinearFunctionBase.compute_fp8_linear(o2, w2, weight_transpose=True, return_transpose_only=True) -def fp8_mlp_fwd(x, w1, w2): - x_orig_shape = x.shape - x = x.reshape([-1, x_orig_shape[-1]]) + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) - _, _, o3 = common_fp8_mlp_fwd(x, w1, w2) + return x_fp8, x_scale, o3 - if len(x_orig_shape) > 2: - o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + @staticmethod + def fp8_mlp_fwd_norm_rc(x, norm_w, norm_eps, w1, w2): + # ===== compute norm_output ===== + norm_output, _ = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + # ===== compute fp8_mlp_fwd ===== + _, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2) + return o3 - return o3 + @staticmethod + def fp8_mlp_bwd(do3, x, w1, w2): + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) -def fp8_mlp_fwd_norm_rc(x, norm_w, norm_eps, w1, w2): - # ===== compute norm_output ===== - norm_output, _ = fused_ln.fused_rms_norm(x, norm_w, norm_eps) - # ===== compute fp8_mlp_fwd ===== - o3 = fp8_mlp_fwd(norm_output, w1, w2) - return o3 + x_fp8, x_scale, x_t_fp8, x_t_scale = FP8LinearFunctionBase.padding_and_quant_input(x) + dx = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_backward_hook=True + ) -def fp8_mlp_bwd(do3, x, w1, w2): - do3_orig_shape = do3.shape - do3 = do3.reshape([-1, do3_orig_shape[-1]]) + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) - x_orig_shape = x.shape - x = x.reshape([-1, x_orig_shape[-1]]) + return dx - if x.shape[0] % 128 == 0: - x_fp8, x_scale, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, output_scale_transpose=True, quant_method="1x128", input_transpose=True - ) - else: - x_fp8, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, output_scale_transpose=True, quant_method="1x128", input_transpose=False - ) - x = padding(x, 0) - x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True - ) + @staticmethod + def fp8_mlp_bwd_norm_rc(do3, x, norm_w, norm_eps, w1, w2): + # ===== recompute norm_output ===== + norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) - dx = common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_backward_hook=True) + # ===== compute fp8_mlp_fwd ===== + d_norm_output = FP8LinearFunctionBase.fp8_mlp_bwd(do3, norm_output, w1, w2) - if len(x_orig_shape) > 2: - dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + # ===== compute norm grad ===== + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, norm_w, invar, d_norm_output, norm_eps) - return dx + if hasattr(norm_w, "main_grad"): + if norm_w.main_grad is None: + norm_w.main_grad = paddle.zeros(shape=norm_w.shape, dtype=paddle.float32) + norm_w.main_grad += d_rms_norm_weight + else: + if norm_w.grad is None: + norm_w.grad = paddle.zeros(shape=norm_w.shape, dtype=paddle.float32) + norm_w.grad += d_rms_norm_weight + if hasattr(norm_w, "_apply_backward_hook"): + norm_w._apply_backward_hook() -def fp8_mlp_bwd_norm_rc(do3, x, norm_w, norm_eps, w1, w2): - # ===== recompute norm_output ===== - norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + return dx - # ===== compute fp8_mlp_fwd ===== - d_norm_output = fp8_mlp_bwd(do3, norm_output, w1, w2) - # ===== compute norm grad ===== - dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, norm_w, invar, d_norm_output, norm_eps) +class FP8LinearFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, custom_map, keep_x=False): + weight = custom_map.weight + x_orig_shape = x.shape - if hasattr(norm_w, "main_grad"): - if norm_w.main_grad is None: - norm_w.main_grad = paddle.zeros(shape=norm_w.shape, dtype=paddle.float32) - norm_w.main_grad += d_rms_norm_weight - else: - if norm_w.grad is None: - norm_w.grad = paddle.zeros(shape=norm_w.shape, dtype=paddle.float32) - norm_w.grad += d_rms_norm_weight - - if hasattr(norm_w, "_apply_backward_hook"): - norm_w._apply_backward_hook() - - return dx - - -def common_fp8_mlp_fwd(x, w1, w2): - # ===== o1 = deep_gemm(x_fp8, w1_t_fp8) ===== - x_fp8, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, output_scale_transpose=True, quant_method="1x128", input_transpose=False - ) - - w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True - ) - o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=x.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=112) - - # ===== o2 = swiglu(o1) ===== - o2 = swiglu(o1) - o2_fp8, o2_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - o2, output_scale_transpose=True, quant_method="1x128", input_transpose=False - ) - - # ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) ===== - w2_t_fp8, w2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - w2, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True - ) - o3 = paddle.empty([o2_fp8.shape[0], w2_t_fp8.shape[0]], dtype=o1.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((o2_fp8, o2_scale.T), (w2_t_fp8, w2_t_scale), o3, num_sms=112) - return x_fp8, x_scale, o3 - - -def common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_backward_hook=False): - w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True - ) - o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=112) - - # ===== [recompute] o2 = swiglu(o1) ===== - o2 = swiglu(o1) - - # ===== do2 = deep_gemm(do3_fp8, w2_fp8) - if do3.shape[0] % 512 != 0: - do3_fp8, do3_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - do3, output_scale_transpose=True, quant_method="1x128", input_transpose=False - ) - do3 = padding(do3, 0) - do3_t_fp8, do3_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - do3, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True - ) - else: - do3_fp8, do3_scale, do3_t_fp8, do3_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - do3, output_scale_transpose=True, quant_method="1x128", input_transpose=True - ) - w2_fp8, w2_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - w2, output_scale_transpose=False, quant_method="128x128", input_transpose=False - ) - do2 = paddle.empty([do3_fp8.shape[0], w2_fp8.shape[0]], do3.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((do3_fp8, do3_scale.T), (w2_fp8, w2_scale), do2, num_sms=112) - - # ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8) - o2 = padding(o2, 0) - o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - o2, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True - ) - - if apply_backward_hook: - if hasattr(w2, "main_grad"): - if w2.main_grad is None: - w2.main_grad = paddle.zeros(shape=w2.shape, dtype=paddle.float32) - kitchen_fp8_gemm( - o2_t_fp8, - o2_t_scale, - do3_t_fp8, - do3_t_scale, - True, - True, - w2.main_grad, - paddle.float32, + # deep_gemm only support 2D + x = x.reshape([-1, x_orig_shape[-1]]).contiguous() + + if keep_x: + out = FP8LinearFunctionBase.compute_fp8_linear( + x, + weight, + weight_transpose=True, + return_transpose_only=True, ) + # save for bwd + out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]]) + ctx.save_for_backward(x, weight) + return out else: - if w2.grad is None: - w2.grad = paddle.zeros(shape=w2.shape, dtype=paddle.float32) - kitchen_fp8_gemm( - o2_t_fp8, - o2_t_scale, - do3_t_fp8, - do3_t_scale, - True, - True, - w2.grad, - paddle.float32, + x_t = x.T + out, x_t_fp8, x_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + x, weight, weight_transpose=True, return_transpose_only=True, return_mode="with_input_transpose_quant" ) - if hasattr(w2, "_apply_backward_hook"): - w2._apply_backward_hook() - else: - dw2 = kitchen_fp8_gemm(o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, rtn_dtype=paddle.float32) + out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]]) + ctx.save_for_backward((x_t_fp8, x_t_scale), weight) + ctx.x_t_shape = x_t.shape + return out - # ===== do1 = swiglu_grad(o1, None, do2) ===== - do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2) + @staticmethod + def backward(ctx, dout): + x, weight = ctx.saved_tensor() + dout_2d = dout.reshape([-1, dout.shape[-1]]) - # ===== dx = deep_gemm(do1_fp8, w1_fp8) - if do1.shape[0] % 512 != 0: - do1_fp8, do1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - do1, output_scale_transpose=True, quant_method="1x128", input_transpose=False - ) - do1 = padding(do1, 0) - do1_t_fp8, do1_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - do1, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True - ) - else: - do1_fp8, do1_scale, do1_t_fp8, do1_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - do1, output_scale_transpose=True, quant_method="1x128", input_transpose=True - ) - w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - w1, output_scale_transpose=False, quant_method="128x128", input_transpose=False - ) - dx = paddle.empty([do1_fp8.shape[0], w1_fp8.shape[0]], do1.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((do1_fp8, do1_scale.T), (w1_fp8, w1_scale), dx, num_sms=112) - - # ===== dw1 = deep_gemm(x_t_fp8, do1_t_fp8) - if apply_backward_hook: - if hasattr(w1, "main_grad"): - if w1.main_grad is None: - w1.main_grad = paddle.zeros(shape=w1.shape, dtype=paddle.float32) - kitchen_fp8_gemm( - x_t_fp8, - x_t_scale, - do1_t_fp8, - do1_t_scale, - True, - True, - w1.main_grad, - paddle.float32, + keep_x = not isinstance(x, tuple) + + if keep_x: + # padding x and quant + dx_orig_shape = x.shape + x = FP8LinearFunctionBase.padding(x, 0) + x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True ) + + # ===== dx = deep_gemm(dout_fp8, w_fp8) + dx, dout_t_fp8, dout_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + dout_2d, weight, weight_transpose=False, return_mode="with_input_transpose_quant" + ) + dx = dx.reshape(dx_orig_shape) + else: - if w1.grad is None: - w1.grad = paddle.zeros(shape=w1.shape, dtype=paddle.float32) - kitchen_fp8_gemm( - x_t_fp8, - x_t_scale, - do1_t_fp8, - do1_t_scale, - True, - True, - w1.grad, - paddle.float32, + x_t_fp8, x_t_scale = x + + # ===== dx = deep_gemm(dout_fp8, w_fp8) + dx, dout_t_fp8, dout_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + dout_2d, weight, weight_transpose=False, return_mode="with_input_transpose_quant" ) - if hasattr(w1, "_apply_backward_hook"): - w1._apply_backward_hook() - else: - dw1 = kitchen_fp8_gemm(x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, rtn_dtype=paddle.float32) + dx_orig_shape = dout.shape[:-1] + dx_orig_shape.append(ctx.x_t_shape[0]) + dx = dx.reshape(dx_orig_shape) - if apply_backward_hook: + # ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8) + FP8LinearFunctionBase.compute_expert_w_grad( + x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, weight, paddle.float32 + ) return dx - else: - assert dw1 is not None and dw2 is not None - return dx, dw1, dw2 -class FP8MlpFunction(paddle.autograd.PyLayer): +class FP8Linear(paddle.nn.Layer): + def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.weight = self.create_parameter( + shape=[in_features, out_features], + dtype="bfloat16", + is_bias=False, + ) + + def forward(self, x): + return FP8LinearFunction.apply(x, self, keep_x=False) + + +class FP8KeepXLinear(paddle.nn.Layer): + def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.weight = self.create_parameter( + shape=[in_features, out_features], + dtype="bfloat16", + is_bias=False, + ) + + def forward(self, x): + return FP8LinearFunction.apply(x, self, keep_x=True) + + +class FusedNormFP8MLPFunction(paddle.autograd.PyLayer): @staticmethod - def forward(ctx, x, w1, w2): + def forward(ctx, x, norm_w, w1, w2, norm_eps): + # ===== compute norm_output ===== + norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) # ===== reshape for deep_gemm, since deep_gemm only support 2D ===== - x_orig_shape = x.shape - x = x.reshape([-1, x_orig_shape[-1]]) + x_orig_shape = norm_output.shape + norm_output = norm_output.reshape([-1, x_orig_shape[-1]]) - # ===== call func common_fp8_mlp_fwd ===== - x_fp8, x_scale, o3 = common_fp8_mlp_fwd(x, w1, w2) + # ===== call func fp8_mlp_fwd ===== + o3, _, _ = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2) # ===== reshape to origin shape ===== if len(x_orig_shape) > 2: @@ -559,10 +488,11 @@ def forward(ctx, x, w1, w2): # ===== save for backward ===== ctx.save_for_backward( - x_fp8, - x_scale, + x, + norm_w, w1, w2, + norm_eps, paddle.to_tensor(x_orig_shape, dtype="int64", place=paddle.CPUPlace()), ) return o3 @@ -574,53 +504,50 @@ def backward(ctx, do3): do3 = do3.reshape([-1, do3_orig_shape[-1]]) # ===== recive saved tensors ===== - x_fp8, x_scale, w1, w2, x_orig_shape = ctx.saved_tensor() + x, norm_w, w1, w2, norm_eps, x_orig_shape = ctx.saved_tensor() + + # ===== recompute norm ===== + norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) # ===== compute x_t_fp8, x_t_scale for dw1 ===== - x_dequant_fp16 = paddle.incubate.nn.functional.fused_act_dequant(x_fp8, x_scale.T.contiguous()) - x_dequant_fp16 = padding(x_dequant_fp16, 0) + norm_output = norm_output.reshape([-1, x_orig_shape[-1]]) - x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x_dequant_fp16, - output_scale_transpose=True, - quant_method="1x128", - input_transpose=True, - return_transpose_only=True, + x_fp8, x_scale, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + norm_output, output_scale_transpose=True, quant_method="1x128", input_transpose=True ) # ===== call func common_fp8_mlp_bwd ===== - dx, dw1, dw2 = common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2) + d_norm_output, dw1, dw2 = FP8LinearFunctionBase.fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2) # ===== reshape to origin shape ===== if len(x_orig_shape) > 2: - dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + d_norm_output = d_norm_output.reshape([x_orig_shape[0], -1, d_norm_output.shape[-1]]) - return dx, dw1, dw2 + # ===== compute norm grad ===== + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, norm_w, invar, d_norm_output, norm_eps) + + return dx, d_rms_norm_weight, dw1, dw2 -class FP8NormMlpRecomputeFunction(paddle.autograd.PyLayer): +class FP8MlpFunction(paddle.autograd.PyLayer): @staticmethod - def forward(ctx, x, norm_w, w1, w2, norm_eps): - # ===== compute norm_output ===== - norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + def forward(ctx, x, w1, w2): # ===== reshape for deep_gemm, since deep_gemm only support 2D ===== - x_orig_shape = norm_output.shape - norm_output = norm_output.reshape([-1, x_orig_shape[-1]]) - - # ===== call func common_fp8_mlp_fwd ===== - _, _, o3 = common_fp8_mlp_fwd(norm_output, w1, w2) + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + # ===== call func fp8_mlp_fwd ===== + x_fp8, x_scale, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(x, w1, w2) # ===== reshape to origin shape ===== if len(x_orig_shape) > 2: o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) # ===== save for backward ===== ctx.save_for_backward( - x, - norm_w, + x_fp8, + x_scale, w1, w2, - norm_eps, paddle.to_tensor(x_orig_shape, dtype="int64", place=paddle.CPUPlace()), ) return o3 @@ -632,28 +559,28 @@ def backward(ctx, do3): do3 = do3.reshape([-1, do3_orig_shape[-1]]) # ===== recive saved tensors ===== - x, norm_w, w1, w2, norm_eps, x_orig_shape = ctx.saved_tensor() - - # ===== recompute norm ===== - norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + x_fp8, x_scale, w1, w2, x_orig_shape = ctx.saved_tensor() # ===== compute x_t_fp8, x_t_scale for dw1 ===== - norm_output = norm_output.reshape([-1, x_orig_shape[-1]]) - x_fp8, x_scale, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - norm_output, output_scale_transpose=True, quant_method="1x128", input_transpose=True + x_dequant_fp16 = paddle.incubate.nn.functional.fused_act_dequant(x_fp8, x_scale.T.contiguous()) + x_dequant_fp16 = FP8LinearFunctionBase.padding(x_dequant_fp16, 0) + + x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x_dequant_fp16, + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + return_transpose_only=True, ) # ===== call func common_fp8_mlp_bwd ===== - d_norm_output, dw1, dw2 = common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2) + dx, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2) # ===== reshape to origin shape ===== if len(x_orig_shape) > 2: - d_norm_output = d_norm_output.reshape([x_orig_shape[0], -1, d_norm_output.shape[-1]]) - - # ===== compute norm grad ===== - dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, norm_w, invar, d_norm_output, norm_eps) + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) - return dx, d_rms_norm_weight, dw1, dw2 + return dx, dw1, dw2 class FP8Mlp(paddle.nn.Layer): @@ -691,7 +618,7 @@ def __init__( def forward(self, x): if self.using_post_norm_recompute: - return FP8NormMlpRecomputeFunction.apply(x, self.norm_weight, self.w1, self.w2, self.norm_eps) + return FusedNormFP8MLPFunction.apply(x, self.norm_weight, self.w1, self.w2, self.norm_eps) else: return FP8MlpFunction.apply(x, self.w1, self.w2) @@ -955,34 +882,9 @@ def bwd_down_weight(self, do3, o2, expert_w2): do3_t_fp8, do3_t_scale = self.fused_transpose_split_quant(do3, self.tokens_per_expert, True) for i in range(len(expert_w2)): - if hasattr(expert_w2[i], "main_grad"): - if expert_w2[i].main_grad is None: - expert_w2[i].main_grad = paddle.zeros(shape=expert_w2[i].shape, dtype=paddle.float32) - kitchen_fp8_gemm( - o2_t_fp8[i], - o2_t_scale[i], - do3_t_fp8[i], - do3_t_scale[i], - True, - True, - expert_w2[i].main_grad, - paddle.float32, - ) - else: - if expert_w2[i].grad is None: - expert_w2[i].grad = paddle.zeros(shape=expert_w2[i].shape, dtype=paddle.float32) - kitchen_fp8_gemm( - o2_t_fp8[i], - o2_t_scale[i], - do3_t_fp8[i], - do3_t_scale[i], - True, - True, - expert_w2[i].grad, - paddle.float32, - ) - if hasattr(expert_w2[i], "_apply_backward_hook"): - expert_w2[i]._apply_backward_hook() + FP8LinearFunctionBase.compute_expert_w_grad( + o2_t_fp8[i], o2_t_scale[i], do3_t_fp8[i], do3_t_scale[i], True, True, expert_w2[i], paddle.float32 + ) def bwd_gate_up_weight(self, do1, input_x, expert_w1, clear_input=False): """ @@ -1001,34 +903,16 @@ def bwd_gate_up_weight(self, do1, input_x, expert_w1, clear_input=False): do1_t_fp8, do1_t_scale = self.fused_transpose_split_quant(do1, self.tokens_per_expert, True) for i in range(len(expert_w1)): - if hasattr(expert_w1[i], "main_grad"): - if expert_w1[i].main_grad is None: - expert_w1[i].main_grad = paddle.zeros(shape=expert_w1[i].shape, dtype=paddle.float32) - kitchen_fp8_gemm( - input_x_t_fp8[i], - input_x_t_scale[i], - do1_t_fp8[i], - do1_t_scale[i], - True, - True, - expert_w1[i].main_grad, - paddle.float32, - ) - else: - if expert_w1[i].grad is None: - expert_w1[i].grad = paddle.zeros(shape=expert_w1[i].shape, dtype=paddle.float32) - kitchen_fp8_gemm( - input_x_t_fp8[i], - input_x_t_scale[i], - do1_t_fp8[i], - do1_t_scale[i], - True, - True, - expert_w1[i].grad, - paddle.float32, - ) - if hasattr(expert_w1[i], "_apply_backward_hook"): - expert_w1[i]._apply_backward_hook() + FP8LinearFunctionBase.compute_expert_w_grad( + input_x_t_fp8[i], + input_x_t_scale[i], + do1_t_fp8[i], + do1_t_scale[i], + True, + True, + expert_w1[i], + paddle.float32, + ) @paddle.no_grad() def forward(self, hs_out, unzipped_probs, tokens_per_expert, origin_token_per_experts, output=None): diff --git a/paddlenlp/transformers/moe_utils.py b/paddlenlp/transformers/moe_utils.py index dd9756746015..0a209ca29e71 100644 --- a/paddlenlp/transformers/moe_utils.py +++ b/paddlenlp/transformers/moe_utils.py @@ -17,7 +17,7 @@ import numpy as np import paddle -from .fp8_utils import dequantize_fp8_to_fp32 +from .fp8_utils import FP8LinearFunctionBase if not hasattr(paddle.Tensor, "_clear_to_zero_allocation"): @@ -349,7 +349,7 @@ def forward( def backward(self, out_grad, out_grad_scale): hidden_states_grad = paddle.gather(out_grad, self.token_permuted_indices) - output_tokens_grad = dequantize_fp8_to_fp32(out_grad, out_grad_scale) + output_tokens_grad = FP8LinearFunctionBase.dequantize_fp8_to_fp32(out_grad, out_grad_scale) permuted_tokens = self.hidden_states * self.permuted_probs.unsqueeze(-1) permuted_tokens = permuted_tokens.cast(self.hidden_states.dtype)