File tree Expand file tree Collapse file tree 3 files changed +18
-6
lines changed Expand file tree Collapse file tree 3 files changed +18
-6
lines changed Original file line number Diff line number Diff line change 1+ import sys
12from setuptools import setup , find_packages
23
4+ sys .path [0 :0 ] = ['stylegan2_pytorch' ]
5+ from version import __version__
6+
37setup (
48 name = 'stylegan2_pytorch' ,
59 packages = find_packages (),
812 'stylegan2_pytorch = stylegan2_pytorch.cli:main' ,
913 ],
1014 },
11- version = '1.2.4' ,
15+ version = __version__ ,
1216 license = 'GPLv3+' ,
1317 description = 'StyleGan2 in Pytorch' ,
1418 author = 'Phil Wang' ,
Original file line number Diff line number Diff line change 2727
2828import torchvision
2929from torchvision import transforms
30+ from stylegan2_pytorch .version import __version__
3031from stylegan2_pytorch .diff_augment import DiffAugment
3132
3233from pytorch_fid import fid_score
@@ -1164,7 +1165,10 @@ def clear(self):
11641165 self .init_folders ()
11651166
11661167 def save (self , num ):
1167- save_data = {'GAN' : self .GAN .state_dict ()}
1168+ save_data = {
1169+ 'GAN' : self .GAN .state_dict (),
1170+ 'version' : __version__
1171+ }
11681172
11691173 if self .GAN .fp16 :
11701174 save_data ['amp' ] = amp .state_dict ()
@@ -1188,11 +1192,14 @@ def load(self, num = -1):
11881192
11891193 load_data = torch .load (self .model_name (name ))
11901194
1191- # make backwards compatible
1192- if 'GAN' not in load_data :
1193- load_data = {'GAN' : load_data }
1195+ if 'version' in load_data :
1196+ print (f"loading from version { load_data ['version' ]} " )
11941197
1195- self .GAN .load_state_dict (load_data ['GAN' ])
1198+ try :
1199+ self .GAN .load_state_dict (load_data ['GAN' ])
1200+ except :
1201+ print ('unable to load save model. please try downgrading the package to the version specified by the saved model' )
1202+ exit ()
11961203
11971204 if self .GAN .fp16 and 'amp' in load_data :
11981205 amp .load_state_dict (load_data ['amp' ])
Original file line number Diff line number Diff line change 1+ __version__ = '1.2.5'
You can’t perform that action at this time.
0 commit comments