|
| 1 | +from typing import Any |
| 2 | + |
| 3 | +import keras |
| 4 | +from keras import ops |
| 5 | + |
| 6 | +from keras_rs.src import types |
| 7 | +from keras_rs.src.api_export import keras_rs_export |
| 8 | +from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores |
| 9 | +from keras_rs.src.metrics.utils import standardize_call_inputs_ranks |
| 10 | + |
| 11 | + |
| 12 | +@keras_rs_export("keras_rs.losses.ListMLELoss") |
| 13 | +class ListMLELoss(keras.losses.Loss): |
| 14 | + """Implements ListMLE (Maximum Likelihood Estimation) loss for ranking. |
| 15 | +
|
| 16 | + ListMLE loss is a listwise ranking loss that maximizes the likelihood of |
| 17 | + the ground truth ranking. It works by: |
| 18 | + 1. Sorting items by their relevance scores (labels) |
| 19 | + 2. Computing the probability of observing this ranking given the |
| 20 | + predicted scores |
| 21 | + 3. Maximizing this likelihood (minimizing negative log-likelihood) |
| 22 | +
|
| 23 | + The loss is computed as the negative log-likelihood of the ground truth |
| 24 | + ranking given the predicted scores: |
| 25 | +
|
| 26 | + ``` |
| 27 | + loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i))) |
| 28 | + ``` |
| 29 | +
|
| 30 | + where s_i is the predicted score for item i in the sorted order. |
| 31 | +
|
| 32 | + Args: |
| 33 | + temperature: Temperature parameter for scaling logits. Higher values |
| 34 | + make the probability distribution more uniform. Defaults to 1.0. |
| 35 | + reduction: Type of reduction to apply to the loss. In almost all cases |
| 36 | + this should be `"sum_over_batch_size"`. Supported options are |
| 37 | + `"sum"`, `"sum_over_batch_size"`, `"mean"`, |
| 38 | + `"mean_with_sample_weight"` or `None`. Defaults to |
| 39 | + `"sum_over_batch_size"`. |
| 40 | + name: Optional name for the loss instance. |
| 41 | + dtype: The dtype of the loss's computations. Defaults to `None`. |
| 42 | +
|
| 43 | + Examples: |
| 44 | + ```python |
| 45 | + # Basic usage |
| 46 | + loss_fn = ListMLELoss() |
| 47 | +
|
| 48 | + # With temperature scaling |
| 49 | + loss_fn = ListMLELoss(temperature=0.5) |
| 50 | +
|
| 51 | + # Example with synthetic data |
| 52 | + y_true = [[3, 2, 1, 0]] # Relevance scores |
| 53 | + y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores |
| 54 | + loss = loss_fn(y_true, y_pred) |
| 55 | + ``` |
| 56 | + """ |
| 57 | + |
| 58 | + def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None: |
| 59 | + super().__init__(**kwargs) |
| 60 | + |
| 61 | + if temperature <= 0.0: |
| 62 | + raise ValueError( |
| 63 | + f"`temperature` should be a positive float. Received: " |
| 64 | + f"`temperature` = {temperature}." |
| 65 | + ) |
| 66 | + |
| 67 | + self.temperature = temperature |
| 68 | + self._epsilon = 1e-10 |
| 69 | + |
| 70 | + def compute_unreduced_loss( |
| 71 | + self, |
| 72 | + labels: types.Tensor, |
| 73 | + logits: types.Tensor, |
| 74 | + mask: types.Tensor | None = None, |
| 75 | + ) -> tuple[types.Tensor, types.Tensor]: |
| 76 | + """Compute the unreduced ListMLE loss. |
| 77 | +
|
| 78 | + Args: |
| 79 | + labels: Ground truth relevance scores of |
| 80 | + shape [batch_size,list_size]. |
| 81 | + logits: Predicted scores of shape [batch_size, list_size]. |
| 82 | + mask: Optional mask of shape [batch_size, list_size]. |
| 83 | +
|
| 84 | + Returns: |
| 85 | + Tuple of (losses, weights) where losses has shape [batch_size, 1] |
| 86 | + and weights has the same shape. |
| 87 | + """ |
| 88 | + |
| 89 | + valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype)) |
| 90 | + |
| 91 | + if mask is not None: |
| 92 | + valid_mask = ops.logical_and( |
| 93 | + valid_mask, ops.cast(mask, dtype="bool") |
| 94 | + ) |
| 95 | + |
| 96 | + num_valid_items = ops.sum( |
| 97 | + ops.cast(valid_mask, dtype=labels.dtype), axis=1, keepdims=True |
| 98 | + ) |
| 99 | + |
| 100 | + batch_has_valid_items = ops.greater(num_valid_items, 0.0) |
| 101 | + |
| 102 | + labels_for_sorting = ops.where( |
| 103 | + valid_mask, labels, ops.full_like(labels, -1e9) |
| 104 | + ) |
| 105 | + logits_masked = ops.where( |
| 106 | + valid_mask, logits, ops.full_like(logits, -1e9) |
| 107 | + ) |
| 108 | + |
| 109 | + sorted_logits, sorted_valid_mask = sort_by_scores( |
| 110 | + tensors_to_sort=[logits_masked, valid_mask], |
| 111 | + scores=labels_for_sorting, |
| 112 | + mask=None, |
| 113 | + shuffle_ties=False, |
| 114 | + seed=None, |
| 115 | + ) |
| 116 | + sorted_logits = ops.divide( |
| 117 | + sorted_logits, ops.cast(self.temperature, dtype=sorted_logits.dtype) |
| 118 | + ) |
| 119 | + |
| 120 | + valid_logits_for_max = ops.where( |
| 121 | + sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9) |
| 122 | + ) |
| 123 | + raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True) |
| 124 | + raw_max = ops.where( |
| 125 | + batch_has_valid_items, raw_max, ops.zeros_like(raw_max) |
| 126 | + ) |
| 127 | + sorted_logits = ops.subtract(sorted_logits, raw_max) |
| 128 | + |
| 129 | + # Set invalid positions to very negative BEFORE exp |
| 130 | + sorted_logits = ops.where( |
| 131 | + sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9) |
| 132 | + ) |
| 133 | + exp_logits = ops.exp(sorted_logits) |
| 134 | + |
| 135 | + reversed_exp = ops.flip(exp_logits, axis=1) |
| 136 | + reversed_cumsum = ops.cumsum(reversed_exp, axis=1) |
| 137 | + cumsum_from_right = ops.flip(reversed_cumsum, axis=1) |
| 138 | + |
| 139 | + log_normalizers = ops.log(cumsum_from_right + self._epsilon) |
| 140 | + log_probs = ops.subtract(sorted_logits, log_normalizers) |
| 141 | + |
| 142 | + log_probs = ops.where( |
| 143 | + sorted_valid_mask, log_probs, ops.zeros_like(log_probs) |
| 144 | + ) |
| 145 | + |
| 146 | + negative_log_likelihood = ops.negative( |
| 147 | + ops.sum(log_probs, axis=1, keepdims=True) |
| 148 | + ) |
| 149 | + |
| 150 | + negative_log_likelihood = ops.where( |
| 151 | + batch_has_valid_items, |
| 152 | + negative_log_likelihood, |
| 153 | + ops.zeros_like(negative_log_likelihood), |
| 154 | + ) |
| 155 | + |
| 156 | + weights = ops.ones_like(negative_log_likelihood) |
| 157 | + |
| 158 | + return negative_log_likelihood, weights |
| 159 | + |
| 160 | + def call( |
| 161 | + self, |
| 162 | + y_true: types.Tensor, |
| 163 | + y_pred: types.Tensor, |
| 164 | + ) -> types.Tensor: |
| 165 | + """Compute the ListMLE loss. |
| 166 | +
|
| 167 | + Args: |
| 168 | + y_true: tensor or dict. Ground truth values. If tensor, of shape |
| 169 | + `(list_size)` for unbatched inputs or `(batch_size, list_size)` |
| 170 | + for batched inputs. If an item has a label of -1, it is ignored |
| 171 | + in loss computation. If it is a dictionary, it should have two |
| 172 | + keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore |
| 173 | + elements in loss computation. |
| 174 | + y_pred: tensor. The predicted values, of shape `(list_size)` for |
| 175 | + unbatched inputs or `(batch_size, list_size)` for batched |
| 176 | + inputs. Should be of the same shape as `y_true`. |
| 177 | +
|
| 178 | + Returns: |
| 179 | + The loss tensor of shape [batch_size]. |
| 180 | + """ |
| 181 | + mask = None |
| 182 | + if isinstance(y_true, dict): |
| 183 | + if "labels" not in y_true: |
| 184 | + raise ValueError( |
| 185 | + '`"labels"` should be present in `y_true`. Received: ' |
| 186 | + f"`y_true` = {y_true}" |
| 187 | + ) |
| 188 | + |
| 189 | + mask = y_true.get("mask", None) |
| 190 | + y_true = y_true["labels"] |
| 191 | + |
| 192 | + y_true = ops.convert_to_tensor(y_true) |
| 193 | + y_pred = ops.convert_to_tensor(y_pred) |
| 194 | + if mask is not None: |
| 195 | + mask = ops.convert_to_tensor(mask) |
| 196 | + |
| 197 | + y_true, y_pred, mask, _ = standardize_call_inputs_ranks( |
| 198 | + y_true, y_pred, mask |
| 199 | + ) |
| 200 | + |
| 201 | + losses, weights = self.compute_unreduced_loss( |
| 202 | + labels=y_true, logits=y_pred, mask=mask |
| 203 | + ) |
| 204 | + losses = ops.multiply(losses, weights) |
| 205 | + losses = ops.squeeze(losses, axis=-1) |
| 206 | + return losses |
| 207 | + |
| 208 | + # getting config |
| 209 | + def get_config(self) -> dict[str, Any]: |
| 210 | + config: dict[str, Any] = super().get_config() |
| 211 | + config.update({"temperature": self.temperature}) |
| 212 | + return config |
0 commit comments