Skip to content

Commit 04fb241

Browse files
Add ListMLE Loss (#130)
* Add STU layer * Update keras_rs/src/layers/stu.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras_rs/src/layers/common.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras_rs/src/layers/jagged_tensors.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras_rs/src/layers/hstu_mha_attention.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras_rs/src/layers/stu.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras_rs/src/layers/stu.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras_rs/src/layers/jagged_tensors.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras_rs/src/layers/hstu_compute_output.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Add list mle loss * Debug statements added * pytest added * Added_stable offset code to sorting_labels * Stable_offset code to clear the ambiguity with similar labels * Stable_offset code added to handle sorting_label * Save local changes * Updated few code changes * lint errors corrected * deleted keras backend import statement --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 51d5c82 commit 04fb241

File tree

4 files changed

+331
-0
lines changed

4 files changed

+331
-0
lines changed

keras_rs/api/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
since your modifications would be overwritten.
55
"""
66

7+
from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss
78
from keras_rs.src.losses.pairwise_hinge_loss import (
89
PairwiseHingeLoss as PairwiseHingeLoss,
910
)
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
from typing import Any
2+
3+
import keras
4+
from keras import ops
5+
6+
from keras_rs.src import types
7+
from keras_rs.src.api_export import keras_rs_export
8+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
9+
from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
10+
11+
12+
@keras_rs_export("keras_rs.losses.ListMLELoss")
13+
class ListMLELoss(keras.losses.Loss):
14+
"""Implements ListMLE (Maximum Likelihood Estimation) loss for ranking.
15+
16+
ListMLE loss is a listwise ranking loss that maximizes the likelihood of
17+
the ground truth ranking. It works by:
18+
1. Sorting items by their relevance scores (labels)
19+
2. Computing the probability of observing this ranking given the
20+
predicted scores
21+
3. Maximizing this likelihood (minimizing negative log-likelihood)
22+
23+
The loss is computed as the negative log-likelihood of the ground truth
24+
ranking given the predicted scores:
25+
26+
```
27+
loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
28+
```
29+
30+
where s_i is the predicted score for item i in the sorted order.
31+
32+
Args:
33+
temperature: Temperature parameter for scaling logits. Higher values
34+
make the probability distribution more uniform. Defaults to 1.0.
35+
reduction: Type of reduction to apply to the loss. In almost all cases
36+
this should be `"sum_over_batch_size"`. Supported options are
37+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
38+
`"mean_with_sample_weight"` or `None`. Defaults to
39+
`"sum_over_batch_size"`.
40+
name: Optional name for the loss instance.
41+
dtype: The dtype of the loss's computations. Defaults to `None`.
42+
43+
Examples:
44+
```python
45+
# Basic usage
46+
loss_fn = ListMLELoss()
47+
48+
# With temperature scaling
49+
loss_fn = ListMLELoss(temperature=0.5)
50+
51+
# Example with synthetic data
52+
y_true = [[3, 2, 1, 0]] # Relevance scores
53+
y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores
54+
loss = loss_fn(y_true, y_pred)
55+
```
56+
"""
57+
58+
def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
59+
super().__init__(**kwargs)
60+
61+
if temperature <= 0.0:
62+
raise ValueError(
63+
f"`temperature` should be a positive float. Received: "
64+
f"`temperature` = {temperature}."
65+
)
66+
67+
self.temperature = temperature
68+
self._epsilon = 1e-10
69+
70+
def compute_unreduced_loss(
71+
self,
72+
labels: types.Tensor,
73+
logits: types.Tensor,
74+
mask: types.Tensor | None = None,
75+
) -> tuple[types.Tensor, types.Tensor]:
76+
"""Compute the unreduced ListMLE loss.
77+
78+
Args:
79+
labels: Ground truth relevance scores of
80+
shape [batch_size,list_size].
81+
logits: Predicted scores of shape [batch_size, list_size].
82+
mask: Optional mask of shape [batch_size, list_size].
83+
84+
Returns:
85+
Tuple of (losses, weights) where losses has shape [batch_size, 1]
86+
and weights has the same shape.
87+
"""
88+
89+
valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
90+
91+
if mask is not None:
92+
valid_mask = ops.logical_and(
93+
valid_mask, ops.cast(mask, dtype="bool")
94+
)
95+
96+
num_valid_items = ops.sum(
97+
ops.cast(valid_mask, dtype=labels.dtype), axis=1, keepdims=True
98+
)
99+
100+
batch_has_valid_items = ops.greater(num_valid_items, 0.0)
101+
102+
labels_for_sorting = ops.where(
103+
valid_mask, labels, ops.full_like(labels, -1e9)
104+
)
105+
logits_masked = ops.where(
106+
valid_mask, logits, ops.full_like(logits, -1e9)
107+
)
108+
109+
sorted_logits, sorted_valid_mask = sort_by_scores(
110+
tensors_to_sort=[logits_masked, valid_mask],
111+
scores=labels_for_sorting,
112+
mask=None,
113+
shuffle_ties=False,
114+
seed=None,
115+
)
116+
sorted_logits = ops.divide(
117+
sorted_logits, ops.cast(self.temperature, dtype=sorted_logits.dtype)
118+
)
119+
120+
valid_logits_for_max = ops.where(
121+
sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
122+
)
123+
raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True)
124+
raw_max = ops.where(
125+
batch_has_valid_items, raw_max, ops.zeros_like(raw_max)
126+
)
127+
sorted_logits = ops.subtract(sorted_logits, raw_max)
128+
129+
# Set invalid positions to very negative BEFORE exp
130+
sorted_logits = ops.where(
131+
sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
132+
)
133+
exp_logits = ops.exp(sorted_logits)
134+
135+
reversed_exp = ops.flip(exp_logits, axis=1)
136+
reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
137+
cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
138+
139+
log_normalizers = ops.log(cumsum_from_right + self._epsilon)
140+
log_probs = ops.subtract(sorted_logits, log_normalizers)
141+
142+
log_probs = ops.where(
143+
sorted_valid_mask, log_probs, ops.zeros_like(log_probs)
144+
)
145+
146+
negative_log_likelihood = ops.negative(
147+
ops.sum(log_probs, axis=1, keepdims=True)
148+
)
149+
150+
negative_log_likelihood = ops.where(
151+
batch_has_valid_items,
152+
negative_log_likelihood,
153+
ops.zeros_like(negative_log_likelihood),
154+
)
155+
156+
weights = ops.ones_like(negative_log_likelihood)
157+
158+
return negative_log_likelihood, weights
159+
160+
def call(
161+
self,
162+
y_true: types.Tensor,
163+
y_pred: types.Tensor,
164+
) -> types.Tensor:
165+
"""Compute the ListMLE loss.
166+
167+
Args:
168+
y_true: tensor or dict. Ground truth values. If tensor, of shape
169+
`(list_size)` for unbatched inputs or `(batch_size, list_size)`
170+
for batched inputs. If an item has a label of -1, it is ignored
171+
in loss computation. If it is a dictionary, it should have two
172+
keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
173+
elements in loss computation.
174+
y_pred: tensor. The predicted values, of shape `(list_size)` for
175+
unbatched inputs or `(batch_size, list_size)` for batched
176+
inputs. Should be of the same shape as `y_true`.
177+
178+
Returns:
179+
The loss tensor of shape [batch_size].
180+
"""
181+
mask = None
182+
if isinstance(y_true, dict):
183+
if "labels" not in y_true:
184+
raise ValueError(
185+
'`"labels"` should be present in `y_true`. Received: '
186+
f"`y_true` = {y_true}"
187+
)
188+
189+
mask = y_true.get("mask", None)
190+
y_true = y_true["labels"]
191+
192+
y_true = ops.convert_to_tensor(y_true)
193+
y_pred = ops.convert_to_tensor(y_pred)
194+
if mask is not None:
195+
mask = ops.convert_to_tensor(mask)
196+
197+
y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
198+
y_true, y_pred, mask
199+
)
200+
201+
losses, weights = self.compute_unreduced_loss(
202+
labels=y_true, logits=y_pred, mask=mask
203+
)
204+
losses = ops.multiply(losses, weights)
205+
losses = ops.squeeze(losses, axis=-1)
206+
return losses
207+
208+
# getting config
209+
def get_config(self) -> dict[str, Any]:
210+
config: dict[str, Any] = super().get_config()
211+
config.update({"temperature": self.temperature})
212+
return config
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import keras
2+
from absl.testing import parameterized
3+
from keras import ops
4+
from keras.losses import deserialize
5+
from keras.losses import serialize
6+
7+
from keras_rs.src import testing
8+
from keras_rs.src.losses.list_mle_loss import ListMLELoss
9+
10+
11+
class ListMLELossTest(testing.TestCase, parameterized.TestCase):
12+
def setUp(self):
13+
self.unbatched_scores = ops.array(
14+
[1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32"
15+
)
16+
self.unbatched_labels = ops.array(
17+
[1.0, 0.0, 1.0, 3.0, 2.0], dtype="float32"
18+
)
19+
self.batched_scores = ops.array(
20+
[[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]],
21+
dtype="float32",
22+
)
23+
self.batched_labels = ops.array(
24+
[[1.0, 0.0, 1.0, 3.0, 2.0], [0.0, 1.0, 2.0, 3.0, 1.5]],
25+
dtype="float32",
26+
)
27+
self.expected_output = ops.array([6.865693, 3.088192], dtype="float32")
28+
29+
def test_unbatched_input(self):
30+
loss = ListMLELoss(reduction="none")
31+
output = loss(
32+
y_true=self.unbatched_labels, y_pred=self.unbatched_scores
33+
)
34+
self.assertEqual(output.shape, (1,))
35+
self.assertTrue(ops.convert_to_numpy(output[0]) > 0)
36+
self.assertAllClose(output, [self.expected_output[0]], atol=1e-5)
37+
38+
def test_batched_input(self):
39+
loss = ListMLELoss(reduction="none")
40+
output = loss(y_true=self.batched_labels, y_pred=self.batched_scores)
41+
self.assertEqual(output.shape, (2,))
42+
self.assertTrue(ops.convert_to_numpy(output[0]) > 0)
43+
self.assertTrue(ops.convert_to_numpy(output[1]) > 0)
44+
self.assertAllClose(output, self.expected_output, atol=1e-5)
45+
46+
def test_temperature(self):
47+
loss_temp = ListMLELoss(temperature=0.5, reduction="none")
48+
output_temp = loss_temp(
49+
y_true=self.batched_labels, y_pred=self.batched_scores
50+
)
51+
self.assertAllClose(
52+
output_temp,
53+
[10.969891, 2.1283305],
54+
atol=1e-5,
55+
)
56+
57+
def test_invalid_input_rank(self):
58+
rank_1_input = ops.ones((2, 3, 4))
59+
60+
loss = ListMLELoss()
61+
with self.assertRaises(ValueError):
62+
loss(y_true=rank_1_input, y_pred=rank_1_input)
63+
64+
def test_loss_reduction(self):
65+
loss = ListMLELoss(reduction="sum_over_batch_size")
66+
output = loss(y_true=self.batched_labels, y_pred=self.batched_scores)
67+
self.assertAlmostEqual(
68+
ops.convert_to_numpy(output), 4.9769425, places=5
69+
)
70+
71+
def test_scalar_sample_weight(self):
72+
sample_weight = ops.array(5.0)
73+
loss = ListMLELoss(reduction="none")
74+
75+
output = loss(
76+
y_true=self.batched_labels,
77+
y_pred=self.batched_scores,
78+
sample_weight=sample_weight,
79+
)
80+
81+
self.assertAllClose(
82+
output, self.expected_output * sample_weight, atol=1e-5
83+
)
84+
85+
def test_model_fit(self):
86+
inputs = keras.Input(shape=(20,), dtype="float32")
87+
outputs = keras.layers.Dense(5)(inputs)
88+
model = keras.Model(inputs=inputs, outputs=outputs)
89+
90+
model.compile(loss=ListMLELoss(), optimizer="adam")
91+
model.fit(
92+
x=keras.random.normal((2, 20)),
93+
y=keras.random.randint((2, 5), minval=0, maxval=2),
94+
)
95+
96+
def test_serialization(self):
97+
loss = ListMLELoss(temperature=0.8)
98+
restored = deserialize(serialize(loss))
99+
self.assertDictEqual(loss.get_config(), restored.get_config())

keras_rs/src/metrics/ranking_metrics_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,25 @@ def sort_by_scores(
8585
else:
8686
k = ops.minimum(k, max_possible_k)
8787

88+
# --- Work around for PyTorch instability ---
89+
# Torch's `topk` is not stable with `sorted=True`, unlike JAX and TF.
90+
# See:
91+
# - https://github.com/pytorch/pytorch/issues/27542
92+
# - https://github.com/pytorch/pytorch/issues/88227
93+
#
94+
# This small "stable offset" ensures deterministic tie-breaking for
95+
# equal scores. We can remove this workaround once PyTorch adds a
96+
# `stable=True` flag for topk.
97+
98+
if keras.backend.backend() == "torch" and not shuffle_ties:
99+
list_size = ops.shape(scores)[1]
100+
indices = ops.arange(list_size)
101+
indices = ops.expand_dims(indices, axis=0)
102+
indices = ops.broadcast_to(indices, ops.shape(scores))
103+
stable_offset = ops.cast(indices, scores.dtype) * 1e-6
104+
scores = ops.subtract(scores, stable_offset)
105+
# --- End FIX ---
106+
88107
# Shuffle ties randomly, and push masked values to the beginning.
89108
shuffled_indices = None
90109
if shuffle_ties or mask is not None:

0 commit comments

Comments
 (0)