Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tensorrt_llm/_torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .llm import LLM
from .model_config import MoeLoadBalancerConfig

__all__ = ["LLM", "MoeLoadBalancerConfig"]
__all__ = ["LLM"]
41 changes: 2 additions & 39 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding
from tensorrt_llm.bindings import LayerType as LayerTypeCpp
from tensorrt_llm.functional import AllReduceStrategy
from tensorrt_llm.llmapi.llm_args import DeepSeekSparseAttentionConfig
from tensorrt_llm.llmapi.llm_args import (DeepSeekSparseAttentionConfig,
MoeLoadBalancerConfig)
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
Expand All @@ -26,44 +27,6 @@
TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)


@dataclass
class MoeLoadBalancerConfig:
num_slots: Optional[int] = None
initial_global_assignments: Optional[Dict[int,
List[int]]] = field(default=None,
repr=False)
layer_updates_per_iter: int = 0

ep_rank: Optional[int] = field(default=None, init=False)
ep_size: Optional[int] = field(default=None, init=False)

def setup(self, ep_rank: int, ep_size: int) -> None:
self.ep_rank = ep_rank
self.ep_size = ep_size
assert self.num_slots is not None

@property
def num_local_slots(self) -> int:
return self.num_slots // self.ep_size

@property
def slot_start(self) -> int:
return self.ep_rank * self.num_local_slots

@property
def slot_end(self) -> int:
return self.slot_start + self.num_local_slots

def get_layer_initial_global_assignments(self, layer_idx: int) -> List[int]:
if self.initial_global_assignments is not None:
assert layer_idx in self.initial_global_assignments
assert len(
self.initial_global_assignments[layer_idx]) == self.num_slots
return self.initial_global_assignments[layer_idx]
else:
return None


@contextlib.contextmanager
def config_file_lock(timeout: int = 10):
"""
Expand Down
104 changes: 103 additions & 1 deletion tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,109 @@ def supports_backend(self, backend: str) -> bool:
return backend == "pytorch"


class MoeLoadBalancerConfig(StrictBaseModel):
"""
Pydantic configuration model for the Mixture of Experts (MoE) load balancer.

This model holds configuration data (`num_slots`, etc.) as well as
runtime state (`_ep_rank`, `_ep_size`) which must be set via the
`setup()` method before use.
"""

num_slots: Optional[int] = None
initial_global_assignments: Optional[Dict[int, List[int]]] = Field(
default=None,
repr=False # Exclude this large dict from model representation
)
layer_updates_per_iter: int = 0
_ep_rank: Optional[int] = PrivateAttr(default=None)
_ep_size: Optional[int] = PrivateAttr(default=None)

# --- Methods ---

def setup(self, ep_rank: int, ep_size: int) -> None:
"""
Initializes the runtime state of the configuration.
This must be called before accessing properties like `num_local_slots`.
"""
self._ep_rank = ep_rank
self._ep_size = ep_size

# This assertion was in the original and is critical.
if self.num_slots is None:
raise ValueError("`num_slots` cannot be None when calling setup().")

if self.num_slots % ep_size != 0:
raise ValueError(
f"`num_slots` ({self.num_slots}) must be divisible by `ep_size` ({ep_size})."
)

# --- Computed Properties ---
# These properties depend on the runtime state set by setup()

@property
def ep_rank(self) -> int:
"""Public accessor for the private expert parallel rank."""
if self._ep_rank is None:
raise AttributeError("ep_rank is not set. Call setup() first.")
return self._ep_rank

@property
def ep_size(self) -> int:
"""Public accessor for the private expert parallel size."""
if self._ep_size is None:
raise AttributeError("ep_size is not set. Call setup() first.")
return self._ep_size

@property
def num_local_slots(self) -> int:
"""Calculates the number of slots local to this rank."""
if self.num_slots is None or self._ep_size is None:
raise ValueError(
"Cannot calculate `num_local_slots`. "
"`num_slots` must be set and setup() must be called.")
return self.num_slots // self._ep_size

@property
def slot_start(self) -> int:
"""Calculates the starting global slot index for this rank."""
if self._ep_rank is None:
raise ValueError(
"Cannot calculate `slot_start`. Call setup() first.")
return self._ep_rank * self.num_local_slots

@property
def slot_end(self) -> int:
"""Calculates the ending global slot index (exclusive) for this rank."""
return self.slot_start + self.num_local_slots

def get_layer_initial_global_assignments(
self, layer_idx: int) -> Optional[List[int]]:
"""
Retrieves the initial global assignments for a specific layer.
"""
if self.initial_global_assignments is None:
return None

if layer_idx not in self.initial_global_assignments:
raise KeyError(
f"layer_idx {layer_idx} not found in `initial_global_assignments`."
)

assignments = self.initial_global_assignments[layer_idx]

if self.num_slots is None:
raise ValueError(
"`num_slots` is not set, cannot verify assignment length.")

if len(assignments) != self.num_slots:
raise ValueError(
f"Assignment length ({len(assignments)}) for layer {layer_idx} "
f"does not match `num_slots` ({self.num_slots}).")

return assignments


class MoeConfig(StrictBaseModel):
"""
Configuration for MoE.
Expand Down Expand Up @@ -2661,7 +2764,6 @@ def validate_checkpoint_format(self):

@model_validator(mode="after")
def validate_load_balancer(self) -> 'TorchLlmArgs':
from .._torch import MoeLoadBalancerConfig
if isinstance(self.moe_config.load_balancer, str):
if not os.path.exists(self.moe_config.load_balancer):
raise FileNotFoundError(
Expand Down
Loading