Skip to content

Commit 093491b

Browse files
authored
Optimize AAN transformer and small fixes (#1482)
* Optimize AAN transformer and small fixes * Make use of FFN layer in AAN an option
1 parent de6d396 commit 093491b

File tree

6 files changed

+42
-37
lines changed

6 files changed

+42
-37
lines changed

onmt/decoders/transformer.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,17 @@ class TransformerDecoderLayer(nn.Module):
2323
"""
2424

2525
def __init__(self, d_model, heads, d_ff, dropout,
26-
self_attn_type="scaled-dot", max_relative_positions=0):
26+
self_attn_type="scaled-dot", max_relative_positions=0,
27+
aan_useffn=False):
2728
super(TransformerDecoderLayer, self).__init__()
2829

2930
if self_attn_type == "scaled-dot":
3031
self.self_attn = MultiHeadedAttention(
3132
heads, d_model, dropout=dropout,
3233
max_relative_positions=max_relative_positions)
3334
elif self_attn_type == "average":
34-
self.self_attn = AverageAttention(d_model, dropout=dropout)
35+
self.self_attn = AverageAttention(d_model, dropout=dropout,
36+
aan_useffn=aan_useffn)
3537

3638
self.context_attn = MultiHeadedAttention(
3739
heads, d_model, dropout=dropout)
@@ -72,7 +74,7 @@ def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
7274
query, attn = self.self_attn(input_norm, input_norm, input_norm,
7375
mask=dec_mask,
7476
layer_cache=layer_cache,
75-
type="self")
77+
attn_type="self")
7678
elif isinstance(self.self_attn, AverageAttention):
7779
query, attn = self.self_attn(input_norm, mask=dec_mask,
7880
layer_cache=layer_cache, step=step)
@@ -83,7 +85,7 @@ def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
8385
mid, attn = self.context_attn(memory_bank, memory_bank, query_norm,
8486
mask=src_pad_mask,
8587
layer_cache=layer_cache,
86-
type="context")
88+
attn_type="context")
8789
output = self.feed_forward(self.drop(mid) + query)
8890

8991
return output, attn
@@ -127,7 +129,7 @@ class TransformerDecoder(DecoderBase):
127129

128130
def __init__(self, num_layers, d_model, heads, d_ff,
129131
copy_attn, self_attn_type, dropout, embeddings,
130-
max_relative_positions):
132+
max_relative_positions, aan_useffn):
131133
super(TransformerDecoder, self).__init__()
132134

133135
self.embeddings = embeddings
@@ -138,7 +140,8 @@ def __init__(self, num_layers, d_model, heads, d_ff,
138140
self.transformer_layers = nn.ModuleList(
139141
[TransformerDecoderLayer(d_model, heads, d_ff, dropout,
140142
self_attn_type=self_attn_type,
141-
max_relative_positions=max_relative_positions)
143+
max_relative_positions=max_relative_positions,
144+
aan_useffn=aan_useffn)
142145
for i in range(num_layers)])
143146

144147
# previously, there was a GlobalAttention module here for copy
@@ -159,7 +162,8 @@ def from_opt(cls, opt, embeddings):
159162
opt.self_attn_type,
160163
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
161164
embeddings,
162-
opt.max_relative_positions)
165+
opt.max_relative_positions,
166+
opt.aan_useffn)
163167

164168
def init_state(self, src, memory_bank, enc_hidden):
165169
"""Initialize decoder state."""
@@ -233,7 +237,8 @@ def _init_cache(self, memory_bank):
233237
for i, layer in enumerate(self.transformer_layers):
234238
layer_cache = {"memory_keys": None, "memory_values": None}
235239
if isinstance(layer.self_attn, AverageAttention):
236-
layer_cache["prev_g"] = torch.zeros((batch_size, 1, depth))
240+
layer_cache["prev_g"] = torch.zeros((batch_size, 1, depth),
241+
device=memory_bank.device)
237242
else:
238243
layer_cache["self_keys"] = None
239244
layer_cache["self_values"] = None

onmt/encoders/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def forward(self, inputs, mask):
4646
"""
4747
input_norm = self.layer_norm(inputs)
4848
context, _ = self.self_attn(input_norm, input_norm, input_norm,
49-
mask=mask, type="self")
49+
mask=mask, attn_type="self")
5050
out = self.dropout(context) + inputs
5151
return self.feed_forward(out)
5252

onmt/modules/average_attn.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ class AverageAttention(nn.Module):
1919
dropout (float): dropout parameter
2020
"""
2121

22-
def __init__(self, model_dim, dropout=0.1):
22+
def __init__(self, model_dim, dropout=0.1, aan_useffn=False):
2323
self.model_dim = model_dim
24-
24+
self.aan_useffn = aan_useffn
2525
super(AverageAttention, self).__init__()
26-
27-
self.average_layer = PositionwiseFeedForward(model_dim, model_dim,
28-
dropout)
26+
if aan_useffn:
27+
self.average_layer = PositionwiseFeedForward(model_dim, model_dim,
28+
dropout)
2929
self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2)
3030

31-
def cumulative_average_mask(self, batch_size, inputs_len):
31+
def cumulative_average_mask(self, batch_size, inputs_len, device):
3232
"""
3333
Builds the mask to compute the cumulative average as described in
3434
:cite:`DBLP:journals/corr/abs-1805-00631` -- Figure 3
@@ -43,9 +43,10 @@ def cumulative_average_mask(self, batch_size, inputs_len):
4343
* A Tensor of shape ``(batch_size, input_len, input_len)``
4444
"""
4545

46-
triangle = torch.tril(torch.ones(inputs_len, inputs_len))
47-
weights = torch.ones(1, inputs_len) / torch.arange(
48-
1, inputs_len + 1, dtype=torch.float)
46+
triangle = torch.tril(torch.ones(inputs_len, inputs_len,
47+
dtype=torch.float, device=device))
48+
weights = torch.ones(1, inputs_len, dtype=torch.float, device=device) \
49+
/ torch.arange(1, inputs_len + 1, dtype=torch.float, device=device)
4950
mask = triangle * weights.transpose(0, 1)
5051

5152
return mask.unsqueeze(0).expand(batch_size, inputs_len, inputs_len)
@@ -72,14 +73,13 @@ def cumulative_average(self, inputs, mask_or_step,
7273

7374
if layer_cache is not None:
7475
step = mask_or_step
75-
device = inputs.device
7676
average_attention = (inputs + step *
77-
layer_cache["prev_g"].to(device)) / (step + 1)
77+
layer_cache["prev_g"]) / (step + 1)
7878
layer_cache["prev_g"] = average_attention
7979
return average_attention
8080
else:
8181
mask = mask_or_step
82-
return torch.matmul(mask, inputs)
82+
return torch.matmul(mask.to(inputs.dtype), inputs)
8383

8484
def forward(self, inputs, mask=None, layer_cache=None, step=None):
8585
"""
@@ -96,13 +96,12 @@ def forward(self, inputs, mask=None, layer_cache=None, step=None):
9696

9797
batch_size = inputs.size(0)
9898
inputs_len = inputs.size(1)
99-
100-
device = inputs.device
10199
average_outputs = self.cumulative_average(
102100
inputs, self.cumulative_average_mask(batch_size,
103-
inputs_len).to(device).float()
101+
inputs_len, inputs.device)
104102
if layer_cache is None else step, layer_cache=layer_cache)
105-
average_outputs = self.average_layer(average_outputs)
103+
if self.aan_useffn:
104+
average_outputs = self.average_layer(average_outputs)
106105
gating_outputs = self.gating_layer(torch.cat((inputs,
107106
average_outputs), -1))
108107
input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2)

onmt/modules/multi_headed_attn.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self, head_count, model_dim, dropout=0.1,
7575
vocab_size, self.dim_per_head)
7676

7777
def forward(self, key, value, query, mask=None,
78-
layer_cache=None, type=None):
78+
layer_cache=None, attn_type=None):
7979
"""
8080
Compute the context vector and the attention vectors.
8181
@@ -117,7 +117,6 @@ def forward(self, key, value, query, mask=None,
117117
head_count = self.head_count
118118
key_len = key.size(1)
119119
query_len = query.size(1)
120-
device = key.device
121120

122121
def shape(x):
123122
"""Projection."""
@@ -131,23 +130,23 @@ def unshape(x):
131130

132131
# 1) Project key, value, and query.
133132
if layer_cache is not None:
134-
if type == "self":
133+
if attn_type == "self":
135134
query, key, value = self.linear_query(query),\
136135
self.linear_keys(query),\
137136
self.linear_values(query)
138137
key = shape(key)
139138
value = shape(value)
140139
if layer_cache["self_keys"] is not None:
141140
key = torch.cat(
142-
(layer_cache["self_keys"].to(device), key),
141+
(layer_cache["self_keys"], key),
143142
dim=2)
144143
if layer_cache["self_values"] is not None:
145144
value = torch.cat(
146-
(layer_cache["self_values"].to(device), value),
145+
(layer_cache["self_values"], value),
147146
dim=2)
148147
layer_cache["self_keys"] = key
149148
layer_cache["self_values"] = value
150-
elif type == "context":
149+
elif attn_type == "context":
151150
query = self.linear_query(query)
152151
if layer_cache["memory_keys"] is None:
153152
key, value = self.linear_keys(key),\
@@ -166,18 +165,18 @@ def unshape(x):
166165
key = shape(key)
167166
value = shape(value)
168167

169-
if self.max_relative_positions > 0 and type == "self":
168+
if self.max_relative_positions > 0 and attn_type == "self":
170169
key_len = key.size(2)
171170
# 1 or key_len x key_len
172171
relative_positions_matrix = generate_relative_positions_matrix(
173172
key_len, self.max_relative_positions,
174173
cache=True if layer_cache is not None else False)
175174
# 1 or key_len x key_len x dim_per_head
176175
relations_keys = self.relative_positions_embeddings(
177-
relative_positions_matrix.to(device))
176+
relative_positions_matrix.to(key.device))
178177
# 1 or key_len x key_len x dim_per_head
179178
relations_values = self.relative_positions_embeddings(
180-
relative_positions_matrix.to(device))
179+
relative_positions_matrix.to(key.device))
181180

182181
query = shape(query)
183182

@@ -189,7 +188,7 @@ def unshape(x):
189188
# batch x num_heads x query_len x key_len
190189
query_key = torch.matmul(query, key.transpose(2, 3))
191190

192-
if self.max_relative_positions > 0 and type == "self":
191+
if self.max_relative_positions > 0 and attn_type == "self":
193192
scores = query_key + relative_matmul(query, relations_keys, True)
194193
else:
195194
scores = query_key
@@ -205,7 +204,7 @@ def unshape(x):
205204

206205
context_original = torch.matmul(drop_attn, value)
207206

208-
if self.max_relative_positions > 0 and type == "self":
207+
if self.max_relative_positions > 0 and attn_type == "self":
209208
context = unshape(context_original
210209
+ relative_matmul(drop_attn,
211210
relations_values,

onmt/opts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ def model_opts(parser):
151151
help='Number of heads for transformer self-attention')
152152
group.add('--transformer_ff', '-transformer_ff', type=int, default=2048,
153153
help='Size of hidden transformer feed-forward')
154+
group.add('--aan_useffn', '-aan_useffn', action="store_true",
155+
help='Turn on the FFN layer in the AAN decoder')
154156

155157
# Generator and loss options.
156158
group.add('--copy_attn', '-copy_attn', action="store_true",

onmt/tests/pull_request_chk.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ echo > ${LOG_FILE} # Empty the log file.
1010
PROJECT_ROOT=`dirname "$0"`"/../../"
1111
DATA_DIR="$PROJECT_ROOT/data"
1212
TEST_DIR="$PROJECT_ROOT/onmt/tests"
13-
PYTHON="python"
13+
PYTHON="python3"
1414

1515
clean_up()
1616
{

0 commit comments

Comments
 (0)