Skip to content

Commit 2c3f2e6

Browse files
Handled ops.flip for torch backend
1 parent ab5da72 commit 2c3f2e6

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

keras_rs/src/losses/list_mle_loss.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,17 +129,19 @@ def compute_unreduced_loss(
129129

130130
# Set invalid positions to very negative BEFORE exp
131131
sorted_logits = ops.where(
132-
sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
132+
sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
133133
)
134134
exp_logits = ops.exp(sorted_logits)
135135

136-
reversed_exp = ops.flip(exp_logits, axis=1)
137-
reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
138-
cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
136+
# reversed_exp = ops.flip(exp_logits, axis=1)
137+
# reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
138+
# cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
139+
cumsum_forward = ops.cumsum(exp_logits, axis=1)
140+
total_sum = ops.sum(exp_logits, axis=1, keepdims=True)
141+
cumsum_from_right = total_sum - cumsum_forward + exp_logits
139142

140143
log_normalizers = ops.log(cumsum_from_right + self._epsilon)
141144
log_probs = ops.subtract(sorted_logits, log_normalizers)
142-
143145

144146
log_probs = ops.where(
145147
sorted_valid_mask, log_probs, ops.zeros_like(log_probs)

keras_rs/src/losses/list_mle_loss_test.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,20 @@
1010

1111
class ListMLELossTest(testing.TestCase, parameterized.TestCase):
1212
def setUp(self):
13-
self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32")
14-
self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0], dtype="float32")
13+
self.unbatched_scores = ops.array(
14+
[1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32"
15+
)
16+
self.unbatched_labels = ops.array(
17+
[1.0, 0.0, 1.0, 3.0, 2.0], dtype="float32"
18+
)
1519

1620
self.batched_scores = ops.array(
17-
[[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]], dtype="float32"
21+
[[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]],
22+
dtype="float32",
1823
)
1924
self.batched_labels = ops.array(
20-
[[1.0, 0.0, 1.0, 3.0, 2.0], [0.0, 1.0, 2.0, 3.0, 1.5]], dtype="float32"
25+
[[1.0, 0.0, 1.0, 3.0, 2.0], [0.0, 1.0, 2.0, 3.0, 1.5]],
26+
dtype="float32",
2127
)
2228
self.expected_output = ops.array([6.865693, 3.088192], dtype="float32")
2329

0 commit comments

Comments
 (0)