@@ -1128,6 +1128,56 @@ def reset_memory():
11281128 assert param .is_cuda
11291129 self .assertLess (memory_streaming , memory_baseline )
11301130
1131+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
1132+ def test_fqn_config_quantized_nested_module (self ):
1133+ class NestedModule (torch .nn .Module ):
1134+ def __init__ (self ):
1135+ super ().__init__ ()
1136+ self .linear = torch .nn .Linear (16 , 16 )
1137+
1138+ class TopLevelModule (torch .nn .Module ):
1139+ def __init__ (self ):
1140+ super ().__init__ ()
1141+ self .nested = NestedModule ()
1142+ self .linear1 = torch .nn .Linear (16 , 16 )
1143+
1144+ m = TopLevelModule ()
1145+ quant_config = FqnToConfig (
1146+ {
1147+ "nested.linear" : Int8WeightOnlyConfig (),
1148+ "linear1" : Int8WeightOnlyConfig (),
1149+ }
1150+ )
1151+ quantize_ (m , quant_config , filter_fn = None )
1152+
1153+ assert isinstance (m .nested .linear .weight , AffineQuantizedTensor )
1154+ assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
1155+
1156+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
1157+ def test_fqn_config_quantized_nested_module_param (self ):
1158+ class NestedModule (torch .nn .Module ):
1159+ def __init__ (self ):
1160+ super ().__init__ ()
1161+ self .linear = torch .nn .Linear (16 , 16 )
1162+
1163+ class TopLevelModule (torch .nn .Module ):
1164+ def __init__ (self ):
1165+ super ().__init__ ()
1166+ self .nested = NestedModule ()
1167+ self .linear1 = torch .nn .Linear (16 , 16 )
1168+
1169+ m = TopLevelModule ()
1170+ quant_config = FqnToConfig (
1171+ {
1172+ "nested.linear.weight" : Int8WeightOnlyConfig (),
1173+ "linear1.weight" : Int8WeightOnlyConfig (),
1174+ }
1175+ )
1176+ quantize_ (m , quant_config , filter_fn = None )
1177+
1178+ assert isinstance (m .nested .linear .weight , AffineQuantizedTensor )
1179+ assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
1180+
11311181
11321182if __name__ == "__main__" :
11331183 unittest .main ()
0 commit comments