Skip to content

Commit 8a3c2f9

Browse files
Updated code with 'cumsum' function
1 parent bd63e54 commit 8a3c2f9

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

keras_rs/src/losses/list_mle_loss.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,12 @@ def compute_unreduced_loss(
136136
# reversed_exp = ops.flip(exp_logits, axis=1)
137137
# reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
138138
# cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
139-
cumsum_forward = ops.cumsum(exp_logits, axis=1)
140-
total_sum = ops.sum(exp_logits, axis=1, keepdims=True)
141-
cumsum_from_right = total_sum - cumsum_forward + exp_logits
139+
#cumsum_forward = ops.cumsum(exp_logits, axis=1)
140+
#total_sum = ops.sum(exp_logits, axis=1, keepdims=True)
141+
#cumsum_from_right = total_sum - cumsum_forward + exp_logits
142+
reversed_exp = ops.flip(exp_logits, axis=1)
143+
reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
144+
cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
142145

143146
log_normalizers = ops.log(cumsum_from_right + self._epsilon)
144147
log_probs = ops.subtract(sorted_logits, log_normalizers)

keras_rs/src/losses/list_mle_loss_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def setUp(self):
1616
self.unbatched_labels = ops.array(
1717
[1.0, 0.0, 1.0, 3.0, 2.0], dtype="float32"
1818
)
19-
2019
self.batched_scores = ops.array(
2120
[[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]],
2221
dtype="float32",

0 commit comments

Comments
 (0)