Skip to content

size mismatch while using CLAM_SB #306

@Himanshunitrr

Description

@Himanshunitrr

Trying to use CLAM_SB directly. But I am getting missing keys, unexpected keys and size mismatch. How to resolve this?




import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.model_mil import MIL_fc, MIL_fc_mc
from models.model_clam import CLAM_SB, CLAM_MB
import pdb
import os
import pandas as pd
from utils.utils import *
from utils.core_utils import Accuracy_Logger
from sklearn.metrics import roc_auc_score, roc_curve, auc
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt

def initiate_model(args, ckpt_path, device='cuda'):
    print('Init Model')    
    model_dict = {"dropout": args.drop_out, 'n_classes': args.n_classes, "embed_dim": args.embed_dim}
    
    if args.model_size is not None and args.model_type in ['clam_sb', 'clam_mb']:
        model_dict.update({"size_arg": args.model_size})
    
    if args.model_type =='clam_sb':
        model = CLAM_SB(**model_dict)
    elif args.model_type =='clam_mb':
        model = CLAM_MB(**model_dict)
    else: # args.model_type == 'mil'
        if args.n_classes > 2:
            model = MIL_fc_mc(**model_dict)
        else:
            model = MIL_fc(**model_dict)

    print_network(model)

    ckpt = torch.load(ckpt_path)
    ckpt_clean = {}
    for key in ckpt.keys():
        if 'instance_loss_fn' in key:
            continue
        ckpt_clean.update({key.replace('.module', ''):ckpt[key]})
    model.load_state_dict(ckpt_clean, strict=True)

    _ = model.to(device)
    _ = model.eval()
    return model

from argparse import Namespace

args = Namespace(
    k=10,
    models_exp_code="task_1_tumor_vs_normal_CLAM_50_s1",
    save_exp_code="task_1_tumor_vs_normal_CLAM_50_s1_cv",
    task="task_1_tumor_vs_normal",
    model_type="clam_sb",
    results_dir="results",
    data_root_dir="DATA_ROOT_DIR",
    drop_out=0.25,
    embed_dim=1024,
    n_classes=2,
    model_size="small"
)

model = initiate_model(args, "/data/hmaurya/CLAM/clam_weights/camelyon_40x_cv/camelyon_40x_cv_CLAM_10_s1/s_0_checkpoint.pt")

error:


Init Model
CLAM_SB(
  (attention_net): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.25, inplace=False)
    (3): Attn_Net_Gated(
      (attention_a): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): Tanh()
        (2): Dropout(p=0.25, inplace=False)
      )
      (attention_b): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): Sigmoid()
        (2): Dropout(p=0.25, inplace=False)
      )
      (attention_c): Linear(in_features=256, out_features=1, bias=True)
    )
  )
  (classifiers): Linear(in_features=512, out_features=2, bias=True)
  (instance_classifiers): ModuleList(
    (0-1): 2 x Linear(in_features=512, out_features=2, bias=True)
  )
  (instance_loss_fn): CrossEntropyLoss()
)
Total number of parameters: 790791
Total number of trainable parameters: 790791
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[13], [line 1](vscode-notebook-cell:?execution_count=13&line=1)
----> [1] model = initiate_model(args, "/data/hmaurya/CLAM/clam_weights/camelyon_40x_cv/camelyon_40x_cv_CLAM_10_s1/s_0_checkpoint.pt")

Cell In[1], [line 42](vscode-notebook-cell:?execution_count=1&line=42)
     [40]         continue
     [41]    ckpt_clean.update({key.replace('.module', ''):ckpt[key]})
---> [42 model.load_state_dict(ckpt_clean, strict=True)
     [44] _ = model.to(device)
     [45] _ = model.eval()

File /data/hmaurya/temp_conda_envs/lib/python3.10/site-packages/torch/nn/modules/module.py:2581, in Module.load_state_dict(self, state_dict, strict, assign)
   [2573]        error_msgs.insert(
   [2574]             0,
   [2575]             "Missing key(s) in state_dict: {}. ".format(
   [2576]                 ", ".join(f'"{k}"' for k in missing_keys)
   [2577]             ),
   [2578]         )
   [2580] if len(error_msgs) > 0:
-> [2581]     raise RuntimeError(
   [2582]         "Error(s) in loading state_dict for {}:\n\t{}".format(
   [2583]             self.__class__.__name__, "\n\t".join(error_msgs)
   [2584]         )
   [2585]     )
   [2586] return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for CLAM_SB:
	Missing key(s) in state_dict: "classifiers.weight", "classifiers.bias". 
	Unexpected key(s) in state_dict: "classifiers.0.weight", "classifiers.0.bias", "classifiers.1.weight", "classifiers.1.bias". 
	size mismatch for attention_net.3.attention_c.weight: copying a param with shape torch.Size([2, 256]) from checkpoint, the shape in current model is torch.Size([1, 256]).
	size mismatch for attention_net.3.attention_c.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions