-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[None][feat] Add SM-level disaggregation support #9020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughThis PR introduces softmax-disaggregated context processing by adding configuration structures, a separate context model engine, disaggregated schedulers, and context-aware request filtering to enable independent execution of context and generation phases across multiple engines. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant Executor as PyExecutor
participant ReqQueue as RequestQueue
participant MainSched as Scheduler
participant SmSched as SmDisaggCtxScheduler
participant CtxEngine as CtxModelEngine
participant MainEngine as MainModelEngine
Client->>Executor: enqueue_requests(requests)
Executor->>ReqQueue: fetch_new_requests(num_active_requests_on_engine)
alt SM Disaggregation Enabled
ReqQueue->>ReqQueue: _fetch_new_requests_sm_disagg()
ReqQueue->>MainSched: schedule_request()
MainSched->>SmSched: filter & process via SmDisaggCtxScheduler
SmSched->>SmSched: split context vs generation
Note over SmSched: Context phase
SmSched->>CtxEngine: forward context requests
Note over SmSched: Generation phase
SmSched->>MainEngine: forward generation requests
SmSched->>ReqQueue: return merged results
else Standard Path
ReqQueue->>MainSched: schedule_request()
MainSched->>MainEngine: forward all requests
end
ReqQueue-->>Executor: scheduled requests
Executor-->>Client: await_responses()
sequenceDiagram
participant Warmup
participant Engine1 as ModelEngine
participant Engine2 as CtxModelEngine
participant Engine3 as DraftEngine
participant Prof as Profiler
Warmup->>Engine1: set_is_warmup(True)
Warmup->>Engine2: set_is_warmup(True)
Warmup->>Engine3: set_is_warmup(True)
rect rgb(200, 220, 240)
Note over Prof: Profile Main Engine
Warmup->>Engine1: forward_step()
Engine1->>Prof: record metrics
end
rect rgb(220, 240, 200)
Note over Prof: Profile Context Engine
Warmup->>Engine2: forward_step()
Engine2->>Prof: record metrics
end
rect rgb(240, 220, 200)
Note over Prof: Profile Draft Engine
Warmup->>Engine3: forward_step()
Engine3->>Prof: record metrics
end
Warmup->>Engine1: set_is_warmup(False)
Warmup->>Engine2: set_is_warmup(False)
Warmup->>Engine3: set_is_warmup(False)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Areas requiring extra attention:
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)
1-1: Add NVIDIA Apache-2.0 header (2025).All source files must start with the NVIDIA Apache-2.0 header per coding guidelines.
Apply at top of file:
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # http://www.apache.org/licenses/LICENSE-2.0 + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License.tensorrt_llm/llmapi/llm_args.py (2)
1-1: Add NVIDIA Apache-2.0 header (2025).Please prepend the standard NVIDIA Apache-2.0 header.
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # http://www.apache.org/licenses/LICENSE-2.0 + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License.
2843-2848: Message nit: match the check.The validator allows 0; error message says “greater than 0”. Suggest “>= 0”.
- if self.batch_wait_timeout_ms < 0: - raise ValueError("batch_wait_timeout_ms must be greater than 0") + if self.batch_wait_timeout_ms < 0: + raise ValueError("batch_wait_timeout_ms must be >= 0")tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
1-1: Add NVIDIA Apache-2.0 header (2025).Please prepend the standard header.
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # http://www.apache.org/licenses/LICENSE-2.0 + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License.
🧹 Nitpick comments (6)
tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)
53-55: SM-disagg wiring looks fine; document semantics.Constructor and self.is_sm_disagg flag are clear. Please document in the class docstring what “SM-disaggregation mode” implies for fetch capacity and scheduling so future readers don’t confuse it with network disaggregation.
Also applies to: 64-64
tensorrt_llm/llmapi/llm_args.py (2)
318-335: SmDisaggConfig: clarify bounds and behavior in docstring.Suggest noting valid range for context_sm_percent and that non-positive ctx limits inherit generation limits. No functional change needed.
2888-2900: Support YAML ingestion for sm_disagg_config.Add SmDisaggConfig to field_mapping so dicts in extra_llm_api_options are parsed consistently (mirrors other nested configs).
field_mapping = { "quant_config": QuantConfig, "calib_config": CalibConfig, "build_config": BuildConfig, "decoding_config": DecodingConfig, "enable_build_cache": BuildCacheConfig, "speculative_config": DecodingBaseConfig, "lora_config": LoraConfig, "moe_config": MoeConfig, "attention_dp_config": AttentionDpConfig, "sparse_attention_config": BaseSparseAttentionConfig, + "sm_disagg_config": SmDisaggConfig, }tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
1363-1371: Operator precedence: parenthesize ‘and’ for clarity (RUF021).Make the new SM-disagg condition explicit.
- if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None \ - or self.sm_disagg_enabled and request.max_num_generated_tokens == 0: + if (next_draft_tokens_device is None or request.is_dummy + or request.py_batch_idx is None + or (self.sm_disagg_enabled + and request.max_num_generated_tokens == 0)):
1470-1472: Operator precedence: parenthesize ‘and’ for clarity (RUF021).Same as above in generation path.
- if new_tokens_device is None or request.is_dummy or request.py_batch_idx is None \ - or self.sm_disagg_enabled and request.max_num_generated_tokens == 0: + if (new_tokens_device is None or request.is_dummy + or request.py_batch_idx is None + or (self.sm_disagg_enabled + and request.max_num_generated_tokens == 0)):tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
352-363: Consider extracting validation logic to a helper function.The validation checks are correct but could improve maintainability by extracting to a helper function like
_validate_sm_disagg_config(llm_args). This would make the main flow cleaner and the validation logic more testable.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
tensorrt_llm/_torch/pyexecutor/_util.py(5 hunks)tensorrt_llm/_torch/pyexecutor/executor_request_queue.py(3 hunks)tensorrt_llm/_torch/pyexecutor/llm_request.py(1 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(6 hunks)tensorrt_llm/_torch/pyexecutor/model_loader.py(3 hunks)tensorrt_llm/_torch/pyexecutor/py_executor.py(24 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py(7 hunks)tensorrt_llm/_torch/pyexecutor/scheduler.py(2 hunks)tensorrt_llm/_torch/virtual_memory.py(1 hunks)tensorrt_llm/llmapi/__init__.py(2 hunks)tensorrt_llm/llmapi/llm_args.py(3 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tensorrt_llm/_torch/virtual_memory.pytensorrt_llm/_torch/pyexecutor/llm_request.pytensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/_torch/pyexecutor/scheduler.pytensorrt_llm/llmapi/__init__.pytensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/_torch/pyexecutor/executor_request_queue.pytensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/_torch/virtual_memory.pytensorrt_llm/_torch/pyexecutor/llm_request.pytensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/_torch/pyexecutor/scheduler.pytensorrt_llm/llmapi/__init__.pytensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/_torch/pyexecutor/executor_request_queue.pytensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tensorrt_llm/_torch/virtual_memory.pytensorrt_llm/_torch/pyexecutor/llm_request.pytensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/_torch/pyexecutor/scheduler.pytensorrt_llm/llmapi/__init__.pytensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/_torch/pyexecutor/executor_request_queue.pytensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.py
🧠 Learnings (9)
📚 Learning: 2025-07-22T09:22:14.726Z
Learnt from: yechank-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
Applied to files:
tensorrt_llm/_torch/pyexecutor/llm_request.py
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which can contain default `cuda_graph_config` values, so `llm_args` may already have this config before the extra options processing.
Applied to files:
tensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/llmapi/__init__.pytensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.py
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM's bench configuration, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which is a Dict[str, Any] that can contain default values including `cuda_graph_config`, making the fallback `llm_args["cuda_graph_config"]` safe to use.
Applied to files:
tensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/llmapi/__init__.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.py
📚 Learning: 2025-08-14T15:38:01.771Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: cpp/tensorrt_llm/pybind/thop/bindings.cpp:55-57
Timestamp: 2025-08-14T15:38:01.771Z
Learning: In TensorRT-LLM Python bindings, tensor parameter collections like mla_tensor_params and spec_decoding_tensor_params are kept as required parameters without defaults to maintain API consistency, even when it might affect backward compatibility.
Applied to files:
tensorrt_llm/_torch/pyexecutor/_util.py
📚 Learning: 2025-09-29T15:14:28.503Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 8063
File: tensorrt_llm/lora_manager.py:1080-1112
Timestamp: 2025-09-29T15:14:28.503Z
Learning: In tensorrt_llm/lora_manager.py, when calculating part_sizes for attn_qkv fused LoRA modules, the sizes are correctly multiplied by tp_size because model_config.num_heads and model_config.num_kv_heads are already divided by tp_size (per-TP-rank values), so multiplication is needed to get the original full concatenated dimension size. The interleave_fused_lora_weights_for_tp function provides proper validation.
Applied to files:
tensorrt_llm/_torch/pyexecutor/_util.py
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
Applied to files:
tensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.py
📚 Learning: 2025-08-28T10:22:02.288Z
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 7294
File: tensorrt_llm/_torch/pyexecutor/sampler.py:1191-1197
Timestamp: 2025-08-28T10:22:02.288Z
Learning: In tensorrt_llm/_torch/pyexecutor/sampler.py, the object identity comparison `softmax_req_indices is not group_req_indices_cuda` on line ~1191 is intentional and used as an optimization to determine whether to reuse an existing indexer or create a new one, based on which code path was taken during tensor assignment.
Applied to files:
tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
📚 Learning: 2025-08-26T06:07:02.166Z
Learnt from: shaharmor98
Repo: NVIDIA/TensorRT-LLM PR: 7231
File: tensorrt_llm/_torch/pyexecutor/_util.py:504-509
Timestamp: 2025-08-26T06:07:02.166Z
Learning: In tensorrt_llm/_torch/pyexecutor/_util.py, when calling model_engine.set_lora_model_config(), pass model_binding_config.mlp_hidden_size directly without multiplying by mapping.tp_size, as the mlp_hidden_size from get_bindings_model_config() is already the per-TP rank value needed for LoRA weight packaging.
Applied to files:
tensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.py
📚 Learning: 2025-09-03T13:16:06.824Z
Learnt from: nvpohanh
Repo: NVIDIA/TensorRT-LLM PR: 7478
File: tensorrt_llm/_torch/models/modeling_llama.py:1315-1315
Timestamp: 2025-09-03T13:16:06.824Z
Learning: The Llama4VisionEncoder.load_weights method signature is `def load_weights(self, weights: Dict)` and should not be confused with Llama4ForConditionalGeneration.load_weights which has a different signature including weight_mapper parameter.
Applied to files:
tensorrt_llm/_torch/pyexecutor/model_loader.py
🧬 Code graph analysis (9)
tensorrt_llm/_torch/pyexecutor/_util.py (4)
tensorrt_llm/llmapi/llm_args.py (1)
SmDisaggConfig(318-338)tensorrt_llm/_torch/pyexecutor/scheduler.py (4)
SimpleScheduler(198-218)SmDisaggCtxScheduler(221-243)BindCapacityScheduler(72-99)BindMicroBatchScheduler(171-195)tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
PyTorchModelEngine(128-2549)tensorrt_llm/mapping.py (1)
has_pp(254-255)
tensorrt_llm/_torch/pyexecutor/scheduler.py (2)
tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
LlmRequest(423-643)get_context_requests(802-803)tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py (1)
SchedulerOutput(76-81)
tensorrt_llm/llmapi/__init__.py (1)
tensorrt_llm/llmapi/llm_args.py (2)
SmDisaggConfig(318-338)TorchLlmArgs(2427-2880)
tensorrt_llm/_torch/pyexecutor/py_executor.py (5)
tensorrt_llm/_torch/pyexecutor/llm_request.py (5)
get_context_requests(802-803)get_draft_token_length(788-799)get_generation_requests(806-807)get(129-141)LlmRequest(423-643)tensorrt_llm/_torch/pyexecutor/scheduler.py (7)
RequestScheduler(44-55)schedule_request(47-55)schedule_request(61-69)schedule_request(95-99)schedule_request(112-153)schedule_request(206-218)schedule_request(229-243)tensorrt_llm/_torch/pyexecutor/model_engine.py (3)
ModelEngine(69-92)forward(76-84)forward(2286-2387)tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)
fetch_new_requests(337-347)tensorrt_llm/_torch/pyexecutor/resource_manager.py (4)
prepare_resources(81-82)prepare_resources(407-447)prepare_resources(1307-1310)prepare_resources(1432-1448)
tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)
tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest(423-643)
tensorrt_llm/_torch/pyexecutor/model_loader.py (1)
examples/models/core/enc_dec/convert_checkpoint.py (1)
state_dict(1629-1630)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
drafting_loop_wrapper(394-400)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
is_dummy(611-612)
tensorrt_llm/llmapi/llm_args.py (4)
tensorrt_llm/builder.py (1)
default(45-50)tensorrt_llm/models/modeling_utils.py (3)
from_dict(253-263)from_dict(325-334)from_dict(487-492)tests/unittest/api_stability/api_stability_core.py (3)
from_dict(116-123)from_dict(172-178)from_dict(319-328)tensorrt_llm/mapping.py (1)
from_dict(314-315)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (4)
tensorrt_llm/_torch/virtual_memory.py (1)
ExecutorMemoryType(70-82)tensorrt_llm/llmapi/llm_args.py (4)
parallel_config(1775-1776)world_size(372-373)world_size(382-386)CapacitySchedulerPolicy(1077-1083)tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
PyTorchModelEngine(128-2549)tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
attn_metadata(131-132)
🪛 Ruff (0.14.3)
tensorrt_llm/_torch/pyexecutor/model_engine.py
1365-1365: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
1471-1471: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
tensorrt_llm/llmapi/llm_args.py
2833-2835: Avoid specifying long messages outside the exception class
(TRY003)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
354-356: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (16)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
205-206: Weight sharing passthrough: LGTM.Passing weight_sharing_model into ModelLoader aligns with shared-weights engines.
154-157: Review comment references incorrect line numbers; underlying concern is unfounded.The review points to lines 154-157 and 175, which do not contain
max_num_generated_tokensaccess. Lines 154-157 initialize SM disagg config values, and line 175 checkssm_disagg_enabled(a boolean flag). The actualmax_num_generated_tokensaccess occurs at lines 1365 and 1471, where it is accessed directly withoutgetattr().Evidence shows the attribute is guaranteed to exist:
- Accessed directly throughout the codebase (lines 1365, 1471, plus test cases at 367, 379) without errors or defensive checks
- Inherited from C++ bindings (
tensorrt_llm.bindings.internal.batch_manager.LlmRequest) where it is always initialized- Tests confirm consistent availability
The existing code pattern validates that no
getattr()wrapper is needed.Likely an incorrect or invalid review comment.
tensorrt_llm/llmapi/llm_args.py (1)
2464-2468: No action needed—SmDisaggConfig is already properly exported.The verification confirms SmDisaggConfig is imported on line 17 and added to the
__all__export list on line 65, making it available for downstream code to import directly from tensorrt_llm/llmapi.tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)
338-348: All call sites verified—no issues found.The single call site at
py_executor.py:1819correctly passes both required parameters (activate_requestsandnum_active_requests_on_engine) with matching types and order. No runtime errors will occur from missing parameters.tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (4)
81-104: LGTM! Memory monitoring support for context model is well-integrated.The additions for
MODEL_ENGINE_CTXfollow the existing patterns and provide helpful tuning guidance consistent with the main model engine.
474-476: LGTM! Safe iteration with proper None checks.The loop correctly checks for
Nonebefore accessing engine attributes.
666-667: LGTM! Consistent propagation of SM-disagg parameters.Both calls to
create_py_executor_instancecorrectly passctx_model_engineandsm_disagg_config.Also applies to: 725-726
696-702: LGTM! Proper resource cleanup with safe guards.The loop correctly handles
Noneengines and checks bothattn_metadataandcuda_graph_runner.enabledbefore releasing resources.tensorrt_llm/_torch/pyexecutor/py_executor.py (8)
50-51: LGTM! Clean imports and well-documented configuration.The new imports and environment variable follow existing conventions and are clearly documented.
Also applies to: 61-63
124-125: LGTM! Proper initialization with correct synchronization primitives.The new parameters, fields, and synchronization primitives (lock and condition variables) are correctly initialized. Using the same lock (
sm_disagg_lock) for both condition variables is the right approach for coordinating the two executor loops.Also applies to: 164-168, 186-187, 217-219
254-260: LGTM! Consistent multi-engine handling with proper guards.The warmup, property setter, and cleanup operations correctly iterate over all engines with appropriate
Nonechecks.Also applies to: 373-378, 473-474
537-647: LGTM! Well-designed profiler enhancements for multi-engine support.The
_profilercontext manager is properly extended to support:
- Multiple engines via
model_engineparameter with safe default- Per-stream profiling via
streamparameter- Phase identification via
phase_namefor clearer logs- Profiling toggle via
enable_profilerparameterThe changes are backward compatible and follow the existing patterns.
1513-1694: Verify synchronization correctness under edge cases.The synchronization between context and generation loops uses condition variables correctly for the common case:
- Context waits on
gen_request_cvwhen it has pending work but can't schedule (line 1555)- Generation waits on
ctx_request_cvwhen it has no work (line 1641)- Each notifies the other after processing
However, verify behavior in these edge cases:
Startup: When both loops start with no scheduled work, does generation block indefinitely on
ctx_request_cv? The context loop will continue and eventually fetch requests, but there's a window where generation is blocked.Resource starvation: If context loop can't schedule due to resource constraints (line 1554 condition is true), and generation is also blocked, ensure one will eventually make progress.
Consider adding a timeout to the condition variable waits or additional logging to help diagnose potential hangs in production.
To verify the synchronization logic is sound, you could add assertion checks or logging:
# In context loop before waiting if scheduled_batch.batch_size == 0 and (len(ctx_requests) > 0 or self.executor_request_queue.get_waiting_queue_size() > 0): logger.debug(f"Context loop waiting for generation to finish. Pending context requests: {len(ctx_requests)}") self.gen_request_cv.wait() # In generation loop before waiting if scheduled_batch.batch_size == 0: logger.debug(f"Generation loop waiting for context to finish. Active requests: {len(self.active_requests)}") self.ctx_request_cv.wait()
1814-1820: LGTM! Correct capacity calculation for SM-disaggregation.When
ctx_model_engineis present, the calculation correctly counts only context requests since that's what's relevant for the context engine's capacity. This aligns with the SM-disaggregation design where context and generation phases have separate capacity limits.
1914-1922: LGTM! Flexible scheduling interface with safe defaults.The optional
schedulerandactive_requestsparameters enable disaggregated scheduling while maintaining backward compatibility through sensible defaults.
2154-2171: LGTM! Flexible forward step supporting multiple engines.The optional
model_engineparameter enables forwarding through either the main or context engine while preserving backward compatibility. The implementation correctly defaults toself.model_engineand uses the provided engine consistently throughout.
| is_sm_disagg_ctx_phase: bool = False, | ||
| is_draft_model: bool = False, | ||
| drafting_loop_wrapper: Optional[Callable[[torch.nn.Module], | ||
| torch.nn.Module]] = None, | ||
| model: Optional[torch.nn.Module] = None, | ||
| weight_sharing_model: Optional[torch.nn.Module] = None, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Constructor additions look good; assert config presence when ctx-phase is enabled.
Prevent AttributeError if is_sm_disagg_ctx_phase=True but sm_disagg_config is None.
spec_config: Optional["DecodingBaseConfig"] = None,
is_sm_disagg_ctx_phase: bool = False,
is_draft_model: bool = False,
@@
) = llm_args.get_runtime_sizes()
- if is_sm_disagg_ctx_phase:
- max_num_tokens = llm_args.sm_disagg_config.context_max_num_tokens
- max_batch_size = llm_args.sm_disagg_config.context_max_batch_size
+ if is_sm_disagg_ctx_phase:
+ if llm_args.sm_disagg_config is None:
+ raise ValueError(
+ "is_sm_disagg_ctx_phase=True requires sm_disagg_config"
+ )
+ max_num_tokens = llm_args.sm_disagg_config.context_max_num_tokens
+ max_batch_size = llm_args.sm_disagg_config.context_max_batch_sizeCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/model_engine.py around lines 139 to 145, when
is_sm_disagg_ctx_phase=True the constructor may later access sm_disagg_config
and raise AttributeError if it's None; add an explicit check at construction
start that if is_sm_disagg_ctx_phase is True then sm_disagg_config is not None,
and raise a clear ValueError or use assert with a descriptive message indicating
sm_disagg_config is required for SM disaggregation context phase so callers get
an immediate, informative failure.
| if self.weight_sharing_model is not None: | ||
| model.load_state_dict(self.weight_sharing_model.state_dict(), | ||
| assign=True) | ||
| # Free up duplicate model weights allocated before weight sharing | ||
| torch.cuda.empty_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep shared weights on-device when assigning
state_dict() without keep_vars=True produces detached CPU tensors. With assign=True, those CPU tensors replace the module’s CUDA parameters, so this branch forces the newly built engine to run with CPU weights and immediately triggers device-mismatch failures instead of sharing memory. Please grab the on-device Parameter objects before assigning.
- model.load_state_dict(self.weight_sharing_model.state_dict(),
- assign=True)
+ shared_state = self.weight_sharing_model.state_dict(
+ keep_vars=True)
+ model.load_state_dict(shared_state, assign=True)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if self.weight_sharing_model is not None: | |
| model.load_state_dict(self.weight_sharing_model.state_dict(), | |
| assign=True) | |
| # Free up duplicate model weights allocated before weight sharing | |
| torch.cuda.empty_cache() | |
| if self.weight_sharing_model is not None: | |
| shared_state = self.weight_sharing_model.state_dict( | |
| keep_vars=True) | |
| model.load_state_dict(shared_state, assign=True) | |
| # Free up duplicate model weights allocated before weight sharing | |
| torch.cuda.empty_cache() |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/model_loader.py around lines 312 to 316, the
code calls self.weight_sharing_model.state_dict() which returns detached CPU
tensors and then uses assign=True, causing CPU tensors to replace CUDA
parameters; instead obtain the on-device Parameter objects by calling
state_dict(keep_vars=True) (or otherwise capture the
weight_sharing_model.parameters()/buffers as Variables on their current device)
and pass that mapping into model.load_state_dict(..., assign=True); ensure any
torch.cuda.empty_cache() call happens after assignment if needed.
| if llm_args.sm_disagg_config is not None: | ||
| if llm_args.cache_transceiver_config is not None: | ||
| raise ValueError( | ||
| "SM-level disaggregation is not compatible with disaggregated serving." | ||
| ) | ||
| if llm_args.parallel_config.world_size > 1: | ||
| raise NotImplementedError( | ||
| "SM-level disaggregation is not supported with parallelism.") | ||
| if scheduler_config.capacity_scheduler_policy != CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: | ||
| raise NotImplementedError( | ||
| "SM-level disaggregation is only supported with guaranteed no evict scheduler policy." | ||
| ) | ||
|
|
||
| with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_CTX, | ||
| RestoreMode.PINNED): | ||
| ctx_llm_args = copy.copy(llm_args) | ||
| ctx_llm_args.cuda_graph_config = None | ||
| ctx_model_engine = PyTorchModelEngine( | ||
| model_path=checkpoint_dir, | ||
| llm_args=ctx_llm_args, | ||
| mapping=mapping, | ||
| attn_runtime_features=attn_runtime_features, | ||
| dist=dist, | ||
| spec_config=spec_config, | ||
| weight_sharing_model=model_engine.model, | ||
| ) | ||
| else: | ||
| ctx_model_engine = None | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Missing is_sm_disagg_ctx_phase=True flag and potential config sharing issue.
Two critical issues:
-
Missing phase flag: Line 369-376 should pass
is_sm_disagg_ctx_phase=TruetoPyTorchModelEngine. According tomodel_engine.py(lines 127-210 in snippets), this flag controls whether to usesm_disagg_config.context_max_num_tokensandsm_disagg_config.context_max_batch_size. Without it, the context model will use the wrong capacity limits. -
Shallow copy risk: Line 368 uses
copy.copy(llm_args)which creates a shallow copy. Ifllm_argscontains mutable objects (e.g., nested configs), modifications to those objects will affect bothllm_argsandctx_llm_args. While settingcuda_graph_config = Noneon line 369 is safe if it replaces a reference, other mutable fields could be problematic.
Apply this diff:
ctx_llm_args = copy.copy(llm_args)
ctx_llm_args.cuda_graph_config = None
ctx_model_engine = PyTorchModelEngine(
model_path=checkpoint_dir,
llm_args=ctx_llm_args,
mapping=mapping,
attn_runtime_features=attn_runtime_features,
dist=dist,
spec_config=spec_config,
+ is_sm_disagg_ctx_phase=True,
weight_sharing_model=model_engine.model,
)Consider using copy.deepcopy(llm_args) instead of copy.copy(llm_args) at line 368 to avoid unintended config sharing between the main and context engines.
🧰 Tools
🪛 Ruff (0.14.3)
354-356: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py around lines 352 to 380:
the context model engine is created without the is_sm_disagg_ctx_phase=True flag
and uses a shallow copy of llm_args; change copy.copy(llm_args) to
copy.deepcopy(llm_args) to avoid shared mutable state, set
ctx_llm_args.cuda_graph_config = None as before, and pass
is_sm_disagg_ctx_phase=True into the PyTorchModelEngine constructor so the
context engine uses sm_disagg_config context limits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching it! I added the missing phase flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qiangxu1996, great! Thanks for adding the is_sm_disagg_ctx_phase=True flag.
Just to follow up—did you also have a chance to address the shallow copy concern? The original comment suggested using copy.deepcopy(llm_args) instead of copy.copy(llm_args) at line 368 to avoid potential issues with shared mutable state between the main and context engines.
If the shallow copy is intentional or safe in this case, that's perfectly fine—just wanted to make sure we've covered both points!
🧠 Learnings used
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which can contain default `cuda_graph_config` values, so `llm_args` may already have this config before the extra options processing.
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM's bench configuration, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which is a Dict[str, Any] that can contain default values including `cuda_graph_config`, making the fallback `llm_args["cuda_graph_config"]` safe to use.
Learnt from: shaharmor98
Repo: NVIDIA/TensorRT-LLM PR: 7231
File: tensorrt_llm/_torch/pyexecutor/_util.py:504-509
Timestamp: 2025-08-26T06:07:02.166Z
Learning: In tensorrt_llm/_torch/pyexecutor/_util.py, when calling model_engine.set_lora_model_config(), pass model_binding_config.mlp_hidden_size directly without multiplying by mapping.tp_size, as the mlp_hidden_size from get_bindings_model_config() is already the per-TP rank value needed for LoRA weight packaging.
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
| def split_device_green_ctx(self): | ||
| device = torch.device("cuda", self.device_id) | ||
| device_properties = torch.cuda.get_device_properties(device) | ||
| sm_count = device_properties.multi_processor_count | ||
| if device_properties.major >= 9: | ||
| sm_min = 8 | ||
| sm_align = 8 | ||
| else: | ||
| sm_min = 4 if device_properties.major == 8 else 2 | ||
| sm_align = 2 | ||
|
|
||
| from flashinfer import green_ctx | ||
|
|
||
| def split_device_green_ctx_aligned(sm_s1): | ||
| sm_s1 = round(sm_s1 / sm_align) * sm_align | ||
| sm_s1 = min(max(sm_s1, sm_min), sm_count - sm_min) | ||
| return green_ctx.split_device_green_ctx_by_sm_count(device, [sm_s1]) | ||
|
|
||
| sm_ctx = round(sm_count * self.sm_disagg_ctx_sm_percent) | ||
| sm_gen = sm_count - sm_ctx | ||
| # Choose the split closer to user-specified percentage when sm_count is not divisible by sm_align | ||
| sm_ctx_dist = min(sm_ctx % sm_align, sm_align - (sm_ctx % sm_align)) | ||
| sm_gen_dist = min(sm_gen % sm_align, sm_align - (sm_gen % sm_align)) | ||
| if sm_gen_dist < sm_ctx_dist: | ||
| (stream_gen, | ||
| stream_ctx), (res_gen, | ||
| res_ctx) = split_device_green_ctx_aligned(sm_gen) | ||
| else: | ||
| (stream_ctx, | ||
| stream_gen), (res_ctx, | ||
| res_gen) = split_device_green_ctx_aligned(sm_ctx) | ||
| logger.info( | ||
| f"Green contexts allocated {res_ctx.sm.smCount} SMs for context phase and {res_gen.sm.smCount} SMs for generation phase." | ||
| ) | ||
| return stream_ctx, stream_gen |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Move flashinfer import to module level for consistency and early failure detection.
Line 1718 imports flashinfer.green_ctx inside the split_device_green_ctx method. This is problematic because:
- Late failure: If flashinfer is not available or incompatible, the error won't surface until this method is called during execution, making debugging harder.
- Inconsistent style: Other imports in this file are at the module level.
- Repeated import overhead: Although Python caches imports, it's better to import once at module level.
Move the import to the top of the file:
from tensorrt_llm._torch.pyexecutor.resource_manager import (
ResourceManagerType, request_context)
+try:
+ from flashinfer import green_ctx
+except ImportError:
+ green_ctx = None
from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled,
mpi_disabled, nvtx_range, trace_func)Then add a validation check in __init__ or at the start of split_device_green_ctx:
def split_device_green_ctx(self):
if green_ctx is None:
raise ImportError("flashinfer is required for SM-level disaggregation but is not available")
device = torch.device("cuda", self.device_id)
# ... rest of the method🤖 Prompt for AI Agents
tensorrt_llm/_torch/pyexecutor/py_executor.py around lines 1707 to 1741: the
function currently imports flashinfer.green_ctx inside split_device_green_ctx
causing late failures and inconsistent style; move "from flashinfer import
green_ctx" to the module-level imports at top of the file, then add a validation
(in __init__ or at the start of split_device_green_ctx) that raises ImportError
with a clear message if green_ctx is None or missing so callers fail fast;
ensure no other behavior changes and remove the local import from this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FlashInfer is no longer a dependency after refactor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for flashinfer imports in the py_executor.py file
rg -n "flashinfer" tensorrt_llm/_torch/pyexecutor/py_executor.py
# Search for green_ctx_split_percent function definition
rg -n "def green_ctx_split_percent" --type=py
# Search for green_ctx_split_percent usage
rg -n "green_ctx_split_percent" tensorrt_llm/_torch/pyexecutor/py_executor.py
Length of output: 249
@qiangxu1996, thank you for the clarification! I can confirm that the refactor has indeed resolved the concern. The flashinfer import is no longer in py_executor.py — it's now properly encapsulated in the green_ctx module, and py_executor.py imports green_ctx_split_percent from .green_ctx at line 44. This is a cleaner design that addresses the original review comment about import location and early failure detection. ✅
| @model_validator(mode='after') | ||
| def validate_and_sync_sm_disagg_config(self) -> 'TorchLlmArgs': | ||
| """Validate SM-level disaggregation configuration.""" | ||
| if self.sm_disagg_config is None: | ||
| return self | ||
|
|
||
| config = self.sm_disagg_config | ||
| if not 0 < config.context_sm_percent < 1: | ||
| raise ValueError( | ||
| "sm_disagg_config.context_sm_percent must be in the range (0, 1)" | ||
| ) | ||
| if config.context_max_num_tokens <= 0: | ||
| config.context_max_num_tokens = self.max_num_tokens | ||
| if config.context_max_batch_size <= 0: | ||
| config.context_max_batch_size = self.max_batch_size | ||
| return self | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allow 1.0 for context_sm_percent and clamp ctx limits to runtime maxima.
- The current check excludes 1.0; allowing full SM allocation to context is reasonable.
- Also clamp context_max_num_tokens/batch_size to runtime maxima to avoid illegal overprovisioning.
Apply:
def validate_and_sync_sm_disagg_config(self) -> 'TorchLlmArgs':
"""Validate SM-level disaggregation configuration."""
if self.sm_disagg_config is None:
return self
config = self.sm_disagg_config
- if not 0 < config.context_sm_percent < 1:
+ if not (0 < config.context_sm_percent <= 1):
raise ValueError(
- "sm_disagg_config.context_sm_percent must be in the range (0, 1)"
+ "sm_disagg_config.context_sm_percent must be in the range (0, 1]"
)
if config.context_max_num_tokens <= 0:
config.context_max_num_tokens = self.max_num_tokens
if config.context_max_batch_size <= 0:
config.context_max_batch_size = self.max_batch_size
+ # Clamp to runtime maxima if provided
+ if self.max_num_tokens is not None and config.context_max_num_tokens is not None:
+ config.context_max_num_tokens = min(config.context_max_num_tokens,
+ self.max_num_tokens)
+ if self.max_batch_size is not None and config.context_max_batch_size is not None:
+ config.context_max_batch_size = min(config.context_max_batch_size,
+ self.max_batch_size)
return self🧰 Tools
🪛 Ruff (0.14.3)
2833-2835: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
tensorrt_llm/llmapi/llm_args.py lines 2825-2841: update the validator to allow
context_sm_percent == 1.0 (change the range check to require 0 <
context_sm_percent <= 1) and ensure context_max_num_tokens and
context_max_batch_size are clamped to runtime maxima — if each is <= 0 set it to
self.max_num_tokens / self.max_batch_size respectively, otherwise set it to
min(their value, self.max_num_tokens/self.max_batch_size) to prevent
overprovisioning.
|
/bot run |
|
PR_Github #23881 [ run ] triggered by Bot. Commit: |
|
Does this PR support dynamically allocating SM resources between the ctx and gen engines based on the per-batch workload? If so, is there a cost model used to determine the SM partitioning? |
No. A fixed SM ratio is passed through llm args. The fixed ratio seems to be good enough for the workloads we looked into (including OpenOrca which has varying ISL/OSL per request). So dynamic SM allocation is not planned unless there's a compelling use case. |
|
PR_Github #23881 [ run ] completed with state |
| sm_min = 4 if device_properties.major == 8 else 2 | ||
| sm_align = 2 | ||
|
|
||
| from flashinfer import green_ctx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though TRTLLM already integrates FlashInfer, is it a better idea to directly integrate the low-level green_context API provided by CUDA to reduce the abstraction level here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The green contexts utils are now implemented in a separate file directly on top of cuda-python bindings.
|
/bot run |
|
PR_Github #24335 [ run ] triggered by Bot. Commit: |
|
PR_Github #24335 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #24504 [ run ] triggered by Bot. Commit: |
|
PR_Github #24504 [ run ] completed with state |
cd20a27 to
990a8fa
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #24847 [ run ] triggered by Bot. Commit: |
|
PR_Github #24847 [ run ] completed with state |
| stats.cpu_mem_usage = 0 | ||
| stats.pinned_mem_usage = 0 | ||
|
|
||
| stats.iter = self.iter_counter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this being removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ctx and gen executor loops need to maintain separate iter_counter and iter_stats instances. The tracking of iter_stats.iter is moved to the end of the executor loops.
| for res in res_list: | ||
| desc = CUASSERT(driver.cuDevResourceGenerateDesc([res], 1))[0] | ||
| green_ctx = CUASSERT( | ||
| driver.cuGreenCtxCreate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we need to destroy the green context too? And same with the cuda stream. It seems like we should have a GreenContext class that manages the lifecycle of the green context and that allows to create streams. You can use Cursor to help with that but something along the lines of:
class GreenContext:
"""Manages the lifecycle of a CUDA Green Context and its associated streams.
A Green Context is a CUDA resource management abstraction that allows for
isolation and control of GPU resources (particularly SMs) for concurrent execution.
"""
def __init__(self, resource, device):
"""Initialize a Green Context from a CUDA device resource.
Args:
resource: A CUDA device resource (e.g., from cuDevSmResourceSplitByCount)
device: CUDA device handle
"""
self.resource = resource
self.device = device
self.green_ctx = None
self.streams = []
self._create_context()
def _create_context(self):
"""Create the underlying CUDA Green Context."""
desc = CUASSERT(driver.cuDevResourceGenerateDesc([self.resource], 1))[0]
self.green_ctx = CUASSERT(
driver.cuGreenCtxCreate(
desc, self.device, driver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM
)
)[0]
def create_stream(self):
"""Create a new CUDA stream within this Green Context.
Returns:
torch.cuda.Stream: A PyTorch CUDA stream backed by the Green Context
"""
if self.green_ctx is None:
raise RuntimeError("Green Context has not been created or has been destroyed")
stream = CUASSERT(
driver.cuGreenCtxStreamCreate(
self.green_ctx, driver.CUstream_flags.CU_STREAM_NON_BLOCKING, 0
)
)[0]
torch_stream = torch.cuda.get_stream_from_external(stream, self.device)
self.streams.append(torch_stream)
return torch_stream
def destroy(self):
"""Clean up the Green Context and release resources."""
if self.green_ctx is not None:
CUASSERT(driver.cuGreenCtxDestroy(self.green_ctx))
self.green_ctx = None
self.streams.clear()
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - clean up resources."""
self.destroy()
return False
And with methods to split based on percent, etc...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Things turn out to get trickier when we try to figure out at which point it is safe to destroy the streams and green contexts as some components in the executor use pinned memory. Pinned memory in PyTorch records all streams it encountered and needs to go through all of them when it finally gets freed. But it only gets freed after the reference count drops to 0 and PyTorch frees the cache, and that's out of the lifetime of the PyExecutor instance. At this point, I would suggest not to explicitly destroy the streams and green contexts.
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/core/CachingHostAllocator.h#L316
| self.responses = {} | ||
| self.result_wait_queues = {} | ||
|
|
||
| self.executor_lock = threading.Lock() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we explain what this lock is used for? Also executor_lock is very vague. Is this needed only for sm disaggregation? If so maybe call sm_disagg_lock?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed and commented to explain lock usage.
| self.ctx_request_cv = threading.Condition(self.executor_lock) | ||
| self.gen_request_cv = threading.Condition(self.executor_lock) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are those named _request_? It doesn't seem to be for a particular request. Should this be renamed to sm_disagg_ctx_cv or similar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed.
| scheduler_output = self.scheduler.schedule_request( | ||
| self.active_requests, self.inflight_req_ids) | ||
| def _schedule(self, | ||
| scheduler: Optional[RequestScheduler] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we make it an optional? Should we just set the default value to self.scheduler? Same with active_requests.
def _schedule(self,
scheduler: RequestScheduler] = self.scheduler,
active_requests = self.active_requests):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python doesn't allow passing self.xxx as default args, so I have to default the args to None and set it to the actual default arg we want in the body.
Signed-off-by: Qiang Xu <qiangx@nvidia.com>
|
/bot run --disable-fail-fast |
|
PR_Github #25243 [ run ] triggered by Bot. Commit: |
|
PR_Github #25243 [ run ] completed with state |
|
@QiJune this is a big MR and would like to have a second set of eyes on it. Could you please help review the changes to |
|
@QiJune friendly ping. |
Summary by CodeRabbit
Release Notes
Description
Background and overview
The SOTA request scheduling scheme - chunked prefill - piggybacks chunked context requests with generation requests to achieve stable TPOT and improved GPU utilization. However, the TPOT latency is bloated as the generation requests now often need to be processed together with many more context tokens.
Alternatively, disaggregated serving processes context and generation requests on different nodes to achieve low TOPT. However, the generation nodes can be underutilized, especially when using smaller batch sizes to meet the TPOT target. Besides, there are deployment scenarios where disaggregated serving is not suitable, e.g., lack of high speed interconnect.
This PR aims to achieve better throghput@latency by implementing a new feature called SM-level disaggregation. The feature achieves low TPOT latency by decoupling context and generation requests (as in desegregated serving), but still runs the decoupled requests on the same GPU for better GPU utilization (as in chunked prefill). To achieve that, we partition the GPU (using Green Contexts) and allocate SMs to context and generation phases, and the context and generation requests are asynchronously scheduled onto two different streams on the same GPU.
This is the first PR that implements the core functionality of SM-level disaggregation. Follow-up PRs are planned to address doc and examples, persistent kernel perf issues, and parallelism support.
Design
Performance results
Limitation: Note that the context and generation workloads are relatively balanced in above cases (in terms of compute time). This feature (as well as any other disaggregation schemes) won't show much benefit if the workload is dominated by context or generation.
Test Coverage
Please suggest appropriate test cases to guard the feature.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.