Skip to content
18 changes: 8 additions & 10 deletions examples/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from opacus import PrivacyEngine
from tqdm import tqdm

from opacus import PrivacyEngine

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--data-root", required=True, help="path to dataset")
Expand Down Expand Up @@ -298,21 +298,19 @@ def forward(self, input):
fake = netG(noise)
label_fake = torch.full((batch_size,), FAKE_LABEL, device=device)
output = netD(fake.detach())
D_G_z1 = output.mean().item()
errD_fake = criterion(output, label_fake)
errD_fake.backward()
optimizerD.step()
optimizerD.zero_grad()

# train with real
label_true = torch.full((batch_size,), REAL_LABEL, device=device)
output = netD(real_data)
errD_real = criterion(output, label_true)
errD_real.backward()
optimizerD.step()
D_x = output.mean().item()
errD_real = criterion(output, label_true)

D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
# Note that we clip the gradient for not only real but also fake data.
errD = errD_fake + errD_real
errD.backward()
optimizerD.step()

############################
# (2) Update G network: maximize log(D(G(z)))
Expand All @@ -324,7 +322,7 @@ def forward(self, input):
output_g = netD(fake)
errG = criterion(output_g, label_g)
errG.backward()
D_G_z2 = output.mean().item()
D_G_z2 = output_g.mean().item()
optimizerG.step()
data_bar.set_description(
f"epoch: {epoch}, Loss_D: {errD.item()} "
Expand Down