From bf183746b5bd64d8db63a4854457b5a8b6f64cb8 Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Fri, 24 Oct 2025 18:47:47 +0200 Subject: [PATCH] fix the outdated end2end training examples of moe+torchtitan make quant receipt flexible --- .../moe_training/examples/simple_moe_layer.py | 51 ++++++++++++++----- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/torchao/prototype/moe_training/examples/simple_moe_layer.py b/torchao/prototype/moe_training/examples/simple_moe_layer.py index 244d786c80..063a12dfcf 100644 --- a/torchao/prototype/moe_training/examples/simple_moe_layer.py +++ b/torchao/prototype/moe_training/examples/simple_moe_layer.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + import torch from torch import nn from torch.nn import functional as F @@ -5,27 +11,41 @@ # this feature requires CUDA and SM89+ assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig +from torchao.prototype.moe_training.conversion_utils import ( + MoEScalingType, + MoETrainingConfig, +) from torchao.quantization.quant_api import quantize_ # this example uses torchtitan llama4 MoE, see try: - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.models.moe import MoE, MoEArgs + from torchtitan.models.moe.utils import set_token_group_alignment_size_m except ImportError as e: raise ImportError( "torchtitan not installed, see installation instructions at https://github.com/pytorch/torchtitan" ) from e +from argparse import ArgumentParser + +parser = ArgumentParser() +parser.add_argument( + "--scaling_type", + type=str, + default="fp8_rowwise", + choices=["fp8_rowwise", "mxfp8"], +) +args = parser.parse_args() + + # initialize model device = torch.device("cuda") -model_args = TransformerModelArgs( - moe_enabled=True, - num_experts=8, - dim=256, -) -model = MoE(model_args).to(torch.bfloat16).to(device) +torch.manual_seed(42) +model_args = MoEArgs(num_experts=8, top_k=2, use_grouped_mm=True) +dim = 1024 +hidden_dim = dim * 4 +model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).to(device) init_std = 0.02 model.init_weights(init_std, device) @@ -40,14 +60,21 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False -# quantize the model -config = MoETrainingConfig() +if args.scaling_type == "fp8_rowwise": + config = MoETrainingConfig() + alignment_size = 16 + +elif args.scaling_type == "mxfp8": + config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8) + alignment_size = 32 + quantize_(model, config=config, filter_fn=moe_module_filter_fn) +set_token_group_alignment_size_m(alignment_size) # training loop optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) for step in range(10): - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device )