diff --git a/tensorrt_llm/_torch/__init__.py b/tensorrt_llm/_torch/__init__.py index 7d2de6d643c..7c2d021b1c4 100644 --- a/tensorrt_llm/_torch/__init__.py +++ b/tensorrt_llm/_torch/__init__.py @@ -1,4 +1,3 @@ from .llm import LLM -from .model_config import MoeLoadBalancerConfig -__all__ = ["LLM", "MoeLoadBalancerConfig"] +__all__ = ["LLM"] diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 9278409aee0..ca956dc53cf 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -17,7 +17,8 @@ from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding from tensorrt_llm.bindings import LayerType as LayerTypeCpp from tensorrt_llm.functional import AllReduceStrategy -from tensorrt_llm.llmapi.llm_args import DeepSeekSparseAttentionConfig +from tensorrt_llm.llmapi.llm_args import (DeepSeekSparseAttentionConfig, + MoeLoadBalancerConfig) from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig @@ -26,44 +27,6 @@ TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig) -@dataclass -class MoeLoadBalancerConfig: - num_slots: Optional[int] = None - initial_global_assignments: Optional[Dict[int, - List[int]]] = field(default=None, - repr=False) - layer_updates_per_iter: int = 0 - - ep_rank: Optional[int] = field(default=None, init=False) - ep_size: Optional[int] = field(default=None, init=False) - - def setup(self, ep_rank: int, ep_size: int) -> None: - self.ep_rank = ep_rank - self.ep_size = ep_size - assert self.num_slots is not None - - @property - def num_local_slots(self) -> int: - return self.num_slots // self.ep_size - - @property - def slot_start(self) -> int: - return self.ep_rank * self.num_local_slots - - @property - def slot_end(self) -> int: - return self.slot_start + self.num_local_slots - - def get_layer_initial_global_assignments(self, layer_idx: int) -> List[int]: - if self.initial_global_assignments is not None: - assert layer_idx in self.initial_global_assignments - assert len( - self.initial_global_assignments[layer_idx]) == self.num_slots - return self.initial_global_assignments[layer_idx] - else: - return None - - @contextlib.contextmanager def config_file_lock(timeout: int = 10): """ diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 1cc5b373341..16504e1eb81 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -261,6 +261,109 @@ def supports_backend(self, backend: str) -> bool: return backend == "pytorch" +class MoeLoadBalancerConfig(StrictBaseModel): + """ + Pydantic configuration model for the Mixture of Experts (MoE) load balancer. + + This model holds configuration data (`num_slots`, etc.) as well as + runtime state (`_ep_rank`, `_ep_size`) which must be set via the + `setup()` method before use. + """ + + num_slots: Optional[int] = None + initial_global_assignments: Optional[Dict[int, List[int]]] = Field( + default=None, + repr=False # Exclude this large dict from model representation + ) + layer_updates_per_iter: int = 0 + _ep_rank: Optional[int] = PrivateAttr(default=None) + _ep_size: Optional[int] = PrivateAttr(default=None) + + # --- Methods --- + + def setup(self, ep_rank: int, ep_size: int) -> None: + """ + Initializes the runtime state of the configuration. + This must be called before accessing properties like `num_local_slots`. + """ + self._ep_rank = ep_rank + self._ep_size = ep_size + + # This assertion was in the original and is critical. + if self.num_slots is None: + raise ValueError("`num_slots` cannot be None when calling setup().") + + if self.num_slots % ep_size != 0: + raise ValueError( + f"`num_slots` ({self.num_slots}) must be divisible by `ep_size` ({ep_size})." + ) + + # --- Computed Properties --- + # These properties depend on the runtime state set by setup() + + @property + def ep_rank(self) -> int: + """Public accessor for the private expert parallel rank.""" + if self._ep_rank is None: + raise AttributeError("ep_rank is not set. Call setup() first.") + return self._ep_rank + + @property + def ep_size(self) -> int: + """Public accessor for the private expert parallel size.""" + if self._ep_size is None: + raise AttributeError("ep_size is not set. Call setup() first.") + return self._ep_size + + @property + def num_local_slots(self) -> int: + """Calculates the number of slots local to this rank.""" + if self.num_slots is None or self._ep_size is None: + raise ValueError( + "Cannot calculate `num_local_slots`. " + "`num_slots` must be set and setup() must be called.") + return self.num_slots // self._ep_size + + @property + def slot_start(self) -> int: + """Calculates the starting global slot index for this rank.""" + if self._ep_rank is None: + raise ValueError( + "Cannot calculate `slot_start`. Call setup() first.") + return self._ep_rank * self.num_local_slots + + @property + def slot_end(self) -> int: + """Calculates the ending global slot index (exclusive) for this rank.""" + return self.slot_start + self.num_local_slots + + def get_layer_initial_global_assignments( + self, layer_idx: int) -> Optional[List[int]]: + """ + Retrieves the initial global assignments for a specific layer. + """ + if self.initial_global_assignments is None: + return None + + if layer_idx not in self.initial_global_assignments: + raise KeyError( + f"layer_idx {layer_idx} not found in `initial_global_assignments`." + ) + + assignments = self.initial_global_assignments[layer_idx] + + if self.num_slots is None: + raise ValueError( + "`num_slots` is not set, cannot verify assignment length.") + + if len(assignments) != self.num_slots: + raise ValueError( + f"Assignment length ({len(assignments)}) for layer {layer_idx} " + f"does not match `num_slots` ({self.num_slots}).") + + return assignments + + class MoeConfig(StrictBaseModel): """ Configuration for MoE. @@ -2661,7 +2764,6 @@ def validate_checkpoint_format(self): @model_validator(mode="after") def validate_load_balancer(self) -> 'TorchLlmArgs': - from .._torch import MoeLoadBalancerConfig if isinstance(self.moe_config.load_balancer, str): if not os.path.exists(self.moe_config.load_balancer): raise FileNotFoundError(