Skip to content

Commit bc2517e

Browse files
committed
make quant receipt flexible
1 parent 1b28354 commit bc2517e

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

torchao/prototype/moe_training/examples/simple_moe_layer.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,23 @@
2727
) from e
2828

2929

30+
from argparse import ArgumentParser
31+
32+
parser = ArgumentParser()
33+
parser.add_argument(
34+
"--scaling_type",
35+
type=str,
36+
default="fp8_rowwise",
37+
choices=["fp8_rowwise", "mxfp8"],
38+
)
39+
args = parser.parse_args()
40+
41+
3042
# initialize model
3143
device = torch.device("cuda")
44+
torch.manual_seed(42)
3245
model_args = MoEArgs(num_experts=8, top_k=2, use_grouped_mm=True)
33-
dim = 256
46+
dim = 1024
3447
hidden_dim = dim * 4
3548
model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).to(device)
3649
init_std = 0.02
@@ -47,11 +60,15 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
4760
return False
4861

4962

50-
# quantize the model, by default it is rowwise fp8
51-
config = MoETrainingConfig()
52-
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
63+
if args.scaling_type == "fp8_rowwise":
64+
config = MoETrainingConfig()
65+
alignment_size = 16
66+
67+
elif args.scaling_type == "mxfp8":
68+
config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
69+
alignment_size = 32
5370

54-
alignment_size = 32 if config.scaling_type == MoEScalingType.MXFP8 else 16
71+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
5572
set_token_group_alignment_size_m(alignment_size)
5673

5774
# training loop

0 commit comments

Comments
 (0)