diff --git a/Capsule Network.ipynb b/Capsule Network.ipynb index f3cd682..2d8efc7 100644 --- a/Capsule Network.ipynb +++ b/Capsule Network.ipynb @@ -73,12 +73,12 @@ "outputs": [], "source": [ "class PrimaryCaps(nn.Module):\n", - " def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):\n", + " def __init__(self, num_conv2D_per_capsule=8, in_channels=256, out_channels=32, kernel_size=9):\n", " super(PrimaryCaps, self).__init__()\n", "\n", " self.capsules = nn.ModuleList([\n", " nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0) \n", - " for _ in range(num_capsules)])\n", + " for _ in range(num_conv2D_per_capsule)])\n", " \n", " def forward(self, x):\n", " u = [capsule(x) for capsule in self.capsules]\n",