11import os
2+ import weakref
23import numpy as np
34import dezero .functions as F
45from dezero import cuda
@@ -18,11 +19,16 @@ def __setattr__(self, name, value):
1819 self ._params .add (name )
1920 super ().__setattr__ (name , value )
2021
21- def __call__ (self , * args , ** kwargs ):
22- return self .forward (* args , ** kwargs )
22+ def __call__ (self , * inputs ):
23+ outputs = self .forward (* inputs )
24+ if not isinstance (outputs , tuple ):
25+ outputs = (outputs ,)
26+ self .inputs = [weakref .ref (x ) for x in inputs ]
27+ self .outputs = [weakref .ref (y ) for y in outputs ]
28+ return outputs if len (outputs ) > 1 else outputs [0 ]
2329
24- def forward (self , * args , ** kwargs ):
25- return self . __call__ ( * args , ** kwargs )
30+ def forward (self , inputs ):
31+ raise NotImplementedError ( )
2632
2733 def params (self ):
2834 for name in self ._params :
@@ -101,7 +107,7 @@ def _init_W(self, xp=np):
101107 W_data = xp .random .randn (I , O ).astype (self .dtype ) * np .sqrt (1 / I )
102108 self .W .data = W_data
103109
104- def __call__ (self , x ):
110+ def forward (self , x ):
105111 if self .W .data is None :
106112 self .in_size = x .shape [1 ]
107113 xp = cuda .get_array_module (x )
@@ -150,7 +156,7 @@ def _init_W(self, xp=np):
150156 W_data = xp .random .randn (OC , C , KH , KW ).astype (self .dtype ) * scale
151157 self .W .data = W_data
152158
153- def __call__ (self , x ):
159+ def forward (self , x ):
154160 if self .W .data is None :
155161 self .in_channels = x .shape [1 ]
156162 xp = cuda .get_array_module (x )
@@ -199,7 +205,7 @@ def _init_W(self, xp=np):
199205 W_data = xp .random .randn (C , OC , KH , KW ).astype (self .dtype ) * scale
200206 self .W .data = W_data
201207
202- def __call__ (self , x ):
208+ def forward (self , x ):
203209 if self .W .data is None :
204210 self .in_channels = x .shape [1 ]
205211 xp = cuda .get_array_module (x )
@@ -231,7 +237,7 @@ def __init__(self, hidden_size, in_size=None):
231237 def reset_state (self ):
232238 self .h = None
233239
234- def __call__ (self , x ):
240+ def forward (self , x ):
235241 if self .h is None :
236242 h_new = F .tanh (self .x2h (x ))
237243 else :
@@ -259,7 +265,7 @@ def reset_state(self):
259265 self .h = None
260266 self .c = None
261267
262- def __call__ (self , x ):
268+ def forward (self , x ):
263269 if self .h is None :
264270 f = F .sigmoid (self .x2f (x ))
265271 i = F .sigmoid (self .x2i (x ))
0 commit comments