11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import os
15
+
14
16
import numpy
15
17
import paddle
16
18
import paddle .nn .functional as F
@@ -26,8 +28,13 @@ def swiglu(x, y=None):
26
28
return F .silu (x ) * y
27
29
28
30
31
+ USE_DS_GEMM = os .getenv ("USE_DS_GEMM" , "False" ).lower () == "true"
32
+
29
33
try :
30
- from paddle .incubate .fp8 import deep_gemm
34
+ if USE_DS_GEMM :
35
+ import deep_gemm
36
+ else :
37
+ from paddle .incubate .fp8 import deep_gemm
31
38
except :
32
39
pass
33
40
@@ -82,9 +89,16 @@ def padding_and_quant_input(tensor):
82
89
return tensor_fp8 , tensor_scale , tensor_t_fp8 , tensor_t_scale
83
90
84
91
@staticmethod
85
- def kitchen_gemm (
86
- x_fp8 , x_scale , w_fp8 , w_scale , is_a_1d_scaled = True , is_b_1d_scaled = True , out = None , rtn_dtype = paddle .bfloat16
92
+ def kitchen_fp8_gemm (
93
+ x_fp8 , x_scale , w_fp8 , w_scale , is_a_1d_scaled , is_b_1d_scaled , out = None , rtn_dtype = paddle .bfloat16
87
94
):
95
+ if USE_DS_GEMM :
96
+ if out is None :
97
+ out = paddle .zeros ([x_fp8 .shape [0 ], w_fp8 .shape [0 ]], rtn_dtype )
98
+ if numpy .prod (x_fp8 .shape ) != 0 and numpy .prod (w_fp8 .shape ) != 0 :
99
+ deep_gemm .wgrad_gemm_fp8_fp8_fp32_nt ((x_fp8 , x_scale ), (w_fp8 , w_scale ), out , num_sms = 112 )
100
+ return out
101
+
88
102
if out is not None :
89
103
accumulate = True
90
104
out_dtype = out .dtype
0 commit comments