diff --git a/vit_pytorch/mae.py b/vit_pytorch/mae.py index 7b076b53..95783e9f 100644 --- a/vit_pytorch/mae.py +++ b/vit_pytorch/mae.py @@ -23,6 +23,7 @@ def __init__( # extract some hyperparameters and functions from encoder (vision transformer to be trained) self.encoder = encoder + # Note: This 'num_patches' contains the actual number of patches & 1 cls_token num_patches, encoder_dim = encoder.pos_embedding.shape[-2:] self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2] pixel_values_per_patch = self.patch_to_emb.weight.shape[-1] @@ -32,6 +33,7 @@ def __init__( self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity() self.mask_token = nn.Parameter(torch.randn(decoder_dim)) self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4) + # This embedding matrix also consider the ViT's cls_token self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim) self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch) @@ -41,11 +43,13 @@ def forward(self, img): # get patches patches = self.to_patch(img) + # Note: This 'num_patches' is the actual number of patches batch, num_patches, *_ = patches.shape # patch to encoder tokens and add positions tokens = self.patch_to_emb(patches) + # pos_embedding[:, 0] is for ViT's cls_token, so we begin from 1 here tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)] # calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked @@ -74,7 +78,9 @@ def forward(self, img): # repeat mask tokens for number of masked, and add the positions using the masked indices derived above mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked) - mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices) + # Like encoder position embedding, 0 is for cls_token, so we should shift 1 here + # mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices) + mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices + 1) # concat the masked tokens to the decoder tokens and attend with decoder