Skip to content

Commit 4391cfd

Browse files
authored
perf: Optimization for Min-p sampling implementation (#42248)
* refactor(MinPLogitsWarper): optimizing min_tokens_to_keep * Fix(MinPLogitsWarper): edge case when min_tokens_to_keep > vocab_size
1 parent 454c0a7 commit 4391cfd

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/transformers/generation/logits_process.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -748,19 +748,18 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
748748
# Convert logits to probabilities
749749
probs = torch.softmax(scores, dim=-1)
750750
# Get the probability of the top token for each sequence in the batch
751-
top_probs, _ = probs.max(dim=-1, keepdim=True)
751+
top_probs = probs.amax(dim=-1, keepdim=True)
752752
# Calculate the actual min_p threshold by scaling min_p with the top token's probability
753753
scaled_min_p = self.min_p * top_probs
754754
# Create a mask for tokens that have a probability less than the scaled min_p
755755
tokens_to_remove = probs < scaled_min_p
756756

757-
sorted_indices = torch.argsort(scores, descending=True, dim=-1)
758-
sorted_indices_to_remove = torch.gather(tokens_to_remove, dim=-1, index=sorted_indices)
759-
# Keep at least min_tokens_to_keep
760-
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = False
757+
# Keep at least min_tokens_to_keep tokens (clip k to vocab size if needed, avoids index out of range)
758+
k = min(self.min_tokens_to_keep, probs.shape[-1])
759+
sorted_indices = torch.topk(probs, k, dim=-1).indices
760+
tokens_to_remove.scatter_(-1, sorted_indices, False)
761761

762-
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
763-
scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
762+
scores_processed = scores.masked_fill(tokens_to_remove, self.filter_value)
764763
return scores_processed
765764

766765

0 commit comments

Comments
 (0)