Skip to content

Commit d7efc9d

Browse files
committed
train: load only valid weights
1 parent 8ea390b commit d7efc9d

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

train.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -831,8 +831,21 @@ def restore_parts(path, model):
831831
state = torch.load(path)["state_dict"]
832832
model_dict = model.state_dict()
833833
valid_state_dict = {k: v for k, v in state.items() if k in model_dict}
834-
model_dict.update(valid_state_dict)
835-
model.load_state_dict(model_dict)
834+
835+
try:
836+
model_dict.update(valid_state_dict)
837+
model.load_state_dict(model_dict)
838+
except RuntimeError as e:
839+
# there should be invalid size of weight(s), so load them per parameter
840+
print(str(e))
841+
model_dict = model.state_dict()
842+
for k, v in valid_state_dict.items():
843+
model_dict[k] = v
844+
try:
845+
model.load_state_dict(model_dict)
846+
except RuntimeError as e:
847+
print(str(e))
848+
warn("{}: may contain invalid size of weight. skipping...".format(k))
836849

837850

838851
if __name__ == "__main__":

0 commit comments

Comments
 (0)