-
Notifications
You must be signed in to change notification settings - Fork 3.1k
refine fp8_utils #10848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refine fp8_utils #10848
Conversation
Thanks for your contribution! |
paddlenlp/transformers/fp8_utils.py
Outdated
|
||
return res | ||
@staticmethod | ||
def run_deep_gemm(a, a_scale, b, b_scale, out=None, num_sms=112, m_indices=None, is_grouped=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 run_deep_gemm,需要考虑下和 wgrad_gemm 的区别
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
paddlenlp/transformers/fp8_utils.py
Outdated
if (x.shape[axis] + 128 - (x.shape[axis] % 128)) % 512 != 0: | ||
padding_size = 512 | ||
@staticmethod | ||
def kitchen_fp8_gemm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用特别提 fp8了吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
do3_orig_shape = do3.shape | ||
do3 = do3.reshape([-1, do3_orig_shape[-1]]) | ||
@staticmethod | ||
def fp8_mlp_bwd_norm_rc(do3, x, norm_w, norm_eps, w1, w2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有验证这个函数的正确性么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我确认下
paddlenlp/transformers/fp8_utils.py
Outdated
|
||
# ===== compute norm grad ===== | ||
dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, norm_w, invar, d_norm_output, norm_eps) | ||
if if_keep_x: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if if_keep_x,这两个 if 也太奇怪了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
就叫keep_x吧
…into rewrite_fp8_utils
if (x.shape[axis] + 128 - (x.shape[axis] % 128)) % 512 != 0: | ||
padding_size = 512 | ||
@staticmethod | ||
def kitchen_gemm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
USE_DS_GEMM 下的逻辑全被删除了
1fb7393
to
6f04692
Compare
Before submitting
tests
folder. If there are codecov issues, please add tests cases first.PR types
Others
PR changes
Others
Description
优化了 fp8_utils 中的各个函数,减少了一些重复实现,增加一下可维护性,主要做了以下工作