-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[ROCm][BugFix] Remove the usage of device_info from aiter
#28383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
94b490f
6455831
2d61722
747f602
06aea92
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,16 +31,12 @@ | |
|
|
||
| if current_platform.is_rocm(): | ||
| import aiter | ||
| from aiter.ops.triton.utils.device_info import get_num_sms | ||
|
|
||
| from vllm.triton_utils import tl, triton | ||
|
|
||
| def block_size(x, head_dim): | ||
| return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) | ||
|
|
||
| def num_programs(head_dim): | ||
| return min(head_dim, get_num_sms()) | ||
|
|
||
| @triton.jit | ||
| def cp_mha_gather_cache_kernel( | ||
| key_cache_ptr, # [num_blocks, page_size, num_head, head_size] | ||
|
|
@@ -143,7 +139,7 @@ def cp_mha_gather_cache( | |
| page_size = key_cache.shape[1] | ||
| num_heads = key_cache.shape[2] | ||
|
|
||
| NUM_PRGMS = num_programs(total_tokens) | ||
| NUM_PRGMS = total_tokens | ||
| BLOCK_SIZE = block_size(key_cache, head_dim) | ||
| grid = lambda meta: (NUM_PRGMS,) | ||
| cp_mha_gather_cache_kernel[grid]( | ||
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While removing the dependency on
aiter.ops.triton.utils.device_info.get_num_smsfixes theModuleNotFoundError, changingNUM_PRGMStototal_tokenscould lead to a significant performance regression. The original logicmin(total_tokens, get_num_sms())capped the number of Triton programs to the number of streaming multiprocessors (SMs) or compute units (CUs) to optimize execution. By settingNUM_PRGMS = total_tokens, you might be launching an excessive number of programs (e.g., one per token), which can be inefficient.A better approach would be to use vLLM's platform abstraction to get the number of compute units. You can replace
get_num_sms()withcurrent_platform.get_cu_count()to preserve the optimization.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ganyi1996ppo
Does this advice help? If it doesn't overall it looks good to me.It seems gemini suggest correctly. I have double checked the
get_sms()from aiter and vLLM'sget_cu_count()they are the same,
VLLM:
vllm/vllm/platforms/rocm.py
Line 454 in d0e186c
and AITER:
https://github.com/ROCm/aiter/blob/de14bec0ca5a9de94e10f5cad4dc1541ac558689/aiter/ops/triton/utils/device_info.py#L4-L9
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments, that's better indeed!