Skip to content

Commit d1b003d

Browse files
authored
[TRTLLM-9212][chore] move MoeLoadBalancerConfig to llm_args.py (#9002)
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
1 parent 943b05e commit d1b003d

File tree

3 files changed

+106
-42
lines changed

3 files changed

+106
-42
lines changed

tensorrt_llm/_torch/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from .llm import LLM
2-
from .model_config import MoeLoadBalancerConfig
32

4-
__all__ = ["LLM", "MoeLoadBalancerConfig"]
3+
__all__ = ["LLM"]

tensorrt_llm/_torch/model_config.py

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding
1818
from tensorrt_llm.bindings import LayerType as LayerTypeCpp
1919
from tensorrt_llm.functional import AllReduceStrategy
20-
from tensorrt_llm.llmapi.llm_args import DeepSeekSparseAttentionConfig
20+
from tensorrt_llm.llmapi.llm_args import (DeepSeekSparseAttentionConfig,
21+
MoeLoadBalancerConfig)
2122
from tensorrt_llm.logger import logger
2223
from tensorrt_llm.mapping import Mapping
2324
from tensorrt_llm.models.modeling_utils import QuantConfig
@@ -26,44 +27,6 @@
2627
TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)
2728

2829

29-
@dataclass
30-
class MoeLoadBalancerConfig:
31-
num_slots: Optional[int] = None
32-
initial_global_assignments: Optional[Dict[int,
33-
List[int]]] = field(default=None,
34-
repr=False)
35-
layer_updates_per_iter: int = 0
36-
37-
ep_rank: Optional[int] = field(default=None, init=False)
38-
ep_size: Optional[int] = field(default=None, init=False)
39-
40-
def setup(self, ep_rank: int, ep_size: int) -> None:
41-
self.ep_rank = ep_rank
42-
self.ep_size = ep_size
43-
assert self.num_slots is not None
44-
45-
@property
46-
def num_local_slots(self) -> int:
47-
return self.num_slots // self.ep_size
48-
49-
@property
50-
def slot_start(self) -> int:
51-
return self.ep_rank * self.num_local_slots
52-
53-
@property
54-
def slot_end(self) -> int:
55-
return self.slot_start + self.num_local_slots
56-
57-
def get_layer_initial_global_assignments(self, layer_idx: int) -> List[int]:
58-
if self.initial_global_assignments is not None:
59-
assert layer_idx in self.initial_global_assignments
60-
assert len(
61-
self.initial_global_assignments[layer_idx]) == self.num_slots
62-
return self.initial_global_assignments[layer_idx]
63-
else:
64-
return None
65-
66-
6730
@contextlib.contextmanager
6831
def config_file_lock(timeout: int = 10):
6932
"""

tensorrt_llm/llmapi/llm_args.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,109 @@ def supports_backend(self, backend: str) -> bool:
261261
return backend == "pytorch"
262262

263263

264+
class MoeLoadBalancerConfig(StrictBaseModel):
265+
"""
266+
Pydantic configuration model for the Mixture of Experts (MoE) load balancer.
267+
268+
This model holds configuration data (`num_slots`, etc.) as well as
269+
runtime state (`_ep_rank`, `_ep_size`) which must be set via the
270+
`setup()` method before use.
271+
"""
272+
273+
num_slots: Optional[int] = None
274+
initial_global_assignments: Optional[Dict[int, List[int]]] = Field(
275+
default=None,
276+
repr=False # Exclude this large dict from model representation
277+
)
278+
layer_updates_per_iter: int = 0
279+
_ep_rank: Optional[int] = PrivateAttr(default=None)
280+
_ep_size: Optional[int] = PrivateAttr(default=None)
281+
282+
# --- Methods ---
283+
284+
def setup(self, ep_rank: int, ep_size: int) -> None:
285+
"""
286+
Initializes the runtime state of the configuration.
287+
This must be called before accessing properties like `num_local_slots`.
288+
"""
289+
self._ep_rank = ep_rank
290+
self._ep_size = ep_size
291+
292+
# This assertion was in the original and is critical.
293+
if self.num_slots is None:
294+
raise ValueError("`num_slots` cannot be None when calling setup().")
295+
296+
if self.num_slots % ep_size != 0:
297+
raise ValueError(
298+
f"`num_slots` ({self.num_slots}) must be divisible by `ep_size` ({ep_size})."
299+
)
300+
301+
# --- Computed Properties ---
302+
# These properties depend on the runtime state set by setup()
303+
304+
@property
305+
def ep_rank(self) -> int:
306+
"""Public accessor for the private expert parallel rank."""
307+
if self._ep_rank is None:
308+
raise AttributeError("ep_rank is not set. Call setup() first.")
309+
return self._ep_rank
310+
311+
@property
312+
def ep_size(self) -> int:
313+
"""Public accessor for the private expert parallel size."""
314+
if self._ep_size is None:
315+
raise AttributeError("ep_size is not set. Call setup() first.")
316+
return self._ep_size
317+
318+
@property
319+
def num_local_slots(self) -> int:
320+
"""Calculates the number of slots local to this rank."""
321+
if self.num_slots is None or self._ep_size is None:
322+
raise ValueError(
323+
"Cannot calculate `num_local_slots`. "
324+
"`num_slots` must be set and setup() must be called.")
325+
return self.num_slots // self._ep_size
326+
327+
@property
328+
def slot_start(self) -> int:
329+
"""Calculates the starting global slot index for this rank."""
330+
if self._ep_rank is None:
331+
raise ValueError(
332+
"Cannot calculate `slot_start`. Call setup() first.")
333+
return self._ep_rank * self.num_local_slots
334+
335+
@property
336+
def slot_end(self) -> int:
337+
"""Calculates the ending global slot index (exclusive) for this rank."""
338+
return self.slot_start + self.num_local_slots
339+
340+
def get_layer_initial_global_assignments(
341+
self, layer_idx: int) -> Optional[List[int]]:
342+
"""
343+
Retrieves the initial global assignments for a specific layer.
344+
"""
345+
if self.initial_global_assignments is None:
346+
return None
347+
348+
if layer_idx not in self.initial_global_assignments:
349+
raise KeyError(
350+
f"layer_idx {layer_idx} not found in `initial_global_assignments`."
351+
)
352+
353+
assignments = self.initial_global_assignments[layer_idx]
354+
355+
if self.num_slots is None:
356+
raise ValueError(
357+
"`num_slots` is not set, cannot verify assignment length.")
358+
359+
if len(assignments) != self.num_slots:
360+
raise ValueError(
361+
f"Assignment length ({len(assignments)}) for layer {layer_idx} "
362+
f"does not match `num_slots` ({self.num_slots}).")
363+
364+
return assignments
365+
366+
264367
class MoeConfig(StrictBaseModel):
265368
"""
266369
Configuration for MoE.
@@ -2673,7 +2776,6 @@ def validate_checkpoint_format(self):
26732776

26742777
@model_validator(mode="after")
26752778
def validate_load_balancer(self) -> 'TorchLlmArgs':
2676-
from .._torch import MoeLoadBalancerConfig
26772779
if isinstance(self.moe_config.load_balancer, str):
26782780
if not os.path.exists(self.moe_config.load_balancer):
26792781
raise FileNotFoundError(

0 commit comments

Comments
 (0)