2
2
import torch .nn as nn
3
3
from fastNLP .modules .utils import initial_parameter
4
4
class MLP (nn .Module ):
5
- def __init__ (self , size_layer , num_class = 2 , activation = 'relu' , initial_method = None ):
5
+ def __init__ (self , size_layer , activation = 'relu' , initial_method = None ):
6
6
"""Multilayer Perceptrons as a decoder
7
7
8
- Args:
9
- size_layer: list of int, define the size of MLP layers
10
- num_class: int, num of class in output, should be 2 or the last layer's size
11
- activation: str or function, the activation function for hidden layers
8
+ :param size_layer: list of int, define the size of MLP layers
9
+ :param activation: str or function, the activation function for hidden layers
10
+
11
+ .. note::
12
+ There is no activation function applying on output layer.
13
+
12
14
"""
13
15
super (MLP , self ).__init__ ()
14
16
self .hiddens = nn .ModuleList ()
@@ -19,13 +21,6 @@ def __init__(self, size_layer, num_class=2, activation='relu' , initial_method =
19
21
else :
20
22
self .hiddens .append (nn .Linear (size_layer [i - 1 ], size_layer [i ]))
21
23
22
- if num_class == 2 :
23
- self .out_active = nn .LogSigmoid ()
24
- elif num_class == size_layer [- 1 ]:
25
- self .out_active = nn .LogSoftmax (dim = 1 )
26
- else :
27
- raise ValueError ("should set output num_class correctly: {}" .format (num_class ))
28
-
29
24
actives = {
30
25
'relu' : nn .ReLU (),
31
26
'tanh' : nn .Tanh ()
@@ -37,17 +32,18 @@ def __init__(self, size_layer, num_class=2, activation='relu' , initial_method =
37
32
else :
38
33
raise ValueError ("should set activation correctly: {}" .format (activation ))
39
34
initial_parameter (self , initial_method )
35
+
40
36
def forward (self , x ):
41
37
for layer in self .hiddens :
42
38
x = self .hidden_active (layer (x ))
43
- x = self .out_active ( self . output (x ) )
39
+ x = self .output (x )
44
40
return x
45
41
46
42
47
43
48
44
if __name__ == '__main__' :
49
45
net1 = MLP ([5 ,10 ,5 ])
50
- net2 = MLP ([5 ,10 ,5 ], 5 )
46
+ net2 = MLP ([5 ,10 ,5 ], 'tanh' )
51
47
for net in [net1 , net2 ]:
52
48
x = torch .randn (5 , 5 )
53
49
y = net (x )
0 commit comments