diff --git a/LICENSE b/LICENSE index cad43213..6bfefed3 100644 --- a/LICENSE +++ b/LICENSE @@ -2,6 +2,8 @@ MIT License Copyright (c) 2020 Phil Wang +Copyright (c) 2020 Stan Kriventsov + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights diff --git a/vit_pytorch/vit_pytorch.py b/vit_pytorch/vit_pytorch.py index ae18e0e6..41d5e05b 100644 --- a/vit_pytorch/vit_pytorch.py +++ b/vit_pytorch/vit_pytorch.py @@ -1,12 +1,14 @@ import torch import torch.nn.functional as F -from einops import rearrange + from torch import nn +from einops import rearrange class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn + def forward(self, x, **kwargs): return self.fn(x, **kwargs) + x @@ -15,6 +17,7 @@ def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn + def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) @@ -26,21 +29,23 @@ def __init__(self, dim, hidden_dim): nn.GELU(), nn.Linear(hidden_dim, dim) ) + def forward(self, x): return self.net(x) class Attention(nn.Module): - def __init__(self, dim, heads = 8): + def __init__(self, dim, heads=8): super().__init__() self.heads = heads self.scale = dim ** -0.5 - self.to_qkv = nn.Linear(dim, dim * 3, bias = False) + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) self.to_out = nn.Linear(dim, dim) + def forward(self, x, mask = None): b, n, _, h = *x.shape, self.heads qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h) + q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h) dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale @@ -67,14 +72,15 @@ def __init__(self, dim, depth, heads, mlp_dim): Residual(PreNorm(dim, Attention(dim, heads = heads))), Residual(PreNorm(dim, FeedForward(dim, mlp_dim))) ])) - def forward(self, x, mask = None): + + def forward(self, x, mask=None): for attn, ff in self.layers: - x = attn(x, mask = mask) + x = attn(x, mask=mask) x = ff(x) return x class ViT(nn.Module): - def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3): + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3): super().__init__() assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' num_patches = (image_size // patch_size) ** 2 @@ -95,7 +101,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml nn.Linear(mlp_dim, num_classes) ) - def forward(self, img, mask = None): + def forward(self, img, mask=None): p = self.patch_size x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)