@@ -177,11 +177,13 @@ def test(model, data_loader, num_train_batches, epoch, writer):
177
177
# Get the reconstructed images of the last batch
178
178
if args .use_reconstruction_loss :
179
179
reconstruction = model .decoder (output , target )
180
- image_width = 28 # MNIST digit image width
181
- image_height = 28 # MNIST digit image height
182
- image_channel = 1 # MNIST digit image channel
180
+ # Input image size and number of channel.
181
+ # By default, for MNIST, the image width and height is 28x28 and 1 channel for black/white.
182
+ image_width = args .input_width
183
+ image_height = args .input_height
184
+ image_channel = args .num_conv_in_channel
183
185
recon_img = reconstruction .view (- 1 , image_channel , image_width , image_height )
184
- assert recon_img .size () == torch .Size ([batch_size , 1 , 28 , 28 ])
186
+ assert recon_img .size () == torch .Size ([batch_size , image_channel , image_width , image_height ])
185
187
186
188
# Save the image into file system
187
189
utils .save_image (recon_img , 'results/recons_image_test_{}_{}.png' .format (epoch , global_step ))
@@ -264,6 +266,11 @@ def main():
264
266
help = 'use an additional reconstruction loss. default=True' )
265
267
parser .add_argument ('--regularization-scale' , type = float , default = 0.0005 ,
266
268
help = 'regularization coefficient for reconstruction loss. default=0.0005' )
269
+ parser .add_argument ('--dataset' , help = 'the name of dataset (mnist, cifar10)' , default = 'mnist' )
270
+ parser .add_argument ('--input-width' , type = int ,
271
+ default = 28 , help = 'input image width to the convolution. default=28 for MNIST' )
272
+ parser .add_argument ('--input-height' , type = int ,
273
+ default = 28 , help = 'input image height to the convolution. default=28 for MNIST' )
267
274
268
275
args = parser .parse_args ()
269
276
@@ -278,7 +285,7 @@ def main():
278
285
torch .cuda .manual_seed (args .seed )
279
286
280
287
# Load data
281
- train_loader , test_loader = utils .load_mnist (args )
288
+ train_loader , test_loader = utils .load_data (args )
282
289
283
290
# Build Capsule Network
284
291
print ('===> Building model' )
@@ -291,6 +298,8 @@ def main():
291
298
num_routing = args .num_routing ,
292
299
use_reconstruction_loss = args .use_reconstruction_loss ,
293
300
regularization_scale = args .regularization_scale ,
301
+ input_width = args .input_width ,
302
+ input_height = args .input_height ,
294
303
cuda_enabled = args .cuda )
295
304
296
305
if args .cuda :
@@ -307,12 +316,14 @@ def main():
307
316
for name , param in model .named_parameters ():
308
317
print ('{}: {}' .format (name , list (param .size ())))
309
318
310
- # CapsNet has 8.2M parameters and 6.8M parameters without the reconstruction subnet.
319
+ # CapsNet has:
320
+ # - 8.2M parameters and 6.8M parameters without the reconstruction subnet on MNIST.
321
+ # - 11.8M parameters and 8.0M parameters without the reconstruction subnet on CIFAR10.
311
322
num_params = sum ([param .nelement () for param in model .parameters ()])
312
323
313
324
# The coupling coefficients c_ij are not included in the parameter list,
314
- # we need to add them manually, which is 1152 * 10 = 11520.
315
- print ('\n Total number of parameters: {}\n ' .format (num_params + 11520 ))
325
+ # we need to add them manually, which is 1152 * 10 = 11520 (on MNIST) or 2048 * 10 (on CIFAR10)
326
+ print ('\n Total number of parameters: {}\n ' .format (num_params + ( 11520 if args . dataset == 'mnist' else 20480 ) ))
316
327
317
328
# Optimizer
318
329
optimizer = optim .Adam (model .parameters (), lr = args .lr )
0 commit comments