Skip to content

Commit 297e7d0

Browse files
committed
handle channel first for accept video wrapper
1 parent 29ac8e1 commit 297e7d0

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.11.4',
9+
version = '1.11.5',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description = long_description,

vit_pytorch/accept_video_wrapper.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
add_time_pos_emb = False,
2626
dim_emb = None,
2727
time_seq_len = None,
28+
embed_is_channel_first = False,
2829
output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
2930
):
3031
super().__init__()
@@ -40,6 +41,8 @@ def __init__(
4041

4142
self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2)
4243

44+
self.embed_is_channel_first = embed_is_channel_first
45+
4346
def forward(
4447
self,
4548
video, # (b c t h w)
@@ -89,9 +92,16 @@ def forward(
8992

9093
dims_to_unsqueeze = embed.ndim - pos_emb.ndim
9194

92-
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *((1,) * dims_to_unsqueeze) , pos_emb.shape[-1])
95+
one_dims = ((1,) * dims_to_unsqueeze)
96+
97+
if self.embed_is_channel_first:
98+
pos_emb = pos_emb.reshape(*pos_emb.shape, *one_dims)
99+
else:
100+
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *one_dims, pos_emb.shape[-1])
101+
102+
pos_emb = pos_emb[:, :embed.shape[1]]
93103

94-
embed = embed + pos_emb[:, :embed.shape[1]]
104+
embed = embed + pos_emb
95105

96106
outputs[self.output_pos_add_pos_emb] = embed
97107

0 commit comments

Comments
 (0)