Skip to content

Commit 0c081fd

Browse files
authored
Fix module name extraction logic in quant_api.py (#3298)
* remove module_name splitting * update
1 parent 6259e98 commit 0c081fd

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

test/quantization/test_quant_api.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,56 @@ def reset_memory():
11281128
assert param.is_cuda
11291129
self.assertLess(memory_streaming, memory_baseline)
11301130

1131+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1132+
def test_fqn_config_quantized_nested_module(self):
1133+
class NestedModule(torch.nn.Module):
1134+
def __init__(self):
1135+
super().__init__()
1136+
self.linear = torch.nn.Linear(16, 16)
1137+
1138+
class TopLevelModule(torch.nn.Module):
1139+
def __init__(self):
1140+
super().__init__()
1141+
self.nested = NestedModule()
1142+
self.linear1 = torch.nn.Linear(16, 16)
1143+
1144+
m = TopLevelModule()
1145+
quant_config = FqnToConfig(
1146+
{
1147+
"nested.linear": Int8WeightOnlyConfig(),
1148+
"linear1": Int8WeightOnlyConfig(),
1149+
}
1150+
)
1151+
quantize_(m, quant_config, filter_fn=None)
1152+
1153+
assert isinstance(m.nested.linear.weight, AffineQuantizedTensor)
1154+
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
1155+
1156+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1157+
def test_fqn_config_quantized_nested_module_param(self):
1158+
class NestedModule(torch.nn.Module):
1159+
def __init__(self):
1160+
super().__init__()
1161+
self.linear = torch.nn.Linear(16, 16)
1162+
1163+
class TopLevelModule(torch.nn.Module):
1164+
def __init__(self):
1165+
super().__init__()
1166+
self.nested = NestedModule()
1167+
self.linear1 = torch.nn.Linear(16, 16)
1168+
1169+
m = TopLevelModule()
1170+
quant_config = FqnToConfig(
1171+
{
1172+
"nested.linear.weight": Int8WeightOnlyConfig(),
1173+
"linear1.weight": Int8WeightOnlyConfig(),
1174+
}
1175+
)
1176+
quantize_(m, quant_config, filter_fn=None)
1177+
1178+
assert isinstance(m.nested.linear.weight, AffineQuantizedTensor)
1179+
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
1180+
11311181

11321182
if __name__ == "__main__":
11331183
unittest.main()

torchao/quantization/quant_api.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -484,11 +484,8 @@ def quantize_(
484484
or _module_param_matches_fqn_config(module, module_fqn, config)
485485
or ("_default" in config.fqn_to_config and _is_linear(module))
486486
):
487-
module_name = (
488-
module_fqn.rsplit(".", 1) if "." in module_fqn else module_fqn
489-
)
490487
# this replaces inplace, so no need to reassign
491-
_fqn_to_config_handler(module, module_name, config)
488+
_fqn_to_config_handler(module, module_fqn, config)
492489
if device is not None:
493490
module.to(device=device)
494491
return

0 commit comments

Comments
 (0)