Skip to content

Commit 2d42fc6

Browse files
committed
fix
1 parent 7e4c254 commit 2d42fc6

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

paddlenlp/transformers/fp8_utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
15+
1416
import numpy
1517
import paddle
1618
import paddle.nn.functional as F
@@ -26,8 +28,13 @@ def swiglu(x, y=None):
2628
return F.silu(x) * y
2729

2830

31+
USE_DS_GEMM = os.getenv("USE_DS_GEMM", "False").lower() == "true"
32+
2933
try:
30-
from paddle.incubate.fp8 import deep_gemm
34+
if USE_DS_GEMM:
35+
import deep_gemm
36+
else:
37+
from paddle.incubate.fp8 import deep_gemm
3138
except:
3239
pass
3340

@@ -82,9 +89,16 @@ def padding_and_quant_input(tensor):
8289
return tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale
8390

8491
@staticmethod
85-
def kitchen_gemm(
86-
x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled=True, is_b_1d_scaled=True, out=None, rtn_dtype=paddle.bfloat16
92+
def kitchen_fp8_gemm(
93+
x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled, out=None, rtn_dtype=paddle.bfloat16
8794
):
95+
if USE_DS_GEMM:
96+
if out is None:
97+
out = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype)
98+
if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0:
99+
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((x_fp8, x_scale), (w_fp8, w_scale), out, num_sms=112)
100+
return out
101+
88102
if out is not None:
89103
accumulate = True
90104
out_dtype = out.dtype

0 commit comments

Comments
 (0)