File tree Expand file tree Collapse file tree 2 files changed +26
-4
lines changed Expand file tree Collapse file tree 2 files changed +26
-4
lines changed Original file line number Diff line number Diff line change @@ -1122,6 +1122,31 @@ def reset_memory():
11221122 assert param .is_cuda
11231123 self .assertLess (memory_streaming , memory_baseline )
11241124
1125+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
1126+ def test_quantized_nested_module (self ):
1127+ class NestedModule (torch .nn .Module ):
1128+ def __init__ (self ):
1129+ super ().__init__ ()
1130+ self .linear = torch .nn .Linear (16 , 16 )
1131+
1132+ class TopLevelModule (torch .nn .Module ):
1133+ def __init__ (self ):
1134+ super ().__init__ ()
1135+ self .nested = NestedModule ()
1136+ self .linear1 = torch .nn .Linear (16 , 16 )
1137+
1138+ m = TopLevelModule ()
1139+ quant_config = FqnToConfig (
1140+ {
1141+ "nested.linear" : Int8WeightOnlyConfig (),
1142+ "linear1" : Int8WeightOnlyConfig (),
1143+ }
1144+ )
1145+ quantize_ (m , quant_config , filter_fn = None )
1146+
1147+ assert isinstance (m .nested .linear .weight , AffineQuantizedTensor )
1148+ assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
1149+
11251150
11261151if __name__ == "__main__" :
11271152 unittest .main ()
Original file line number Diff line number Diff line change @@ -484,11 +484,8 @@ def quantize_(
484484 or _module_param_matches_fqn_config (module , module_fqn , config )
485485 or ("_default" in config .fqn_to_config and _is_linear (module ))
486486 ):
487- module_name = (
488- module_fqn .rsplit ("." , 1 ) if "." in module_fqn else module_fqn
489- )
490487 # this replaces inplace, so no need to reassign
491- _fqn_to_config_handler (module , module_name , config )
488+ _fqn_to_config_handler (module , module_fqn , config )
492489 if device is not None :
493490 module .to (device = device )
494491 return
You can’t perform that action at this time.
0 commit comments