Skip to content

Commit 3a402e3

Browse files
Modified code with axis
1 parent 8a3c2f9 commit 3a402e3

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

keras_rs/src/losses/list_mle_loss.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,12 @@ def compute_unreduced_loss(
136136
# reversed_exp = ops.flip(exp_logits, axis=1)
137137
# reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
138138
# cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
139-
#cumsum_forward = ops.cumsum(exp_logits, axis=1)
140-
#total_sum = ops.sum(exp_logits, axis=1, keepdims=True)
141-
#cumsum_from_right = total_sum - cumsum_forward + exp_logits
142-
reversed_exp = ops.flip(exp_logits, axis=1)
139+
# cumsum_forward = ops.cumsum(exp_logits, axis=1)
140+
# total_sum = ops.sum(exp_logits, axis=1, keepdims=True)
141+
# cumsum_from_right = total_sum - cumsum_forward + exp_logits
142+
reversed_exp = ops.flip(exp_logits, axis=[1])
143143
reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
144-
cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
144+
cumsum_from_right = ops.flip(reversed_cumsum, axis=[1])
145145

146146
log_normalizers = ops.log(cumsum_from_right + self._epsilon)
147147
log_probs = ops.subtract(sorted_logits, log_normalizers)
@@ -211,6 +211,7 @@ def call(
211211
losses = ops.multiply(losses, weights)
212212
losses = ops.squeeze(losses, axis=-1)
213213
return losses
214+
214215
# getting config
215216
def get_config(self) -> dict[str, Any]:
216217
config: dict[str, Any] = super().get_config()

keras_rs/src/losses/list_mle_loss_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from keras_rs.src import testing
88
from keras_rs.src.losses.list_mle_loss import ListMLELoss
99

10-
1110
class ListMLELossTest(testing.TestCase, parameterized.TestCase):
1211
def setUp(self):
1312
self.unbatched_scores = ops.array(

0 commit comments

Comments
 (0)