diff --git a/modeling/deeplab.py b/modeling/deeplab.py index 91907f83..57a27886 100644 --- a/modeling/deeplab.py +++ b/modeling/deeplab.py @@ -58,16 +58,16 @@ def get_1x_lr_params(self): def get_10x_lr_params(self): modules = [self.aspp, self.decoder] for i in range(len(modules)): - for m in modules[i].named_modules(): + for m_0, m_1, *m_len in modules[i].named_modules(): if self.freeze_bn: - if isinstance(m[1], nn.Conv2d): - for p in m[1].parameters(): + if isinstance(m_1, nn.Conv2d): + for p in m_1.parameters(): if p.requires_grad: yield p else: - if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ - or isinstance(m[1], nn.BatchNorm2d): - for p in m[1].parameters(): + if isinstance(m_1, nn.Conv2d) or isinstance(m_1, SynchronizedBatchNorm2d) \ + or isinstance(m_1, nn.BatchNorm2d): + for p in m_1.parameters(): if p.requires_grad: yield p