Skip to content

Commit 25ccfcf

Browse files
committed
make pytorch-fid package optional
1 parent 5d9691f commit 25ccfcf

File tree

4 files changed

+10
-4
lines changed

4 files changed

+10
-4
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,14 @@ Notes:
236236

237237
Thanks to <a href="https://github.com/GetsEclectic">GetsEclectic</a>, you can now calculate the FID score periodically! Again, made super simple with one extra argument, as shown below.
238238

239+
Firstly, install the `pytorch_fid` package
240+
241+
```bash
242+
$ pip install pytorch-fid
243+
```
244+
245+
Followed by
246+
239247
```bash
240248
$ stylegan2_pytorch --data ./data --calculate-fid-every 5000
241249
```

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
'torch',
3333
'torchvision',
3434
'pillow',
35-
'pytorch-fid',
3635
'vector-quantize-pytorch>=0.1.0'
3736
],
3837
classifiers=[

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
from stylegan2_pytorch.version import __version__
3131
from stylegan2_pytorch.diff_augment import DiffAugment
3232

33-
from pytorch_fid import fid_score
34-
3533
from vector_quantize_pytorch import VectorQuantize
3634
from linear_attention_transformer import ImageLinearAttention
3735

@@ -1102,6 +1100,7 @@ def tile(a, dim, n_tile):
11021100

11031101
@torch.no_grad()
11041102
def calculate_fid(self, num_batches):
1103+
from pytorch_fid import fid_score
11051104
torch.cuda.empty_cache()
11061105

11071106
real_path = str(self.results_dir / self.name / 'fid_real') + '/'

stylegan2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.5.4'
1+
__version__ = '1.5.5'

0 commit comments

Comments
 (0)