Skip to content

Commit 5f2bc0c

Browse files
committed
with assistance from claude (yes it did the einops equation building here), generalize to n-dimensions
1 parent 35bf273 commit 5f2bc0c

File tree

2 files changed

+192
-1
lines changed

2 files changed

+192
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.11.7',
9+
version = '1.12.0',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description = long_description,

vit_pytorch/vit_nd.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
from torch import nn
5+
from torch.nn import Module
6+
7+
from einops import rearrange, repeat
8+
from einops.layers.torch import Rearrange
9+
10+
# helpers
11+
12+
def join(arr, delimiter = ' '):
13+
return delimiter.join(arr)
14+
15+
def ensure_tuple(t, length):
16+
if isinstance(t, (tuple, list)):
17+
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
18+
return tuple(t)
19+
return (t,) * length
20+
21+
# classes
22+
23+
class FeedForward(Module):
24+
def __init__(self, dim, hidden_dim, dropout = 0.):
25+
super().__init__()
26+
self.net = nn.Sequential(
27+
nn.LayerNorm(dim),
28+
nn.Linear(dim, hidden_dim),
29+
nn.GELU(),
30+
nn.Dropout(dropout),
31+
nn.Linear(hidden_dim, dim),
32+
nn.Dropout(dropout)
33+
)
34+
35+
def forward(self, x):
36+
return self.net(x)
37+
38+
class Attention(Module):
39+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
40+
super().__init__()
41+
inner_dim = dim_head * heads
42+
project_out = not (heads == 1 and dim_head == dim)
43+
44+
self.heads = heads
45+
self.scale = dim_head ** -0.5
46+
47+
self.norm = nn.LayerNorm(dim)
48+
self.attend = nn.Softmax(dim = -1)
49+
self.dropout = nn.Dropout(dropout)
50+
51+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
52+
53+
self.to_out = nn.Sequential(
54+
nn.Linear(inner_dim, dim),
55+
nn.Dropout(dropout)
56+
) if project_out else nn.Identity()
57+
58+
def forward(self, x):
59+
x = self.norm(x)
60+
qkv = self.to_qkv(x).chunk(3, dim = -1)
61+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
62+
63+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
64+
65+
attn = self.attend(dots)
66+
attn = self.dropout(attn)
67+
68+
out = torch.matmul(attn, v)
69+
out = rearrange(out, 'b h n d -> b n (h d)')
70+
return self.to_out(out)
71+
72+
class Transformer(Module):
73+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
74+
super().__init__()
75+
self.norm = nn.LayerNorm(dim)
76+
self.layers = nn.ModuleList([])
77+
for _ in range(depth):
78+
self.layers.append(nn.ModuleList([
79+
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
80+
FeedForward(dim, mlp_dim, dropout = dropout)
81+
]))
82+
83+
def forward(self, x):
84+
for attn, ff in self.layers:
85+
x = attn(x) + x
86+
x = ff(x) + x
87+
return self.norm(x)
88+
89+
class ViTND(Module):
90+
def __init__(
91+
self,
92+
*,
93+
ndim: int,
94+
input_shape: int | tuple[int, ...],
95+
patch_size: int | tuple[int, ...],
96+
num_classes: int,
97+
dim: int,
98+
depth: int,
99+
heads: int,
100+
mlp_dim: int,
101+
pool: str = 'cls',
102+
channels: int = 3,
103+
dim_head: int = 64,
104+
dropout: float = 0.,
105+
emb_dropout: float = 0.
106+
):
107+
super().__init__()
108+
109+
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
110+
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
111+
112+
self.ndim = ndim
113+
self.pool = pool
114+
115+
input_shape = ensure_tuple(input_shape, ndim)
116+
patch_size = ensure_tuple(patch_size, ndim)
117+
118+
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
119+
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
120+
121+
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
122+
num_patches = 1
123+
for n in num_patches_per_dim:
124+
num_patches *= n
125+
126+
patch_dim = channels
127+
for p in patch_size:
128+
patch_dim *= p
129+
130+
dim_names = 'fghijkl'[:ndim]
131+
132+
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
133+
patch_dims = [f'p{i}' for i in range(ndim)]
134+
135+
input_pattern = f'b c {join(input_dims)}'
136+
output_pattern = f'b ({join(dim_names)}) ({join(patch_dims)} c)'
137+
rearrange_str = f'{input_pattern} -> {output_pattern}'
138+
139+
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
140+
141+
self.to_patch_embedding = nn.Sequential(
142+
Rearrange(rearrange_str, **rearrange_kwargs),
143+
nn.Linear(patch_dim, dim),
144+
nn.LayerNorm(dim),
145+
)
146+
147+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
148+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
149+
self.dropout = nn.Dropout(emb_dropout)
150+
151+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
152+
153+
self.to_latent = nn.Identity()
154+
self.mlp_head = nn.Linear(dim, num_classes)
155+
156+
def forward(self, x):
157+
x = self.to_patch_embedding(x)
158+
b, n, _ = x.shape
159+
160+
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
161+
x = torch.cat((cls_tokens, x), dim = 1)
162+
x += self.pos_embedding[:, :(n + 1)]
163+
x = self.dropout(x)
164+
165+
x = self.transformer(x)
166+
167+
x = x[:, 1:].mean(dim = 1) if self.pool == 'mean' else x[:, 0]
168+
169+
x = self.to_latent(x)
170+
return self.mlp_head(x)
171+
172+
173+
if __name__ == '__main__':
174+
175+
model = ViTND(
176+
ndim = 4,
177+
input_shape = (8, 16, 32, 64),
178+
patch_size = (2, 4, 4, 8),
179+
num_classes = 1000,
180+
dim = 512,
181+
depth = 6,
182+
heads = 8,
183+
mlp_dim = 2048,
184+
channels = 3,
185+
dropout = 0.1,
186+
emb_dropout = 0.1
187+
)
188+
189+
occupancy_time = torch.randn(2, 3, 8, 16, 32, 64)
190+
191+
logits = model(occupancy_time)

0 commit comments

Comments
 (0)