diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e1c6471b17..4d427b9cf0 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -1122,6 +1122,56 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_fqn_config_quantized_nested_module(self): + class NestedModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(16, 16) + + class TopLevelModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.nested = NestedModule() + self.linear1 = torch.nn.Linear(16, 16) + + m = TopLevelModule() + quant_config = FqnToConfig( + { + "nested.linear": Int8WeightOnlyConfig(), + "linear1": Int8WeightOnlyConfig(), + } + ) + quantize_(m, quant_config, filter_fn=None) + + assert isinstance(m.nested.linear.weight, AffineQuantizedTensor) + assert isinstance(m.linear1.weight, AffineQuantizedTensor) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_fqn_config_quantized_nested_module_param(self): + class NestedModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(16, 16) + + class TopLevelModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.nested = NestedModule() + self.linear1 = torch.nn.Linear(16, 16) + + m = TopLevelModule() + quant_config = FqnToConfig( + { + "nested.linear.weight": Int8WeightOnlyConfig(), + "linear1.weight": Int8WeightOnlyConfig(), + } + ) + quantize_(m, quant_config, filter_fn=None) + + assert isinstance(m.nested.linear.weight, AffineQuantizedTensor) + assert isinstance(m.linear1.weight, AffineQuantizedTensor) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4e12e031e2..375aa084af 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -484,11 +484,8 @@ def quantize_( or _module_param_matches_fqn_config(module, module_fqn, config) or ("_default" in config.fqn_to_config and _is_linear(module)) ): - module_name = ( - module_fqn.rsplit(".", 1) if "." in module_fqn else module_fqn - ) # this replaces inplace, so no need to reassign - _fqn_to_config_handler(module, module_name, config) + _fqn_to_config_handler(module, module_fqn, config) if device is not None: module.to(device=device) return