-
Notifications
You must be signed in to change notification settings - Fork 459
Open
Description
Trying to use CLAM_SB directly. But I am getting missing keys, unexpected keys and size mismatch. How to resolve this?
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.model_mil import MIL_fc, MIL_fc_mc
from models.model_clam import CLAM_SB, CLAM_MB
import pdb
import os
import pandas as pd
from utils.utils import *
from utils.core_utils import Accuracy_Logger
from sklearn.metrics import roc_auc_score, roc_curve, auc
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
def initiate_model(args, ckpt_path, device='cuda'):
print('Init Model')
model_dict = {"dropout": args.drop_out, 'n_classes': args.n_classes, "embed_dim": args.embed_dim}
if args.model_size is not None and args.model_type in ['clam_sb', 'clam_mb']:
model_dict.update({"size_arg": args.model_size})
if args.model_type =='clam_sb':
model = CLAM_SB(**model_dict)
elif args.model_type =='clam_mb':
model = CLAM_MB(**model_dict)
else: # args.model_type == 'mil'
if args.n_classes > 2:
model = MIL_fc_mc(**model_dict)
else:
model = MIL_fc(**model_dict)
print_network(model)
ckpt = torch.load(ckpt_path)
ckpt_clean = {}
for key in ckpt.keys():
if 'instance_loss_fn' in key:
continue
ckpt_clean.update({key.replace('.module', ''):ckpt[key]})
model.load_state_dict(ckpt_clean, strict=True)
_ = model.to(device)
_ = model.eval()
return model
from argparse import Namespace
args = Namespace(
k=10,
models_exp_code="task_1_tumor_vs_normal_CLAM_50_s1",
save_exp_code="task_1_tumor_vs_normal_CLAM_50_s1_cv",
task="task_1_tumor_vs_normal",
model_type="clam_sb",
results_dir="results",
data_root_dir="DATA_ROOT_DIR",
drop_out=0.25,
embed_dim=1024,
n_classes=2,
model_size="small"
)
model = initiate_model(args, "/data/hmaurya/CLAM/clam_weights/camelyon_40x_cv/camelyon_40x_cv_CLAM_10_s1/s_0_checkpoint.pt")
error:
Init Model
CLAM_SB(
(attention_net): Sequential(
(0): Linear(in_features=1024, out_features=512, bias=True)
(1): ReLU()
(2): Dropout(p=0.25, inplace=False)
(3): Attn_Net_Gated(
(attention_a): Sequential(
(0): Linear(in_features=512, out_features=256, bias=True)
(1): Tanh()
(2): Dropout(p=0.25, inplace=False)
)
(attention_b): Sequential(
(0): Linear(in_features=512, out_features=256, bias=True)
(1): Sigmoid()
(2): Dropout(p=0.25, inplace=False)
)
(attention_c): Linear(in_features=256, out_features=1, bias=True)
)
)
(classifiers): Linear(in_features=512, out_features=2, bias=True)
(instance_classifiers): ModuleList(
(0-1): 2 x Linear(in_features=512, out_features=2, bias=True)
)
(instance_loss_fn): CrossEntropyLoss()
)
Total number of parameters: 790791
Total number of trainable parameters: 790791
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[13], [line 1](vscode-notebook-cell:?execution_count=13&line=1)
----> [1] model = initiate_model(args, "/data/hmaurya/CLAM/clam_weights/camelyon_40x_cv/camelyon_40x_cv_CLAM_10_s1/s_0_checkpoint.pt")
Cell In[1], [line 42](vscode-notebook-cell:?execution_count=1&line=42)
[40] continue
[41] ckpt_clean.update({key.replace('.module', ''):ckpt[key]})
---> [42 model.load_state_dict(ckpt_clean, strict=True)
[44] _ = model.to(device)
[45] _ = model.eval()
File /data/hmaurya/temp_conda_envs/lib/python3.10/site-packages/torch/nn/modules/module.py:2581, in Module.load_state_dict(self, state_dict, strict, assign)
[2573] error_msgs.insert(
[2574] 0,
[2575] "Missing key(s) in state_dict: {}. ".format(
[2576] ", ".join(f'"{k}"' for k in missing_keys)
[2577] ),
[2578] )
[2580] if len(error_msgs) > 0:
-> [2581] raise RuntimeError(
[2582] "Error(s) in loading state_dict for {}:\n\t{}".format(
[2583] self.__class__.__name__, "\n\t".join(error_msgs)
[2584] )
[2585] )
[2586] return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for CLAM_SB:
Missing key(s) in state_dict: "classifiers.weight", "classifiers.bias".
Unexpected key(s) in state_dict: "classifiers.0.weight", "classifiers.0.bias", "classifiers.1.weight", "classifiers.1.bias".
size mismatch for attention_net.3.attention_c.weight: copying a param with shape torch.Size([2, 256]) from checkpoint, the shape in current model is torch.Size([1, 256]).
size mismatch for attention_net.3.attention_c.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
Metadata
Metadata
Assignees
Labels
No labels