Skip to content

Commit c6b4b76

Browse files
committed
save version of package along with model to allow users to downgrade the package if there is an incompatibility
1 parent 309d095 commit c6b4b76

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

setup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import sys
12
from setuptools import setup, find_packages
23

4+
sys.path[0:0] = ['stylegan2_pytorch']
5+
from version import __version__
6+
37
setup(
48
name = 'stylegan2_pytorch',
59
packages = find_packages(),
@@ -8,7 +12,7 @@
812
'stylegan2_pytorch = stylegan2_pytorch.cli:main',
913
],
1014
},
11-
version = '1.2.4',
15+
version = __version__,
1216
license='GPLv3+',
1317
description = 'StyleGan2 in Pytorch',
1418
author = 'Phil Wang',

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import torchvision
2929
from torchvision import transforms
30+
from stylegan2_pytorch.version import __version__
3031
from stylegan2_pytorch.diff_augment import DiffAugment
3132

3233
from pytorch_fid import fid_score
@@ -1164,7 +1165,10 @@ def clear(self):
11641165
self.init_folders()
11651166

11661167
def save(self, num):
1167-
save_data = {'GAN': self.GAN.state_dict()}
1168+
save_data = {
1169+
'GAN': self.GAN.state_dict(),
1170+
'version': __version__
1171+
}
11681172

11691173
if self.GAN.fp16:
11701174
save_data['amp'] = amp.state_dict()
@@ -1188,11 +1192,14 @@ def load(self, num = -1):
11881192

11891193
load_data = torch.load(self.model_name(name))
11901194

1191-
# make backwards compatible
1192-
if 'GAN' not in load_data:
1193-
load_data = {'GAN': load_data}
1195+
if 'version' in load_data:
1196+
print(f"loading from version {load_data['version']}")
11941197

1195-
self.GAN.load_state_dict(load_data['GAN'])
1198+
try:
1199+
self.GAN.load_state_dict(load_data['GAN'])
1200+
except:
1201+
print('unable to load save model. please try downgrading the package to the version specified by the saved model')
1202+
exit()
11961203

11971204
if self.GAN.fp16 and 'amp' in load_data:
11981205
amp.load_state_dict(load_data['amp'])

stylegan2_pytorch/version.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = '1.2.5'

0 commit comments

Comments
 (0)