@@ -105,15 +105,13 @@ def compute_unreduced_loss(
105105 logits_masked = ops .where (
106106 valid_mask , logits , ops .full_like (logits , - 1e9 )
107107 )
108-
109108 sorted_logits , sorted_valid_mask = sort_by_scores (
110109 tensors_to_sort = [logits_masked , valid_mask ],
111110 scores = labels_for_sorting ,
112111 mask = None ,
113112 shuffle_ties = False ,
114113 seed = None ,
115114 )
116-
117115 sorted_logits = ops .divide (
118116 sorted_logits , ops .cast (self .temperature , dtype = sorted_logits .dtype )
119117 )
@@ -139,9 +137,9 @@ def compute_unreduced_loss(
139137 # cumsum_forward = ops.cumsum(exp_logits, axis=1)
140138 # total_sum = ops.sum(exp_logits, axis=1, keepdims=True)
141139 # cumsum_from_right = total_sum - cumsum_forward + exp_logits
142- reversed_exp = ops .flip (exp_logits , axis = [ 1 ] )
140+ reversed_exp = ops .flip (exp_logits , axis = 1 )
143141 reversed_cumsum = ops .cumsum (reversed_exp , axis = 1 )
144- cumsum_from_right = ops .flip (reversed_cumsum , axis = [ 1 ] )
142+ cumsum_from_right = ops .flip (reversed_cumsum , axis = 1 )
145143
146144 log_normalizers = ops .log (cumsum_from_right + self ._epsilon )
147145 log_probs = ops .subtract (sorted_logits , log_normalizers )
0 commit comments