Skip to content

Commit 8a0986e

Browse files
authored
support wgrad gemm (#10892)
1 parent ba94cdc commit 8a0986e

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

paddlenlp/transformers/fp8_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
1415
import numpy
1516
import paddle
1617
import paddle.nn.functional as F
@@ -26,8 +27,13 @@ def swiglu(x, y=None):
2627
return F.silu(x) * y
2728

2829

30+
USE_DS_GEMM = os.getenv("USE_DS_GEMM", "False").lower() == "true"
31+
2932
try:
30-
from paddle.incubate.fp8 import deep_gemm
33+
if USE_DS_GEMM:
34+
import deep_gemm
35+
else:
36+
from paddle.incubate.fp8 import deep_gemm
3137
except:
3238
pass
3339

@@ -43,6 +49,13 @@ def swiglu(x, y=None):
4349
def kitchen_fp8_gemm(
4450
x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled, out=None, rtn_dtype=paddle.bfloat16
4551
):
52+
if USE_DS_GEMM:
53+
if out is None:
54+
out = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype)
55+
if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0:
56+
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((x_fp8, x_scale), (w_fp8, w_scale), out, num_sms=112)
57+
return out
58+
4659
if out is not None:
4760
accumulate = True
4861
out_dtype = out.dtype
@@ -1118,13 +1131,14 @@ def backward_dx(self, out_grad):
11181131
self.out_grad = out_grad
11191132

11201133
# clear status for save memory
1121-
self.m_indices = None
11221134
self.unzipped_probs = None
11231135
self.input = None
11241136

11251137
# dx
11261138
dx = self.bwd_gate_up_input(do1, expert_w1, dx=out_grad[0] if isinstance(out_grad, tuple) else out_grad)
11271139

1140+
self.m_indices = None
1141+
11281142
return dx, probs_grad
11291143

11301144
@paddle.no_grad()

0 commit comments

Comments
 (0)