Skip to content

Commit 943b05e

Browse files
authored
[TRTLLM-9179][feat] add pp_partition to customize each rank's layer number (#9003)
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
1 parent 3416efb commit 943b05e

File tree

5 files changed

+65
-19
lines changed

5 files changed

+65
-19
lines changed

docs/source/developer-guide/api-change.md

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ TensorRT LLM classifies APIs into two categories:
3434
All API schemas are:
3535
- Stored as YAML files in the codebase
3636
- Protected by unit tests in `tests/unittest/api_stability/`
37-
- Automatically validated to ensure consistency
37+
- Automatically validated to ensure consistency
3838

3939
## API Change Principles
4040

@@ -44,22 +44,26 @@ All API schemas are:
4444

4545
Argument names should describe what the argument represents, not how it is used internally.
4646

47-
**Good**: `max_new_tokens` (clear meaning)
47+
**Good**: `max_new_tokens` (clear meaning)
48+
4849
**Bad**: `num` (ambiguous)
4950

5051
**Reflect Argument Type and Granularity**
5152

5253
- For **boolean** knobs, prefix with verbs like `enable_` and so on.
54+
5355
Examples: `enable_cache`, `enable_flash_attention`
5456

55-
- For **numerical threshold** knobs, suffix with `_limit`, `_size`, `_count`, `_len_` or `_ratio`
57+
- For **numerical threshold** knobs, suffix with `_limit`, `_size`, `_count`, `_len_` or `_ratio`
58+
5659
Examples: `max_seq_len`, `prefill_batch_size`
5760

5861
**Avoid Redundant Prefixes**
5962

6063
Example (in `MoeConfig`):
6164

62-
**Good**: `backend`
65+
**Good**: `backend`
66+
6367
**Bad**: `moe_backend` (redundant since it's already in `MoeConfig`)
6468

6569
**Use Specific Names for Narrow Scenarios**
@@ -68,7 +72,8 @@ When adding knobs for specific use cases, make the name convey the restriction c
6872

6973
Example (argument to the LLM class):
7074

71-
**Good**: `rope_scaling_factor` → clearly indicates it's for RoPE
75+
**Good**: `rope_scaling_factor` → clearly indicates it's for RoPE
76+
7277
**Bad**: `scaling_factor` → too generic and prone to misuse
7378

7479
### 2. Hierarchical Configuration
@@ -77,13 +82,16 @@ Organize complex or hierarchical arguments into **dedicated configuration datacl
7782

7883
**Guidelines**
7984

80-
- Use the `XxxConfig` suffix consistently
85+
- Use the `XxxConfig` suffix consistently
86+
8187
Examples: `ModelConfig`, `ParallelConfig`, `MoeConfig`
82-
83-
- **Reflect conceptual hierarchy**
88+
89+
- **Reflect conceptual hierarchy**
90+
8491
The dataclass name should represent a coherent functional unit, not an arbitrary grouping
85-
86-
- **Avoid over-nesting**
92+
93+
- **Avoid over-nesting**
94+
8795
Use only one level of configuration hierarchy whenever possible (e.g., `LlmArgs → ParallelConfig`) to balance readability and modularity
8896

8997
### 3. Prefer `LlmArgs` Over Environment Variables
@@ -154,15 +162,15 @@ garbage_collection_gen0_threshold: int = Field(
154162

155163
Add the field to the appropriate schema file:
156164

157-
- **Non-committed arguments**: `tests/unittest/api_stability/references/llm_args.yaml`
165+
- **Non-committed arguments**: `tests/unittest/api_stability/references/llm.yaml`
158166
```yaml
159167
garbage_collection_gen0_threshold:
160168
type: int
161169
default: 20000
162170
status: beta # Must match the status in code
163171
```
164172
165-
- **Committed arguments**: `tests/unittest/api_stability/references_committed/llm_args.yaml`
173+
- **Committed arguments**: `tests/unittest/api_stability/references_committed/llm.yaml`
166174
```yaml
167175
garbage_collection_gen0_threshold:
168176
type: int
@@ -196,16 +204,16 @@ For non-committed APIs, use the `@set_api_status` decorator:
196204
```python
197205
@set_api_status("beta")
198206
def generate_with_streaming(
199-
self,
200-
prompts: List[str],
207+
self,
208+
prompts: List[str],
201209
**kwargs
202210
) -> Iterator[GenerationOutput]:
203211
"""Generate text with streaming output.
204-
212+
205213
Args:
206214
prompts: Input prompts for generation
207215
**kwargs: Additional generation parameters
208-
216+
209217
Returns:
210218
Iterator of generation outputs
211219
"""

tensorrt_llm/llmapi/llm_args.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ class _ParallelConfig(StrictBaseModel):
326326
moe_tp_size: int = -1
327327
moe_ep_size: int = -1
328328
cp_config: dict = Field(default_factory=dict)
329+
pp_partition: Optional[List[int]] = Field(default=None)
329330
enable_attention_dp: bool = False
330331
enable_lm_head_tp_in_adp: bool = False
331332

@@ -372,6 +373,7 @@ def to_mapping(self) -> Mapping:
372373
gpus_per_node=self.gpus_per_node,
373374
tp_size=self.tp_size,
374375
pp_size=self.pp_size,
376+
pp_partition=self.pp_partition,
375377
cp_size=self.cp_size,
376378
cp_config=self.cp_config,
377379
enable_attention_dp=self.enable_attention_dp,
@@ -1587,6 +1589,12 @@ class BaseLlmArgs(StrictBaseModel):
15871589
description="Enable LM head TP in attention dp.",
15881590
status="prototype")
15891591

1592+
pp_partition: Optional[List[int]] = Field(
1593+
default=None,
1594+
description=
1595+
"Pipeline parallel partition, a list of each rank's layer number.",
1596+
status="prototype")
1597+
15901598
cp_config: Optional[dict] = Field(default_factory=dict,
15911599
description="Context parallel config.",
15921600
status="prototype")
@@ -1843,6 +1851,7 @@ def validate_parallel_config(self):
18431851
moe_ep_size=self.moe_expert_parallel_size,
18441852
enable_attention_dp=self.enable_attention_dp,
18451853
enable_lm_head_tp_in_adp=self.enable_lm_head_tp_in_adp,
1854+
pp_partition=self.pp_partition,
18461855
cp_config=self.cp_config)
18471856
return self
18481857

tensorrt_llm/mapping.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
cp_config=None,
5050
tp_size=1,
5151
pp_size=1,
52+
pp_partition=None,
5253
moe_cluster_size=-1, # -1 means no moe
5354
moe_tp_size=-1, # -1 means no moe
5455
moe_ep_size=-1, # -1 means no moe
@@ -126,6 +127,7 @@ def __init__(
126127
self.cp_size = cp_size
127128
self.cp_config = cp_config if cp_config is not None else {}
128129
self.pp_size = pp_size
130+
self.pp_partition = pp_partition
129131
self.moe_tp_size = moe_tp_size
130132
self.moe_ep_size = moe_ep_size
131133
self.moe_cluster_size = moe_cluster_size
@@ -156,6 +158,7 @@ def __eq__(self, other):
156158
and self.tp_size == other.tp_size
157159
and self.moe_cluster_size == other.moe_cluster_size
158160
and self.pp_size == other.pp_size
161+
and self.pp_partition == other.pp_partition
159162
and self.moe_tp_size == other.moe_tp_size
160163
and self.moe_ep_size == other.moe_ep_size
161164
and self.attn_tp_size == other.attn_tp_size
@@ -177,6 +180,7 @@ def __hash__(self):
177180
self.attn_cp_size,
178181
# note: we do not allow updating cp_config after initialization
179182
tuple(sorted(self.cp_config.items())),
183+
tuple(self.pp_partition) if self.pp_partition is not None else (),
180184
))
181185

182186
@property
@@ -299,9 +303,20 @@ def has_moe_ep(self):
299303
return self.moe_ep_size > 1
300304

301305
def pp_layers(self, num_layers: int) -> List[int]:
302-
# If num_layers % pp_size = n != 0, first n ranks get one extra layer
303-
return torch.tensor_split(torch.arange(num_layers),
304-
self.pp_size)[self.pp_rank].tolist()
306+
if self.pp_partition is not None:
307+
if len(self.pp_partition) != self.pp_size:
308+
raise ValueError(
309+
f"{len(self.pp_partition)=} does not match {self.pp_size=}."
310+
)
311+
if sum(self.pp_partition) != num_layers:
312+
raise ValueError(
313+
f"{sum(self.pp_partition)=} does not match {num_layers=}.")
314+
return torch.arange(num_layers).split(
315+
self.pp_partition)[self.pp_rank].tolist()
316+
else:
317+
# If num_layers % pp_size = n != 0, first n ranks get one extra layer
318+
return torch.tensor_split(torch.arange(num_layers),
319+
self.pp_size)[self.pp_rank].tolist()
305320

306321
def ep_experts(self, num_experts: int) -> List[int]:
307322
assert self.cp_size == 1
@@ -446,6 +461,7 @@ def __init__(
446461
cp_config=None,
447462
tp_size=1,
448463
pp_size=1,
464+
pp_partition=None,
449465
moe_cluster_size=-1, # -1 means no moe
450466
moe_tp_size=-1, # -1 means no moe
451467
moe_ep_size=-1, # -1 means no moe
@@ -460,6 +476,7 @@ def __init__(
460476
cp_config=cp_config,
461477
tp_size=tp_size,
462478
pp_size=pp_size,
479+
pp_partition=pp_partition,
463480
moe_cluster_size=moe_cluster_size,
464481
moe_tp_size=moe_tp_size,
465482
moe_ep_size=moe_ep_size,

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,6 +1290,13 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
12901290
torch_compile):
12911291
if torch_compile and pp_size > 1:
12921292
pytest.skip("PP with torch.compile is not supported yet.")
1293+
1294+
if pp_size > 1 and mtp_nextn > 0:
1295+
num_hidden_layers = 30
1296+
pp_partition = [num_hidden_layers // pp_size + 1] * pp_size
1297+
pp_partition[-1] = num_hidden_layers - sum(pp_partition[:-1])
1298+
else:
1299+
pp_partition = None
12931300
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
12941301
torch_compile_config = TorchCompileConfig(
12951302
enable_fullgraph=True,
@@ -1307,6 +1314,7 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
13071314
with LLM(self.MODEL_PATH,
13081315
tensor_parallel_size=tp_size,
13091316
pipeline_parallel_size=pp_size,
1317+
pp_partition=pp_partition,
13101318
moe_expert_parallel_size=ep_size,
13111319
kv_cache_config=kv_cache_config,
13121320
**pytorch_config,

tests/unittest/api_stability/references/llm.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ methods:
1818
annotation: Optional[dict]
1919
default: null
2020
status: prototype
21+
pp_partition:
22+
annotation: Optional[List[int]]
23+
default: null
24+
status: prototype
2125
# Stats
2226
iter_stats_max_iterations:
2327
annotation: Optional[int]

0 commit comments

Comments
 (0)