Skip to content

Commit 15418da

Browse files
authored
Built in flatten (#486)
1 parent b78f4ba commit 15418da

File tree

2 files changed

+2
-7
lines changed

2 files changed

+2
-7
lines changed

segmentation_models_pytorch/base/heads.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch.nn as nn
2-
from .modules import Flatten, Activation
2+
from .modules import Activation
33

44

55
class SegmentationHead(nn.Sequential):
@@ -17,7 +17,7 @@ def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=
1717
if pooling not in ("max", "avg"):
1818
raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling))
1919
pool = nn.AdaptiveAvgPool2d(1) if pooling == 'avg' else nn.AdaptiveMaxPool2d(1)
20-
flatten = Flatten()
20+
flatten = nn.Flatten()
2121
dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
2222
linear = nn.Linear(in_channels, classes, bias=True)
2323
activation = Activation(activation)

segmentation_models_pytorch/base/modules.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,3 @@ def __init__(self, name, **params):
118118

119119
def forward(self, x):
120120
return self.attention(x)
121-
122-
123-
class Flatten(nn.Module):
124-
def forward(self, x):
125-
return x.view(x.shape[0], -1)

0 commit comments

Comments
 (0)