Skip to content

Conversation

zcin
Copy link

@zcin zcin commented Oct 12, 2025

📌 Description

Functions in sampling like top_k_mask_logits require the input tensor to be contiguous. But in reality we just need the last dimension to be contiguous. This PR adds support for tensors that are not contiguous in batch dimension, but contiguous in the last dimension.

🔍 Related Issues

closes #1866

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

zcin added 7 commits October 11, 2025 17:55
Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
…tride

Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
@zcin zcin marked this pull request as ready for review October 12, 2025 19:04
@zcin
Copy link
Author

zcin commented Oct 12, 2025

Hi @yzh119, I added support for top_k_mask_logits and added test cases for it. If the approach looks good, I can also make the same changes to other relevant sampling kernels?

@yzh119
Copy link
Collaborator

yzh119 commented Oct 13, 2025

Hi @yzh119, I added support for top_k_mask_logits and added test cases for it. If the approach looks good, I can also make the same changes to other relevant sampling kernels?

Yes it should be applicable to all these kernels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RuntimeError: logits must be contiguous in flashinfer/sampling.py:375

2 participants