File tree Expand file tree Collapse file tree 1 file changed +15
-2
lines changed Expand file tree Collapse file tree 1 file changed +15
-2
lines changed Original file line number Diff line number Diff line change @@ -831,8 +831,21 @@ def restore_parts(path, model):
831831 state = torch .load (path )["state_dict" ]
832832 model_dict = model .state_dict ()
833833 valid_state_dict = {k : v for k , v in state .items () if k in model_dict }
834- model_dict .update (valid_state_dict )
835- model .load_state_dict (model_dict )
834+
835+ try :
836+ model_dict .update (valid_state_dict )
837+ model .load_state_dict (model_dict )
838+ except RuntimeError as e :
839+ # there should be invalid size of weight(s), so load them per parameter
840+ print (str (e ))
841+ model_dict = model .state_dict ()
842+ for k , v in valid_state_dict .items ():
843+ model_dict [k ] = v
844+ try :
845+ model .load_state_dict (model_dict )
846+ except RuntimeError as e :
847+ print (str (e ))
848+ warn ("{}: may contain invalid size of weight. skipping..." .format (k ))
836849
837850
838851if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments