55
66
77# =============================================================================
8- # sin / cos / tanh / exp / log
8+ # Basic functions: sin / cos / tanh / exp / log
99# =============================================================================
1010class Sin (Function ):
1111 def forward (self , x ):
@@ -88,7 +88,7 @@ def log(x):
8888
8989
9090# =============================================================================
91- # Tensor operations: sum / repeat / reshape / sum_to / broadcast_to / get_item
91+ # Tensor operations: reshape / transpose / get_item / expand_dims / flatten
9292# =============================================================================
9393class Reshape (Function ):
9494 def __init__ (self , shape ):
@@ -109,6 +109,65 @@ def reshape(x, shape):
109109 return Reshape (shape )(x )
110110
111111
112+ class Transpose (Function ):
113+ def __init__ (self , axes = None ):
114+ self .axes = axes
115+
116+ def forward (self , x ):
117+ y = x .transpose (self .axes )
118+ return y
119+
120+ def backward (self , gy ):
121+ if self .axes is None :
122+ return transpose (gy )
123+
124+ axes_len = len (self .axes )
125+ inv_axes = tuple (np .argsort ([ax % axes_len for ax in self .axes ]))
126+ return transpose (gy , inv_axes )
127+
128+
129+ def transpose (x , axes = None ):
130+ return Transpose (axes )(x )
131+
132+
133+ class GetItem (Function ):
134+ def __init__ (self , slices ):
135+ self .slices = slices
136+
137+ def forward (self , x ):
138+ y = x [self .slices ]
139+ return y
140+
141+ def backward (self , gy ):
142+ x , = self .inputs
143+ f = GetItemGrad (self .slices , x .shape )
144+ return f (gy )
145+
146+
147+ class GetItemGrad (Function ):
148+ def __init__ (self , slices , in_shape ):
149+ self .slices = slices
150+ self .in_shape = in_shape
151+
152+ def forward (self , gy ):
153+ xp = dezero .cuda .get_array_module (gy )
154+ gx = xp .zeros (self .in_shape , dtype = gy .dtype )
155+
156+ if xp is np :
157+ np .add .at (gx , self .slices , gy )
158+ else :
159+ xp .scatter_add (gx , self .slices , gy )
160+ return gx
161+
162+ def backward (self , ggx ):
163+ return get_item (ggx , self .slices )
164+
165+
166+ def get_item (x , slices ):
167+ f = GetItem (slices )
168+ return f (x )
169+
170+
112171def expand_dims (x , axis ):
113172 x = as_variable (x )
114173 shape = list (x .shape )
@@ -120,6 +179,10 @@ def flatten(x):
120179 """Flattens the input. Does not affect the batch size."""
121180 return reshape (x , (x .shape [0 ], - 1 ))
122181
182+
183+ # =============================================================================
184+ # sum / sum_to / broadcast_to / average / matmul / linear
185+ # =============================================================================
123186class Sum (Function ):
124187 def __init__ (self , axis , keepdims ):
125188 self .axis = axis
@@ -228,67 +291,8 @@ def linear_simple(x, W, b=None):
228291 return y
229292
230293
231- class Transpose (Function ):
232- def __init__ (self , axes = None ):
233- self .axes = axes
234-
235- def forward (self , x ):
236- y = x .transpose (self .axes )
237- return y
238-
239- def backward (self , gy ):
240- if self .axes is None :
241- return transpose (gy )
242-
243- axes_len = len (self .axes )
244- inv_axes = tuple (np .argsort ([ax % axes_len for ax in self .axes ]))
245- return transpose (gy , inv_axes )
246-
247-
248- def transpose (x , axes = None ):
249- return Transpose (axes )(x )
250-
251-
252- class GetItem (Function ):
253- def __init__ (self , slices ):
254- self .slices = slices
255-
256- def forward (self , x ):
257- y = x [self .slices ]
258- return y
259-
260- def backward (self , gy ):
261- x , = self .inputs
262- f = GetItemGrad (self .slices , x .shape )
263- return f (gy )
264-
265-
266- class GetItemGrad (Function ):
267- def __init__ (self , slices , in_shape ):
268- self .slices = slices
269- self .in_shape = in_shape
270-
271- def forward (self , gy ):
272- xp = dezero .cuda .get_array_module (gy )
273- gx = xp .zeros (self .in_shape , dtype = gy .dtype )
274-
275- if xp is np :
276- np .add .at (gx , self .slices , gy )
277- else :
278- xp .scatter_add (gx , self .slices , gy )
279- return gx
280-
281- def backward (self , ggx ):
282- return get_item (ggx , self .slices )
283-
284-
285- def get_item (x , slices ):
286- f = GetItem (slices )
287- return f (x )
288-
289-
290294# =============================================================================
291- # activation function
295+ # activation function: sigmoid / relu / softmax / log_softmax / leaky_relu
292296# =============================================================================
293297def sigmoid_simple (x ):
294298 x = as_variable (x )
@@ -398,8 +402,10 @@ def backward(self, gy):
398402
399403def leaky_relu (x , slope = 0.2 ):
400404 return LeakyReLU (slope )(x )
405+
406+
401407# =============================================================================
402- # loss function
408+ # loss function: mean_squared_error / softmax_cross_entropy / sigmoid_cross_entropy / binary_cross_entropy
403409# =============================================================================
404410def mean_squared_error_simple (x0 , x1 ):
405411 x0 , x1 = as_variable (x0 ), as_variable (x1 )
@@ -487,7 +493,7 @@ def binary_cross_entropy(p, t):
487493
488494
489495# =============================================================================
490- # utility function
496+ # accuracy / dropout / batch_norm / embed_id
491497# =============================================================================
492498def accuracy (y , t ):
493499 """
@@ -501,13 +507,6 @@ def accuracy(y, t):
501507 return Variable (as_array (acc ))
502508
503509
504- # =============================================================================
505- # embed_id / dropout / batch_norm
506- # =============================================================================
507- def embed_id (x , W ):
508- return W [x ]
509-
510-
511510def dropout (x , dropout_ratio = 0.5 ):
512511 x = as_variable (x )
513512
@@ -593,6 +592,10 @@ def batch_nrom(x, gamma, beta, mean, var, decay=0.9, eps=2e-5):
593592 return BatchNorm (mean , var , decay , eps )(x , gamma , beta )
594593
595594
595+ def embed_id (x , W ):
596+ return W [x ]
597+
598+
596599# =============================================================================
597600# max / min / clip
598601# =============================================================================
0 commit comments