diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index 5507702fdcd8..f8b6695832f6 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -140,6 +140,8 @@ def _cal_seq_aux_loss(self, gates, top_k, topk_idx) -> paddle.Tensor: paddle.Tensor: The value of sequence auxiliary loss. """ batch_size, seq_len, _ = gates.shape + gates = gates / (gates.sum(axis=-1, keepdim=True) + 1e-20) + _, topk_idx = paddle.topk(gates, top_k, axis=-1) ce = paddle.zeros([batch_size, self.num_experts]) topk_idx = topk_idx.reshape([batch_size, -1]) ce.put_along_axis_(indices=topk_idx, values=paddle.ones([batch_size, seq_len * top_k]), axis=1, reduce="add")