Skip to content

Commit 311a4ba

Browse files
committed
Modified implementation to supports CIFAR10 dataset
1 parent d399766 commit 311a4ba

File tree

3 files changed

+40
-22
lines changed

3 files changed

+40
-22
lines changed

decoder.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,22 @@ class Decoder(nn.Module):
2727
This Decoder network is used in training and prediction (testing).
2828
"""
2929

30-
def __init__(self, num_classes, output_unit_size, cuda_enabled):
30+
def __init__(self, num_classes, output_unit_size, input_width,
31+
input_height, num_conv_in_channel, cuda_enabled):
3132
"""
3233
The decoder network consists of 3 fully connected layers, with
33-
512, 1024, 784 neurons each.
34+
512, 1024, 784 (or 3072 for CIFAR10) neurons each.
3435
"""
3536
super(Decoder, self).__init__()
3637

3738
self.cuda_enabled = cuda_enabled
3839

3940
fc1_output_size = 512
4041
fc2_output_size = 1024
42+
self.fc3_output_size = input_width * input_height * num_conv_in_channel
4143
self.fc1 = nn.Linear(num_classes * output_unit_size, fc1_output_size) # input dim 10 * 16.
4244
self.fc2 = nn.Linear(fc1_output_size, fc2_output_size)
43-
self.fc3 = nn.Linear(fc2_output_size, 784)
45+
self.fc3 = nn.Linear(fc2_output_size, self.fc3_output_size)
4446
# Activation functions
4547
self.relu = nn.ReLU(inplace=True)
4648
self.sigmoid = nn.Sigmoid()
@@ -49,14 +51,14 @@ def forward(self, x, target):
4951
"""
5052
We send the outputs of the `DigitCaps` layer, which is a
5153
[batch_size, 10, 16] size tensor into the Decoder network, and
52-
reconstruct a [batch_size, 784] size tensor representing the image.
54+
reconstruct a [batch_size, fc3_output_size] size tensor representing the image.
5355
5456
Args:
5557
x: [batch_size, 10, 16] The output of the digit capsule.
5658
target: [batch_size, 10] One-hot MNIST dataset labels.
5759
5860
Returns:
59-
reconstruction: [batch_size, 784] Tensor of reconstructed images.
61+
reconstruction: [batch_size, fc3_output_size] Tensor of reconstructed images.
6062
"""
6163
batch_size = target.size(0)
6264

@@ -77,8 +79,8 @@ def forward(self, x, target):
7779
# Forward pass of the network
7880
fc1_out = self.relu(self.fc1(vector_j))
7981
fc2_out = self.relu(self.fc2(fc1_out)) # shape: [batch_size, 1024]
80-
reconstruction = self.sigmoid(self.fc3(fc2_out)) # shape: [batch_size, 784]
82+
reconstruction = self.sigmoid(self.fc3(fc2_out)) # shape: [batch_size, fc3_output_size]
8183

82-
assert reconstruction.size() == torch.Size([batch_size, 784])
84+
assert reconstruction.size() == torch.Size([batch_size, self.fc3_output_size])
8385

8486
return reconstruction

main.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,13 @@ def test(model, data_loader, num_train_batches, epoch, writer):
177177
# Get the reconstructed images of the last batch
178178
if args.use_reconstruction_loss:
179179
reconstruction = model.decoder(output, target)
180-
image_width = 28 # MNIST digit image width
181-
image_height = 28 # MNIST digit image height
182-
image_channel = 1 # MNIST digit image channel
180+
# Input image size and number of channel.
181+
# By default, for MNIST, the image width and height is 28x28 and 1 channel for black/white.
182+
image_width = args.input_width
183+
image_height = args.input_height
184+
image_channel = args.num_conv_in_channel
183185
recon_img = reconstruction.view(-1, image_channel, image_width, image_height)
184-
assert recon_img.size() == torch.Size([batch_size, 1, 28, 28])
186+
assert recon_img.size() == torch.Size([batch_size, image_channel, image_width, image_height])
185187

186188
# Save the image into file system
187189
utils.save_image(recon_img, 'results/recons_image_test_{}_{}.png'.format(epoch, global_step))
@@ -264,6 +266,11 @@ def main():
264266
help='use an additional reconstruction loss. default=True')
265267
parser.add_argument('--regularization-scale', type=float, default=0.0005,
266268
help='regularization coefficient for reconstruction loss. default=0.0005')
269+
parser.add_argument('--dataset', help='the name of dataset (mnist, cifar10)', default='mnist')
270+
parser.add_argument('--input-width', type=int,
271+
default=28, help='input image width to the convolution. default=28 for MNIST')
272+
parser.add_argument('--input-height', type=int,
273+
default=28, help='input image height to the convolution. default=28 for MNIST')
267274

268275
args = parser.parse_args()
269276

@@ -278,7 +285,7 @@ def main():
278285
torch.cuda.manual_seed(args.seed)
279286

280287
# Load data
281-
train_loader, test_loader = utils.load_mnist(args)
288+
train_loader, test_loader = utils.load_data(args)
282289

283290
# Build Capsule Network
284291
print('===> Building model')
@@ -291,6 +298,8 @@ def main():
291298
num_routing=args.num_routing,
292299
use_reconstruction_loss=args.use_reconstruction_loss,
293300
regularization_scale=args.regularization_scale,
301+
input_width=args.input_width,
302+
input_height=args.input_height,
294303
cuda_enabled=args.cuda)
295304

296305
if args.cuda:
@@ -307,12 +316,14 @@ def main():
307316
for name, param in model.named_parameters():
308317
print('{}: {}'.format(name, list(param.size())))
309318

310-
# CapsNet has 8.2M parameters and 6.8M parameters without the reconstruction subnet.
319+
# CapsNet has:
320+
# - 8.2M parameters and 6.8M parameters without the reconstruction subnet on MNIST.
321+
# - 11.8M parameters and 8.0M parameters without the reconstruction subnet on CIFAR10.
311322
num_params = sum([param.nelement() for param in model.parameters()])
312323

313324
# The coupling coefficients c_ij are not included in the parameter list,
314-
# we need to add them manually, which is 1152 * 10 = 11520.
315-
print('\nTotal number of parameters: {}\n'.format(num_params + 11520))
325+
# we need to add them manually, which is 1152 * 10 = 11520 (on MNIST) or 2048 * 10 (on CIFAR10)
326+
print('\nTotal number of parameters: {}\n'.format(num_params + (11520 if args.dataset == 'mnist' else 20480)))
316327

317328
# Optimizer
318329
optimizer = optim.Adam(model.parameters(), lr=args.lr)

model.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ class Net(nn.Module):
2121
A simple CapsNet with 3 layers
2222
"""
2323

24-
def __init__(self, num_conv_in_channel, num_conv_out_channel, num_primary_unit, primary_unit_size,
25-
num_classes, output_unit_size, num_routing,
26-
use_reconstruction_loss, regularization_scale, cuda_enabled):
24+
def __init__(self, num_conv_in_channel, num_conv_out_channel, num_primary_unit,
25+
primary_unit_size, num_classes, output_unit_size, num_routing,
26+
use_reconstruction_loss, regularization_scale, input_width, input_height,
27+
cuda_enabled):
2728
"""
2829
In the constructor we instantiate one ConvLayer module and two CapsuleLayer modules
2930
and assign them as member variables.
@@ -34,9 +35,12 @@ def __init__(self, num_conv_in_channel, num_conv_out_channel, num_primary_unit,
3435

3536
# Configurations used for image reconstruction.
3637
self.use_reconstruction_loss = use_reconstruction_loss
37-
self.image_width = 28 # MNIST digit image width
38-
self.image_height = 28 # MNIST digit image height
39-
self.image_channel = 1 # MNIST digit image channel
38+
# Input image size and number of channel.
39+
# By default, for MNIST, the image width and height is 28x28
40+
# and 1 channel for black/white.
41+
self.image_width = input_width
42+
self.image_height = input_height
43+
self.image_channel = num_conv_in_channel
4044

4145
# Also known as lambda reconstruction. Default value is 0.0005.
4246
# We use sum of squared errors (SSE) similar to paper.
@@ -69,7 +73,8 @@ def __init__(self, num_conv_in_channel, num_conv_out_channel, num_primary_unit,
6973

7074
# Reconstruction network
7175
if use_reconstruction_loss:
72-
self.decoder = Decoder(num_classes, output_unit_size, cuda_enabled)
76+
self.decoder = Decoder(num_classes, output_unit_size, input_width,
77+
input_height, num_conv_in_channel, cuda_enabled)
7378

7479
def forward(self, x):
7580
"""

0 commit comments

Comments
 (0)