@@ -28,6 +28,20 @@ def __init__(self, dim, fn):
2828 def forward (self , x ):
2929 return self .fn (self .norm (x )) + x
3030
31+ class FeedForward (nn .Module ):
32+ def __init__ (self , dim , mult = 4 , dropout = 0. ):
33+ super ().__init__ ()
34+ inner_dim = int (dim * mult )
35+ self .net = nn .Sequential (
36+ nn .Linear (dim , inner_dim ),
37+ nn .GELU (),
38+ nn .Dropout (dropout ),
39+ nn .Linear (inner_dim , dim ),
40+ nn .Dropout (dropout )
41+ )
42+ def forward (self , x ):
43+ return self .net (x )
44+
3145# MBConv
3246
3347class SqueezeExcitation (nn .Module ):
@@ -244,10 +258,12 @@ def __init__(
244258 ),
245259 Rearrange ('b d (x w1) (y w2) -> b x y w1 w2 d' , w1 = w , w2 = w ), # block-like attention
246260 PreNormResidual (layer_dim , Attention (dim = layer_dim , dim_head = dim_head , dropout = dropout , window_size = w )),
261+ PreNormResidual (layer_dim , FeedForward (dim = layer_dim , dropout = dropout )),
247262 Rearrange ('b x y w1 w2 d -> b d (x w1) (y w2)' ),
248263
249264 Rearrange ('b d (w1 x) (w2 y) -> b x y w1 w2 d' , w1 = w , w2 = w ), # grid-like attention
250265 PreNormResidual (layer_dim , Attention (dim = layer_dim , dim_head = dim_head , dropout = dropout , window_size = w )),
266+ PreNormResidual (layer_dim , FeedForward (dim = layer_dim , dropout = dropout )),
251267 Rearrange ('b x y w1 w2 d -> b d (w1 x) (w2 y)' ),
252268 )
253269
0 commit comments