Skip to content

Commit ab5da72

Browse files
Updated code
1 parent b63ec12 commit ab5da72

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

keras_rs/src/losses/list_mle_loss.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,19 @@ def compute_unreduced_loss(
127127
)
128128
sorted_logits = ops.subtract(sorted_logits, raw_max)
129129

130-
exp_logits = ops.exp(sorted_logits)
131-
exp_logits = ops.where(
132-
sorted_valid_mask, exp_logits, ops.zeros_like(exp_logits)
130+
# Set invalid positions to very negative BEFORE exp
131+
sorted_logits = ops.where(
132+
sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
133133
)
134+
exp_logits = ops.exp(sorted_logits)
134135

135136
reversed_exp = ops.flip(exp_logits, axis=1)
136137
reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
137138
cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
138139

139140
log_normalizers = ops.log(cumsum_from_right + self._epsilon)
140141
log_probs = ops.subtract(sorted_logits, log_normalizers)
142+
141143

142144
log_probs = ops.where(
143145
sorted_valid_mask, log_probs, ops.zeros_like(log_probs)

keras_rs/src/losses/list_mle_loss_test.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@
1010

1111
class ListMLELossTest(testing.TestCase, parameterized.TestCase):
1212
def setUp(self):
13-
self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8])
14-
self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0])
13+
self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32")
14+
self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0], dtype="float32")
1515

1616
self.batched_scores = ops.array(
17-
[[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]]
17+
[[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]], dtype="float32"
1818
)
1919
self.batched_labels = ops.array(
20-
[[1.0, 0.0, 1.0, 3.0, 2.0], [0.0, 1.0, 2.0, 3.0, 1.5]]
20+
[[1.0, 0.0, 1.0, 3.0, 2.0], [0.0, 1.0, 2.0, 3.0, 1.5]], dtype="float32"
2121
)
22-
self.expected_output = ops.array([6.865693, 3.088192])
22+
self.expected_output = ops.array([6.865693, 3.088192], dtype="float32")
2323

2424
def test_unbatched_input(self):
2525
loss = ListMLELoss(reduction="none")
@@ -43,7 +43,6 @@ def test_temperature(self):
4343
output_temp = loss_temp(
4444
y_true=self.batched_labels, y_pred=self.batched_scores
4545
)
46-
4746
self.assertAllClose(
4847
output_temp,
4948
[10.969891, 2.1283305],
@@ -60,7 +59,6 @@ def test_invalid_input_rank(self):
6059
def test_loss_reduction(self):
6160
loss = ListMLELoss(reduction="sum_over_batch_size")
6261
output = loss(y_true=self.batched_labels, y_pred=self.batched_scores)
63-
6462
self.assertAlmostEqual(
6563
ops.convert_to_numpy(output), 4.9769425, places=5
6664
)

0 commit comments

Comments
 (0)