@@ -261,6 +261,109 @@ def supports_backend(self, backend: str) -> bool:
261261 return backend == "pytorch"
262262
263263
264+ class MoeLoadBalancerConfig (StrictBaseModel ):
265+ """
266+ Pydantic configuration model for the Mixture of Experts (MoE) load balancer.
267+
268+ This model holds configuration data (`num_slots`, etc.) as well as
269+ runtime state (`_ep_rank`, `_ep_size`) which must be set via the
270+ `setup()` method before use.
271+ """
272+
273+ num_slots : Optional [int ] = None
274+ initial_global_assignments : Optional [Dict [int , List [int ]]] = Field (
275+ default = None ,
276+ repr = False # Exclude this large dict from model representation
277+ )
278+ layer_updates_per_iter : int = 0
279+ _ep_rank : Optional [int ] = PrivateAttr (default = None )
280+ _ep_size : Optional [int ] = PrivateAttr (default = None )
281+
282+ # --- Methods ---
283+
284+ def setup (self , ep_rank : int , ep_size : int ) -> None :
285+ """
286+ Initializes the runtime state of the configuration.
287+ This must be called before accessing properties like `num_local_slots`.
288+ """
289+ self ._ep_rank = ep_rank
290+ self ._ep_size = ep_size
291+
292+ # This assertion was in the original and is critical.
293+ if self .num_slots is None :
294+ raise ValueError ("`num_slots` cannot be None when calling setup()." )
295+
296+ if self .num_slots % ep_size != 0 :
297+ raise ValueError (
298+ f"`num_slots` ({ self .num_slots } ) must be divisible by `ep_size` ({ ep_size } )."
299+ )
300+
301+ # --- Computed Properties ---
302+ # These properties depend on the runtime state set by setup()
303+
304+ @property
305+ def ep_rank (self ) -> int :
306+ """Public accessor for the private expert parallel rank."""
307+ if self ._ep_rank is None :
308+ raise AttributeError ("ep_rank is not set. Call setup() first." )
309+ return self ._ep_rank
310+
311+ @property
312+ def ep_size (self ) -> int :
313+ """Public accessor for the private expert parallel size."""
314+ if self ._ep_size is None :
315+ raise AttributeError ("ep_size is not set. Call setup() first." )
316+ return self ._ep_size
317+
318+ @property
319+ def num_local_slots (self ) -> int :
320+ """Calculates the number of slots local to this rank."""
321+ if self .num_slots is None or self ._ep_size is None :
322+ raise ValueError (
323+ "Cannot calculate `num_local_slots`. "
324+ "`num_slots` must be set and setup() must be called." )
325+ return self .num_slots // self ._ep_size
326+
327+ @property
328+ def slot_start (self ) -> int :
329+ """Calculates the starting global slot index for this rank."""
330+ if self ._ep_rank is None :
331+ raise ValueError (
332+ "Cannot calculate `slot_start`. Call setup() first." )
333+ return self ._ep_rank * self .num_local_slots
334+
335+ @property
336+ def slot_end (self ) -> int :
337+ """Calculates the ending global slot index (exclusive) for this rank."""
338+ return self .slot_start + self .num_local_slots
339+
340+ def get_layer_initial_global_assignments (
341+ self , layer_idx : int ) -> Optional [List [int ]]:
342+ """
343+ Retrieves the initial global assignments for a specific layer.
344+ """
345+ if self .initial_global_assignments is None :
346+ return None
347+
348+ if layer_idx not in self .initial_global_assignments :
349+ raise KeyError (
350+ f"layer_idx { layer_idx } not found in `initial_global_assignments`."
351+ )
352+
353+ assignments = self .initial_global_assignments [layer_idx ]
354+
355+ if self .num_slots is None :
356+ raise ValueError (
357+ "`num_slots` is not set, cannot verify assignment length." )
358+
359+ if len (assignments ) != self .num_slots :
360+ raise ValueError (
361+ f"Assignment length ({ len (assignments )} ) for layer { layer_idx } "
362+ f"does not match `num_slots` ({ self .num_slots } )." )
363+
364+ return assignments
365+
366+
264367class MoeConfig (StrictBaseModel ):
265368 """
266369 Configuration for MoE.
@@ -2673,7 +2776,6 @@ def validate_checkpoint_format(self):
26732776
26742777 @model_validator (mode = "after" )
26752778 def validate_load_balancer (self ) -> 'TorchLlmArgs' :
2676- from .._torch import MoeLoadBalancerConfig
26772779 if isinstance (self .moe_config .load_balancer , str ):
26782780 if not os .path .exists (self .moe_config .load_balancer ):
26792781 raise FileNotFoundError (
0 commit comments