diff --git a/test/llm/test_objectives.py b/test/llm/test_objectives.py index e1cd3a61eb8..9e8f13bdc15 100644 --- a/test/llm/test_objectives.py +++ b/test/llm/test_objectives.py @@ -23,6 +23,7 @@ GRPOLossOutput, MCAdvantage, ) +from torchrl._utils import logger from torchrl.objectives.llm.sft import SFTLoss _has_transformers = importlib.util.find_spec("transformers") is not None @@ -200,7 +201,7 @@ def test_grpo(self, mock_transformer_model, dapo): ) # Create loss module - loss_fn = GRPOLoss(actor_network, eps=eps) + loss_fn = GRPOLoss(actor_network, clip_epsilon=eps) # Create fake data data = _mock_data_grpo(vocab_size=vocab_size, device=device) @@ -245,6 +246,124 @@ def test_grpo(self, mock_transformer_model, dapo): 0 <= loss_vals.clip_fraction <= 1 ), f"clip_fraction out of range: {loss_vals.clip_fraction}" + def test_kl_mask_threshold(self, mock_transformer_model): + """Test that kl_mask_threshold properly filters out high-KL tokens.""" + torch.manual_seed(42) + vocab_size = 1024 + device = ( + torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + ) + + # Create mock model and wrap it + model = mock_transformer_model(vocab_size=vocab_size, device=device) + actor_network = TransformersWrapper( + model, + generate=False, + pad_output=True, + input_mode="history", + ) + + # Create fake data + data = _mock_data_grpo(vocab_size=vocab_size, device=device) + + # First, test that the data works without any threshold + loss_fn_baseline = GRPOLoss( + actor_network, clip_epsilon=0.2, kl_mask_threshold=None + ) + + data_baseline = data.clone() + loss_baseline = loss_fn_baseline(data_baseline) + logger.info(f"Baseline loss (no threshold): {loss_baseline.loss_objective}") + logger.info(f"Baseline ESS: {loss_baseline.ESS}") + + # Check baseline is valid + if not torch.isfinite(loss_baseline.loss_objective): + raise ValueError( + f"Baseline loss is not finite: {loss_baseline.loss_objective}, skipping test" + ) + + # Now test with kl_mask_threshold enabled + # Use a very high threshold that should not mask any tokens + kl_threshold = 100.0 # Extremely high threshold to ensure no masking + loss_fn_with_threshold = GRPOLoss( + actor_network, clip_epsilon=0.2, kl_mask_threshold=kl_threshold + ) + + data_with_threshold = data.clone() + loss_with_threshold = loss_fn_with_threshold(data_with_threshold) + + # Should produce valid output + assert isinstance(loss_with_threshold, GRPOLossOutput) + + # Check that the loss is finite (with such a high threshold, it should be) + assert torch.isfinite( + loss_with_threshold.loss_objective + ), f"loss_with_threshold is not finite: {loss_with_threshold.loss_objective}" + assert torch.isfinite( + loss_with_threshold.ESS + ), f"ESS with threshold is not finite: {loss_with_threshold.ESS}" + + logger.info( + f"Loss with high threshold (100.0): {loss_with_threshold.loss_objective}" + ) + logger.info(f"ESS with high threshold: {loss_with_threshold.ESS}") + + # The losses should be identical or very similar since we're not masking anything + # (the difference comes only from numerical precision) + assert torch.isclose( + loss_baseline.loss_objective, loss_with_threshold.loss_objective, rtol=1e-3 + ), f"Losses differ too much with high threshold: {loss_baseline.loss_objective} vs {loss_with_threshold.loss_objective}" + + def test_failure_missing_entries(self, mock_transformer_model): + """Test that GRPO fails when required keys are missing but works without optional keys.""" + vocab_size = 1024 + device = torch.device("cpu") + + # Create mock model and wrap it + model = mock_transformer_model(vocab_size=vocab_size, device=device) + actor_network = TransformersWrapper( + model, + generate=False, + pad_output=True, + input_mode="history", + ) + + # Create loss module + loss_fn = GRPOLoss(actor_network, clip_epsilon=0.2) + + # Create fake data + data = _mock_data_grpo(vocab_size=vocab_size, device=device) + + # Test 1: Missing sample_log_prob (required) should fail + data_missing_sample_log_prob = data.clone() + data_missing_sample_log_prob.exclude(("log_probs", "full"), inplace=True) + + with pytest.raises(KeyError, match="Couldn't find the log-prob"): + loss_fn(data_missing_sample_log_prob) + + # Test 2: Missing ref_log_probs (optional when kl_to_ref_coeff is None) should work + data_missing_ref = data.clone() + # Remove the ref_log_probs key if it exists + if ("next", "ref_log_probs", "full") in data_missing_ref.keys(True): + data_missing_ref.exclude(("next", "ref_log_probs", "full"), inplace=True) + + # Should work fine without ref_log_probs when kl_to_ref_coeff is None + loss_vals = loss_fn(data_missing_ref) + assert isinstance(loss_vals, GRPOLossOutput) + assert torch.isfinite(loss_vals.loss_objective) + + # Test 3: Missing ref_log_probs when kl_to_ref_coeff is set should fail + loss_fn_with_kl = GRPOLoss(actor_network, clip_epsilon=0.2, kl_to_ref_coeff=0.1) + + data_missing_ref_for_kl = data.clone() + if ("next", "ref_log_probs", "full") in data_missing_ref_for_kl.keys(True): + data_missing_ref_for_kl.exclude( + ("next", "ref_log_probs", "full"), inplace=True + ) + + with pytest.raises(KeyError, match="Couldn't find the ref log-prob"): + loss_fn_with_kl(data_missing_ref_for_kl) + def test_cispo(self, mock_transformer_model): """Test CISPO loss computation with mock models.""" vocab_size = 1024 diff --git a/torchrl/objectives/llm/grpo.py b/torchrl/objectives/llm/grpo.py index e11f04509e9..8d5879033f3 100644 --- a/torchrl/objectives/llm/grpo.py +++ b/torchrl/objectives/llm/grpo.py @@ -101,6 +101,10 @@ class GRPOLoss(LossModule): - float x: symmetric clipping [1 - x, 1 + x] (default: 0.2) - tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper. + kl_mask_threshold (float | None, optional): enable token-wise trust-region filtering (KL-Mask). + When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss. + This stabilizes updates by skipping tokens that drifted too far from the reference distribution + (see table and description; enables per-token trust region). entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the loss to favour exploratory policies. samples_mc_entropy (int, optional): if the distribution retrieved from the policy @@ -189,6 +193,7 @@ def __init__( actor_network: LLMWrapperBase | None = None, *, clip_epsilon: float | tuple[float, float] = 0.2, + kl_mask_threshold: float | None = None, entropy_bonus: bool = True, samples_mc_entropy: int = 1, entropy_coeff: float = 0.01, @@ -208,6 +213,7 @@ def __init__( self.samples_mc_entropy = samples_mc_entropy self.entropy_coeff = entropy_coeff self.reduction = reduction if reduction is not None else "mean" + self.kl_mask_threshold = kl_mask_threshold # Determine device and register clip epsilon as buffer if device is None: @@ -382,6 +388,32 @@ def forward(self, tensordict: TensorDictBase) -> LLMOutputType: tensordict, adv_shape=advantage.shape[:-1] ) mask = dist.mask + + # Optional per-token trust-region filtering (KL-Mask) vs reference policy + if self.kl_mask_threshold is not None and self.kl_mask_threshold > 0: + try: + inference_log_prob = tensordict.get( + self.tensor_keys.sample_log_prob, + as_padded_tensor=True, + padding_side="left", + padding_value=0.0, + ) + except KeyError: + inference_log_prob = None + cur_log_prob = tensordict.get("_cur_log_prob", None) + if (inference_log_prob is not None) and (cur_log_prob is not None): + # Align to valid tokens only (safety) + cur_log_prob_masked = torch.where( + expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0 + ) + inference_log_prob_masked = torch.where( + expand_as_right(mask, inference_log_prob), inference_log_prob, 0.0 + ) + log_is_ref = cur_log_prob_masked - inference_log_prob_masked + kl_token = 0.5 * (log_is_ref**2) + tr_mask = kl_token <= self.kl_mask_threshold + # Combine with attention mask + mask = mask & tr_mask # ESS for logging with torch.no_grad(): # In theory, ESS should be computed on particles sampled from the same source. Here we sample according