11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import os
14
15
import numpy
15
16
import paddle
16
17
import paddle .nn .functional as F
@@ -26,8 +27,13 @@ def swiglu(x, y=None):
26
27
return F .silu (x ) * y
27
28
28
29
30
+ USE_DS_GEMM = os .getenv ("USE_DS_GEMM" , "False" ).lower () == "true"
31
+
29
32
try :
30
- from paddle .incubate .fp8 import deep_gemm
33
+ if USE_DS_GEMM :
34
+ import deep_gemm
35
+ else :
36
+ from paddle .incubate .fp8 import deep_gemm
31
37
except :
32
38
pass
33
39
@@ -43,6 +49,13 @@ def swiglu(x, y=None):
43
49
def kitchen_fp8_gemm (
44
50
x_fp8 , x_scale , w_fp8 , w_scale , is_a_1d_scaled , is_b_1d_scaled , out = None , rtn_dtype = paddle .bfloat16
45
51
):
52
+ if USE_DS_GEMM :
53
+ if out is None :
54
+ out = paddle .zeros ([x_fp8 .shape [0 ], w_fp8 .shape [0 ]], rtn_dtype )
55
+ if numpy .prod (x_fp8 .shape ) != 0 and numpy .prod (w_fp8 .shape ) != 0 :
56
+ deep_gemm .wgrad_gemm_fp8_fp8_fp32_nt ((x_fp8 , x_scale ), (w_fp8 , w_scale ), out , num_sms = 112 )
57
+ return out
58
+
46
59
if out is not None :
47
60
accumulate = True
48
61
out_dtype = out .dtype
@@ -1118,13 +1131,14 @@ def backward_dx(self, out_grad):
1118
1131
self .out_grad = out_grad
1119
1132
1120
1133
# clear status for save memory
1121
- self .m_indices = None
1122
1134
self .unzipped_probs = None
1123
1135
self .input = None
1124
1136
1125
1137
# dx
1126
1138
dx = self .bwd_gate_up_input (do1 , expert_w1 , dx = out_grad [0 ] if isinstance (out_grad , tuple ) else out_grad )
1127
1139
1140
+ self .m_indices = None
1141
+
1128
1142
return dx , probs_grad
1129
1143
1130
1144
@paddle .no_grad ()
0 commit comments