File tree Expand file tree Collapse file tree 1 file changed +22
-5
lines changed
torchao/prototype/moe_training/examples Expand file tree Collapse file tree 1 file changed +22
-5
lines changed Original file line number Diff line number Diff line change 2727 ) from e
2828
2929
30+ from argparse import ArgumentParser
31+
32+ parser = ArgumentParser ()
33+ parser .add_argument (
34+ "--scaling_type" ,
35+ type = str ,
36+ default = "fp8_rowwise" ,
37+ choices = ["fp8_rowwise" , "mxfp8" ],
38+ )
39+ args = parser .parse_args ()
40+
41+
3042# initialize model
3143device = torch .device ("cuda" )
44+ torch .manual_seed (42 )
3245model_args = MoEArgs (num_experts = 8 , top_k = 2 , use_grouped_mm = True )
33- dim = 256
46+ dim = 1024
3447hidden_dim = dim * 4
3548model = MoE (model_args , dim , hidden_dim ).to (torch .bfloat16 ).to (device )
3649init_std = 0.02
@@ -47,11 +60,15 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
4760 return False
4861
4962
50- # quantize the model, by default it is rowwise fp8
51- config = MoETrainingConfig ()
52- quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
63+ if args .scaling_type == "fp8_rowwise" :
64+ config = MoETrainingConfig ()
65+ alignment_size = 16
66+
67+ elif args .scaling_type == "mxfp8" :
68+ config = MoETrainingConfig (scaling_type = MoEScalingType .MXFP8 )
69+ alignment_size = 32
5370
54- alignment_size = 32 if config . scaling_type == MoEScalingType . MXFP8 else 16
71+ quantize_ ( model , config = config , filter_fn = moe_module_filter_fn )
5572set_token_group_alignment_size_m (alignment_size )
5673
5774# training loop
You can’t perform that action at this time.
0 commit comments