@@ -25,6 +25,7 @@ def __init__(
2525 add_time_pos_emb = False ,
2626 dim_emb = None ,
2727 time_seq_len = None ,
28+ embed_is_channel_first = False ,
2829 output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
2930 ):
3031 super ().__init__ ()
@@ -40,6 +41,8 @@ def __init__(
4041
4142 self .pos_emb = Parameter (randn (time_seq_len , dim_emb ) * 1e-2 )
4243
44+ self .embed_is_channel_first = embed_is_channel_first
45+
4346 def forward (
4447 self ,
4548 video , # (b c t h w)
@@ -89,9 +92,16 @@ def forward(
8992
9093 dims_to_unsqueeze = embed .ndim - pos_emb .ndim
9194
92- pos_emb = pos_emb .reshape (* pos_emb .shape [:2 ], * ((1 ,) * dims_to_unsqueeze ) , pos_emb .shape [- 1 ])
95+ one_dims = ((1 ,) * dims_to_unsqueeze )
96+
97+ if self .embed_is_channel_first :
98+ pos_emb = pos_emb .reshape (* pos_emb .shape , * one_dims )
99+ else :
100+ pos_emb = pos_emb .reshape (* pos_emb .shape [:2 ], * one_dims , pos_emb .shape [- 1 ])
101+
102+ pos_emb = pos_emb [:, :embed .shape [1 ]]
93103
94- embed = embed + pos_emb [:, : embed . shape [ 1 ]]
104+ embed = embed + pos_emb
95105
96106 outputs [self .output_pos_add_pos_emb ] = embed
97107
0 commit comments