Skip to content

Commit b8b4298

Browse files
committed
Just move some comments
1 parent cdded76 commit b8b4298

File tree

1 file changed

+74
-71
lines changed

1 file changed

+74
-71
lines changed

dezero/functions.py

Lines changed: 74 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
# =============================================================================
8-
# sin / cos / tanh / exp / log
8+
# Basic functions: sin / cos / tanh / exp / log
99
# =============================================================================
1010
class Sin(Function):
1111
def forward(self, x):
@@ -88,7 +88,7 @@ def log(x):
8888

8989

9090
# =============================================================================
91-
# Tensor operations: sum / repeat / reshape / sum_to / broadcast_to / get_item
91+
# Tensor operations: reshape / transpose / get_item / expand_dims / flatten
9292
# =============================================================================
9393
class Reshape(Function):
9494
def __init__(self, shape):
@@ -109,6 +109,65 @@ def reshape(x, shape):
109109
return Reshape(shape)(x)
110110

111111

112+
class Transpose(Function):
113+
def __init__(self, axes=None):
114+
self.axes = axes
115+
116+
def forward(self, x):
117+
y = x.transpose(self.axes)
118+
return y
119+
120+
def backward(self, gy):
121+
if self.axes is None:
122+
return transpose(gy)
123+
124+
axes_len = len(self.axes)
125+
inv_axes = tuple(np.argsort([ax % axes_len for ax in self.axes]))
126+
return transpose(gy, inv_axes)
127+
128+
129+
def transpose(x, axes=None):
130+
return Transpose(axes)(x)
131+
132+
133+
class GetItem(Function):
134+
def __init__(self, slices):
135+
self.slices = slices
136+
137+
def forward(self, x):
138+
y = x[self.slices]
139+
return y
140+
141+
def backward(self, gy):
142+
x, = self.inputs
143+
f = GetItemGrad(self.slices, x.shape)
144+
return f(gy)
145+
146+
147+
class GetItemGrad(Function):
148+
def __init__(self, slices, in_shape):
149+
self.slices = slices
150+
self.in_shape = in_shape
151+
152+
def forward(self, gy):
153+
xp = dezero.cuda.get_array_module(gy)
154+
gx = xp.zeros(self.in_shape, dtype=gy.dtype)
155+
156+
if xp is np:
157+
np.add.at(gx, self.slices, gy)
158+
else:
159+
xp.scatter_add(gx, self.slices, gy)
160+
return gx
161+
162+
def backward(self, ggx):
163+
return get_item(ggx, self.slices)
164+
165+
166+
def get_item(x, slices):
167+
f = GetItem(slices)
168+
return f(x)
169+
170+
112171
def expand_dims(x, axis):
113172
x = as_variable(x)
114173
shape = list(x.shape)
@@ -120,6 +179,10 @@ def flatten(x):
120179
"""Flattens the input. Does not affect the batch size."""
121180
return reshape(x, (x.shape[0], -1))
122181

182+
183+
# =============================================================================
184+
# sum / sum_to / broadcast_to / average / matmul / linear
185+
# =============================================================================
123186
class Sum(Function):
124187
def __init__(self, axis, keepdims):
125188
self.axis = axis
@@ -228,67 +291,8 @@ def linear_simple(x, W, b=None):
228291
return y
229292

230293

231-
class Transpose(Function):
232-
def __init__(self, axes=None):
233-
self.axes = axes
234-
235-
def forward(self, x):
236-
y = x.transpose(self.axes)
237-
return y
238-
239-
def backward(self, gy):
240-
if self.axes is None:
241-
return transpose(gy)
242-
243-
axes_len = len(self.axes)
244-
inv_axes = tuple(np.argsort([ax % axes_len for ax in self.axes]))
245-
return transpose(gy, inv_axes)
246-
247-
248-
def transpose(x, axes=None):
249-
return Transpose(axes)(x)
250-
251-
252-
class GetItem(Function):
253-
def __init__(self, slices):
254-
self.slices = slices
255-
256-
def forward(self, x):
257-
y = x[self.slices]
258-
return y
259-
260-
def backward(self, gy):
261-
x, = self.inputs
262-
f = GetItemGrad(self.slices, x.shape)
263-
return f(gy)
264-
265-
266-
class GetItemGrad(Function):
267-
def __init__(self, slices, in_shape):
268-
self.slices = slices
269-
self.in_shape = in_shape
270-
271-
def forward(self, gy):
272-
xp = dezero.cuda.get_array_module(gy)
273-
gx = xp.zeros(self.in_shape, dtype=gy.dtype)
274-
275-
if xp is np:
276-
np.add.at(gx, self.slices, gy)
277-
else:
278-
xp.scatter_add(gx, self.slices, gy)
279-
return gx
280-
281-
def backward(self, ggx):
282-
return get_item(ggx, self.slices)
283-
284-
285-
def get_item(x, slices):
286-
f = GetItem(slices)
287-
return f(x)
288-
289-
290294
# =============================================================================
291-
# activation function
295+
# activation function: sigmoid / relu / softmax / log_softmax / leaky_relu
292296
# =============================================================================
293297
def sigmoid_simple(x):
294298
x = as_variable(x)
@@ -398,8 +402,10 @@ def backward(self, gy):
398402

399403
def leaky_relu(x, slope=0.2):
400404
return LeakyReLU(slope)(x)
405+
406+
401407
# =============================================================================
402-
# loss function
408+
# loss function: mean_squared_error / softmax_cross_entropy / sigmoid_cross_entropy / binary_cross_entropy
403409
# =============================================================================
404410
def mean_squared_error_simple(x0, x1):
405411
x0, x1 = as_variable(x0), as_variable(x1)
@@ -487,7 +493,7 @@ def binary_cross_entropy(p, t):
487493

488494

489495
# =============================================================================
490-
# utility function
496+
# accuracy / dropout / batch_norm / embed_id
491497
# =============================================================================
492498
def accuracy(y, t):
493499
"""
@@ -501,13 +507,6 @@ def accuracy(y, t):
501507
return Variable(as_array(acc))
502508

503509

504-
# =============================================================================
505-
# embed_id / dropout / batch_norm
506-
# =============================================================================
507-
def embed_id(x, W):
508-
return W[x]
509-
510-
511510
def dropout(x, dropout_ratio=0.5):
512511
x = as_variable(x)
513512

@@ -593,6 +592,10 @@ def batch_nrom(x, gamma, beta, mean, var, decay=0.9, eps=2e-5):
593592
return BatchNorm(mean, var, decay, eps)(x, gamma, beta)
594593

595594

595+
def embed_id(x, W):
596+
return W[x]
597+
598+
596599
# =============================================================================
597600
# max / min / clip
598601
# =============================================================================

0 commit comments

Comments
 (0)