@@ -136,12 +136,12 @@ def compute_unreduced_loss(
136136 # reversed_exp = ops.flip(exp_logits, axis=1)
137137 # reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
138138 # cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
139- #cumsum_forward = ops.cumsum(exp_logits, axis=1)
140- #total_sum = ops.sum(exp_logits, axis=1, keepdims=True)
141- #cumsum_from_right = total_sum - cumsum_forward + exp_logits
142- reversed_exp = ops .flip (exp_logits , axis = 1 )
139+ # cumsum_forward = ops.cumsum(exp_logits, axis=1)
140+ # total_sum = ops.sum(exp_logits, axis=1, keepdims=True)
141+ # cumsum_from_right = total_sum - cumsum_forward + exp_logits
142+ reversed_exp = ops .flip (exp_logits , axis = [ 1 ] )
143143 reversed_cumsum = ops .cumsum (reversed_exp , axis = 1 )
144- cumsum_from_right = ops .flip (reversed_cumsum , axis = 1 )
144+ cumsum_from_right = ops .flip (reversed_cumsum , axis = [ 1 ] )
145145
146146 log_normalizers = ops .log (cumsum_from_right + self ._epsilon )
147147 log_probs = ops .subtract (sorted_logits , log_normalizers )
@@ -211,6 +211,7 @@ def call(
211211 losses = ops .multiply (losses , weights )
212212 losses = ops .squeeze (losses , axis = - 1 )
213213 return losses
214+
214215 # getting config
215216 def get_config (self ) -> dict [str , Any ]:
216217 config : dict [str , Any ] = super ().get_config ()
0 commit comments