1+ from typing import Any
2+
3+ import keras
4+ from keras import ops
5+
6+ from keras_rs .src import types
7+ from keras_rs .src .metrics .utils import standardize_call_inputs_ranks
8+ from keras_rs .src .api_export import keras_rs_export
9+ from keras_rs .src .metrics .ranking_metrics_utils import sort_by_scores
10+
11+
12+ @keras_rs_export ("keras_rs.losses.ListMLELoss" )
13+ class ListMLELoss (keras .losses .Loss ):
14+ """Implements ListMLE (Maximum Likelihood Estimation) loss for ranking.
15+
16+ ListMLE loss is a listwise ranking loss that maximizes the likelihood of
17+ the ground truth ranking. It works by:
18+ 1. Sorting items by their relevance scores (labels)
19+ 2. Computing the probability of observing this ranking given the
20+ predicted scores
21+ 3. Maximizing this likelihood (minimizing negative log-likelihood)
22+
23+ The loss is computed as the negative log-likelihood of the ground truth
24+ ranking given the predicted scores:
25+
26+ ```
27+ loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
28+ ```
29+
30+ where s_i is the predicted score for item i in the sorted order.
31+
32+ Args:
33+ temperature: Temperature parameter for scaling logits. Higher values
34+ make the probability distribution more uniform. Defaults to 1.0.
35+ reduction: Type of reduction to apply to the loss. In almost all cases
36+ this should be `"sum_over_batch_size"`. Supported options are
37+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
38+ `"mean_with_sample_weight"` or `None`. Defaults to
39+ `"sum_over_batch_size"`.
40+ name: Optional name for the loss instance.
41+ dtype: The dtype of the loss's computations. Defaults to `None`.
42+
43+ Examples:
44+ ```python
45+ # Basic usage
46+ loss_fn = ListMLELoss()
47+
48+ # With temperature scaling
49+ loss_fn = ListMLELoss(temperature=0.5)
50+
51+ # Example with synthetic data
52+ y_true = [[3, 2, 1, 0]] # Relevance scores
53+ y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores
54+ loss = loss_fn(y_true, y_pred)
55+ ```
56+ """
57+
58+ def __init__ (self , temperature : float = 1.0 , ** kwargs : Any ) -> None :
59+ super ().__init__ (** kwargs )
60+
61+ if temperature <= 0.0 :
62+ raise ValueError (
63+ f"`temperature` should be a positive float. Received: "
64+ f"`temperature` = { temperature } ."
65+ )
66+
67+ self .temperature = temperature
68+ self ._epsilon = 1e-10
69+
70+ def compute_unreduced_loss (
71+ self ,
72+ labels : types .Tensor ,
73+ logits : types .Tensor ,
74+ mask : types .Tensor | None = None ,
75+ ) -> tuple [types .Tensor , types .Tensor ]:
76+ """Compute the unreduced ListMLE loss.
77+
78+ Args:
79+ labels: Ground truth relevance scores of
80+ shape [batch_size,list_size].
81+ logits: Predicted scores of shape [batch_size, list_size].
82+ mask: Optional mask of shape [batch_size, list_size].
83+
84+ Returns:
85+ Tuple of (losses, weights) where losses has shape [batch_size, 1]
86+ and weights has the same shape.
87+ """
88+
89+ valid_mask = ops .greater_equal (labels , ops .cast (0.0 , labels .dtype ))
90+
91+ if mask is not None :
92+ valid_mask = ops .logical_and (valid_mask , ops .cast (mask , dtype = "bool" ))
93+
94+ num_valid_items = ops .sum (ops .cast (valid_mask , dtype = labels .dtype ),
95+ axis = 1 , keepdims = True )
96+
97+ batch_has_valid_items = ops .greater (num_valid_items , 0.0 )
98+
99+
100+ labels_for_sorting = ops .where (valid_mask , labels , ops .full_like (labels , - 1e9 ))
101+ logits_masked = ops .where (valid_mask , logits , ops .full_like (logits , - 1e9 ))
102+
103+ sorted_logits , sorted_valid_mask = sort_by_scores (
104+ tensors_to_sort = [logits_masked , valid_mask ],
105+ scores = labels_for_sorting ,
106+ mask = None ,
107+ shuffle_ties = False ,
108+ seed = None
109+ )
110+
111+ sorted_logits = ops .divide (
112+ sorted_logits ,
113+ ops .cast (self .temperature , dtype = sorted_logits .dtype )
114+ )
115+
116+ valid_logits_for_max = ops .where (sorted_valid_mask , sorted_logits ,
117+ ops .full_like (sorted_logits , - 1e9 ))
118+ raw_max = ops .max (valid_logits_for_max , axis = 1 , keepdims = True )
119+ raw_max = ops .where (batch_has_valid_items , raw_max , ops .zeros_like (raw_max ))
120+ sorted_logits = sorted_logits - raw_max
121+
122+ exp_logits = ops .exp (sorted_logits )
123+ exp_logits = ops .where (sorted_valid_mask , exp_logits , ops .zeros_like (exp_logits ))
124+
125+ reversed_exp = ops .flip (exp_logits , axis = 1 )
126+ reversed_cumsum = ops .cumsum (reversed_exp , axis = 1 )
127+ cumsum_from_right = ops .flip (reversed_cumsum , axis = 1 )
128+
129+ log_normalizers = ops .log (cumsum_from_right + self ._epsilon )
130+ log_probs = sorted_logits - log_normalizers
131+
132+ log_probs = ops .where (sorted_valid_mask , log_probs , ops .zeros_like (log_probs ))
133+
134+ negative_log_likelihood = - ops .sum (log_probs , axis = 1 , keepdims = True )
135+
136+ negative_log_likelihood = ops .where (batch_has_valid_items , negative_log_likelihood ,
137+ ops .zeros_like (negative_log_likelihood ))
138+
139+ weights = ops .ones_like (negative_log_likelihood )
140+
141+ return negative_log_likelihood , weights
142+
143+ def call (
144+ self ,
145+ y_true : types .Tensor ,
146+ y_pred : types .Tensor ,
147+ ) -> types .Tensor :
148+ """Compute the ListMLE loss.
149+
150+ Args:
151+ y_true: tensor or dict. Ground truth values. If tensor, of shape
152+ `(list_size)` for unbatched inputs or `(batch_size, list_size)`
153+ for batched inputs. If an item has a label of -1, it is ignored
154+ in loss computation. If it is a dictionary, it should have two
155+ keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
156+ elements in loss computation.
157+ y_pred: tensor. The predicted values, of shape `(list_size)` for
158+ unbatched inputs or `(batch_size, list_size)` for batched
159+ inputs. Should be of the same shape as `y_true`.
160+
161+ Returns:
162+ The loss tensor of shape [batch_size].
163+ """
164+ mask = None
165+ if isinstance (y_true , dict ):
166+ if "labels" not in y_true :
167+ raise ValueError (
168+ '`"labels"` should be present in `y_true`. Received: '
169+ f"`y_true` = { y_true } "
170+ )
171+
172+ mask = y_true .get ("mask" , None )
173+ y_true = y_true ["labels" ]
174+
175+ y_true = ops .convert_to_tensor (y_true )
176+ y_pred = ops .convert_to_tensor (y_pred )
177+ if mask is not None :
178+ mask = ops .convert_to_tensor (mask )
179+
180+ y_true , y_pred , mask , _ = standardize_call_inputs_ranks (
181+ y_true , y_pred , mask
182+ )
183+
184+ losses , weights = self .compute_unreduced_loss (
185+ labels = y_true , logits = y_pred , mask = mask
186+ )
187+ losses = ops .multiply (losses , weights )
188+ losses = ops .squeeze (losses , axis = - 1 )
189+ return losses
190+
191+ def get_config (self ) -> dict [str , Any ]:
192+ config : dict [str , Any ] = super ().get_config ()
193+ config .update ({"temperature" : self .temperature })
194+ return config
0 commit comments