diff --git a/WGAN.py b/WGAN.py index 08e1b0a..4f1add3 100644 --- a/WGAN.py +++ b/WGAN.py @@ -141,7 +141,9 @@ def train(self): D_real = self.D(x_) D_real_loss = -torch.mean(D_real) - G_ = self.G(z_) + with torch.no_grad(): + G_ = self.G(z_) + D_fake = self.D(G_) D_fake_loss = torch.mean(D_fake)