Skip to content

refine fp8_utils #10848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 31, 2025
Merged

refine fp8_utils #10848

merged 13 commits into from
Jul 31, 2025

Conversation

risemeup1
Copy link
Contributor

@risemeup1 risemeup1 commented Jul 15, 2025

Before submitting

  • Lint code. If there are lint issues, please format the code first.
# Install and register `pre-commit` in the project folder
pip install pre-commit && pre-commit install

# Process previous code files separately
pre-commit run --file XXXX.py
  • Add test cases into tests folder. If there are codecov issues, please add tests cases first.

PR types

Others

PR changes

Others

Description

优化了 fp8_utils 中的各个函数,减少了一些重复实现,增加一下可维护性,主要做了以下工作

  • 构建了FP8LinearFunctionBase基础类,把常用的padding,quant,gemm等逻辑进行了封装
  • 把很多相似的类FP8Linear和FP8KeepXLinear中相似的函数进行了合并
  • 把一些功能相似的函数,如common_fp8_mlp_fwd,fp8_mlp_fwd等进行了合并

Copy link

paddle-bot bot commented Jul 15, 2025

Thanks for your contribution!


return res
@staticmethod
def run_deep_gemm(a, a_scale, b, b_scale, out=None, num_sms=112, m_indices=None, is_grouped=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 run_deep_gemm,需要考虑下和 wgrad_gemm 的区别

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

if (x.shape[axis] + 128 - (x.shape[axis] % 128)) % 512 != 0:
padding_size = 512
@staticmethod
def kitchen_fp8_gemm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用特别提 fp8了吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

do3_orig_shape = do3.shape
do3 = do3.reshape([-1, do3_orig_shape[-1]])
@staticmethod
def fp8_mlp_bwd_norm_rc(do3, x, norm_w, norm_eps, w1, w2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有验证这个函数的正确性么

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我确认下


# ===== compute norm grad =====
dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, norm_w, invar, d_norm_output, norm_eps)
if if_keep_x:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if if_keep_x,这两个 if 也太奇怪了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

就叫keep_x吧

if (x.shape[axis] + 128 - (x.shape[axis] % 128)) % 512 != 0:
padding_size = 512
@staticmethod
def kitchen_gemm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

USE_DS_GEMM 下的逻辑全被删除了

@risemeup1 risemeup1 force-pushed the rewrite_fp8_utils branch from 1fb7393 to 6f04692 Compare July 31, 2025 02:44
@phlrain phlrain merged commit c7b0059 into PaddlePaddle:dsv3_dev Jul 31, 2025
2 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants