1616import json
1717import time
1818from functools import lru_cache
19- from typing import Optional , Union
19+ from typing import Iterable , Optional , Union
2020
2121import torch
2222
@@ -39,6 +39,7 @@ def _is_auto_round_available():
3939from auto_round .export .export_to_itrex .export import pack_model # pylint: disable=E0401
4040from auto_round .schemes import QuantizationScheme
4141
42+ from neural_compressor .common .utils import Statistics
4243from neural_compressor .torch .algorithms import Quantizer
4344from neural_compressor .torch .utils import get_accelerator , logger
4445
@@ -104,6 +105,14 @@ def __init__(
104105 guidance_scale : float = 7.5 ,
105106 num_inference_steps : int = 50 ,
106107 generator_seed : int = None ,
108+ # 0.9
109+ target_bits : int = None ,
110+ options : Union [str , list [Union [str ]], tuple [Union [str ], ...]] = ("MXFP4" , "MXFP8" ),
111+ shared_layers : Optional [Iterable [Iterable [str ]]] = None ,
112+ ignore_scale_zp_bits : bool = False ,
113+ auto_scheme_method : str = "default" ,
114+ auto_scheme_batch_size : int = None ,
115+ auto_scheme_device_map : str = None ,
107116 ** kwargs ,
108117 ):
109118 """Init a AutQRoundQuantizer object.
@@ -238,6 +247,13 @@ def __init__(
238247 self .guidance_scale = guidance_scale
239248 self .num_inference_steps = num_inference_steps
240249 self .generator_seed = generator_seed
250+ self .target_bits = target_bits
251+ self .options = options
252+ self .shared_layers = shared_layers
253+ self .ignore_scale_zp_bits = ignore_scale_zp_bits
254+ self .auto_scheme_method = auto_scheme_method
255+ self .auto_scheme_batch_size = auto_scheme_batch_size
256+ self .auto_scheme_device_map = auto_scheme_device_map
241257
242258 def _is_w4afp8 (self ) -> bool :
243259 return any ([v .get ("data_type" , None ) == "fp8_to_int_sym" for v in self .quant_config .values ()])
@@ -273,6 +289,19 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
273289 model = model .orig_model
274290 if pipe is not None :
275291 model = pipe
292+ if self .target_bits is not None :
293+ from auto_round import AutoScheme
294+
295+ self .scheme = AutoScheme (
296+ avg_bits = self .target_bits ,
297+ options = self .options ,
298+ shared_layers = self .shared_layers ,
299+ ignore_scale_zp_bits = self .ignore_scale_zp_bits ,
300+ method = self .auto_scheme_method ,
301+ batch_size = self .auto_scheme_batch_size ,
302+ device_map = self .auto_scheme_device_map ,
303+ low_gpu_mem_usage = self .low_gpu_mem_usage ,
304+ )
276305 rounder = AutoRound (
277306 model ,
278307 layer_config = self .layer_config ,
@@ -338,6 +367,9 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
338367 rounder .quantize_and_save (output_dir = self .output_dir , format = self .export_format , inplace = True )
339368 model = rounder .model
340369 model .autoround_config = rounder .layer_config
370+
371+ dump_model_op_stats (rounder .layer_config )
372+
341373 return model
342374
343375
@@ -452,3 +484,28 @@ def get_mllm_dataloader(
452484 quant_nontext_module = quant_nontext_module ,
453485 )
454486 return dataloader , template , truncation , batch_size , gradient_accumulate_steps , seqlen , nsamples
487+
488+
489+ def dump_model_op_stats (layer_config ):
490+ """Dump quantizable ops stats of model to user."""
491+ # TODO: collect more ops besides Linear
492+ res = {}
493+ res ["Linear" ] = {}
494+ for name , info in layer_config .items ():
495+ if "data_type" in info :
496+ data_type_str = info ["data_type" ].upper ()
497+ if "bits" in info and str (info ["bits" ]) not in info ["data_type" ]:
498+ data_type_str += str (info ["bits" ])
499+ res ["Linear" ][data_type_str ] = res .get ("Linear" , {}).get (data_type_str , 0 ) + 1
500+
501+ # update stats format for dump.
502+ field_names = ["Op Type" , "Total" ]
503+ dtype_list = list (res ["Linear" ].keys ())
504+ field_names .extend (dtype_list )
505+ output_data = []
506+ for op_type in res .keys ():
507+ field_results = [op_type , sum (res [op_type ].values ())]
508+ field_results .extend ([res [op_type ][dtype ] for dtype in dtype_list ])
509+ output_data .append (field_results )
510+
511+ Statistics (output_data , header = "Mixed Precision Statistics" , field_names = field_names ).print_stat ()
0 commit comments