Skip to content

Commit 7830e21

Browse files
lingvo-botcopybara-github
authored andcommitted
Avoid using inference-time arithmetic on inputs of reshapes, use -1 instead.
PiperOrigin-RevId: 591598526
1 parent 381f8eb commit 7830e21

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

lingvo/core/attention.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -897,18 +897,19 @@ def AttenProbs(inputs):
897897
tf.cast(
898898
py_utils.GetShape(inputs.query_vec)[1],
899899
dtype=py_utils.FPropDtype(p))))
900-
source_batch = py_utils.GetShape(concated_source_vecs)[0]
900+
source_batch, _, source_dim = py_utils.GetShape(concated_source_vecs)
901901
target_batch = py_utils.GetShape(inputs.query_vec)[0]
902902
query_vec = inputs.query_vec * inputs.per_dim_scale
903903
# The n here refers to the "n" described in the comment above.
904-
n = target_batch // source_batch
905-
query_vec = tf.reshape(query_vec, [n, source_batch, -1])
904+
# => [n, source_batch, source_dim] where n = target_batch // source_batch
905+
query_vec = tf.reshape(query_vec, [-1, source_batch, source_dim])
906906
# => [source_batch, source_dim, n]
907907
query_vec = tf.transpose(query_vec, [1, 2, 0])
908-
# => [n, source_batch, source_sequence_len]
908+
source_length = py_utils.GetShape(inputs.per_step_source_padding)[1]
909+
# => [n, source_batch, source_length]
909910
per_step_source_padding = tf.reshape(inputs.per_step_source_padding,
910-
[n, source_batch, -1])
911-
# => [source_batch, source_sequence_len, n]
911+
[-1, source_batch, source_length])
912+
# => [source_batch, source_length, n]
912913
per_step_source_padding = tf.transpose(per_step_source_padding, [1, 2, 0])
913914
# Dot-product part.
914915
# Calls batch_mat_mul since dim > 2 for per-instance matmul.
@@ -982,9 +983,8 @@ def Atten(per_dim_scale, source_padding, source_segment_id,
982983
[py_utils.GetShape(query_vec)[1]])
983984
py_utils.assert_shape_match([py_utils.GetShape(concated_source_vecs)[2]],
984985
[symbolic.ToStatic(p.source_dim)])
985-
source_batch = py_utils.GetShape(concated_source_vecs)[1]
986+
time, source_batch = py_utils.GetShape(concated_source_vecs, 2)
986987
target_batch = py_utils.GetShape(query_vec)[0]
987-
n = target_batch // source_batch
988988
concated_source_vecs = tf.transpose(concated_source_vecs, [1, 0, 2])
989989
concated_source_vecs = tf.identity(
990990
concated_source_vecs, name='concated_source_vecs')
@@ -1000,8 +1000,8 @@ def Atten(per_dim_scale, source_padding, source_segment_id,
10001000
query_segment_id=query_segment_id))
10011001
returned_probs.set_shape(per_step_source_padding.shape)
10021002

1003-
# => [n, source_batch, time].
1004-
probs = tf.reshape(returned_probs, [n, source_batch, -1])
1003+
# => [n, source_batch, time] where n = target_batch // source_batch
1004+
probs = tf.reshape(returned_probs, [-1, source_batch, time])
10051005
# => [source_batch, n, time].
10061006
probs = tf.transpose(probs, [1, 0, 2])
10071007

0 commit comments

Comments
 (0)