@@ -897,18 +897,19 @@ def AttenProbs(inputs):
897897 tf .cast (
898898 py_utils .GetShape (inputs .query_vec )[1 ],
899899 dtype = py_utils .FPropDtype (p ))))
900- source_batch = py_utils .GetShape (concated_source_vecs )[ 0 ]
900+ source_batch , _ , source_dim = py_utils .GetShape (concated_source_vecs )
901901 target_batch = py_utils .GetShape (inputs .query_vec )[0 ]
902902 query_vec = inputs .query_vec * inputs .per_dim_scale
903903 # The n here refers to the "n" described in the comment above.
904- n = target_batch // source_batch
905- query_vec = tf .reshape (query_vec , [n , source_batch , - 1 ])
904+ # => [n, source_batch, source_dim] where n = target_batch // source_batch
905+ query_vec = tf .reshape (query_vec , [- 1 , source_batch , source_dim ])
906906 # => [source_batch, source_dim, n]
907907 query_vec = tf .transpose (query_vec , [1 , 2 , 0 ])
908- # => [n, source_batch, source_sequence_len]
908+ source_length = py_utils .GetShape (inputs .per_step_source_padding )[1 ]
909+ # => [n, source_batch, source_length]
909910 per_step_source_padding = tf .reshape (inputs .per_step_source_padding ,
910- [n , source_batch , - 1 ])
911- # => [source_batch, source_sequence_len , n]
911+ [- 1 , source_batch , source_length ])
912+ # => [source_batch, source_length , n]
912913 per_step_source_padding = tf .transpose (per_step_source_padding , [1 , 2 , 0 ])
913914 # Dot-product part.
914915 # Calls batch_mat_mul since dim > 2 for per-instance matmul.
@@ -982,9 +983,8 @@ def Atten(per_dim_scale, source_padding, source_segment_id,
982983 [py_utils .GetShape (query_vec )[1 ]])
983984 py_utils .assert_shape_match ([py_utils .GetShape (concated_source_vecs )[2 ]],
984985 [symbolic .ToStatic (p .source_dim )])
985- source_batch = py_utils .GetShape (concated_source_vecs )[ 1 ]
986+ time , source_batch = py_utils .GetShape (concated_source_vecs , 2 )
986987 target_batch = py_utils .GetShape (query_vec )[0 ]
987- n = target_batch // source_batch
988988 concated_source_vecs = tf .transpose (concated_source_vecs , [1 , 0 , 2 ])
989989 concated_source_vecs = tf .identity (
990990 concated_source_vecs , name = 'concated_source_vecs' )
@@ -1000,8 +1000,8 @@ def Atten(per_dim_scale, source_padding, source_segment_id,
10001000 query_segment_id = query_segment_id ))
10011001 returned_probs .set_shape (per_step_source_padding .shape )
10021002
1003- # => [n, source_batch, time].
1004- probs = tf .reshape (returned_probs , [n , source_batch , - 1 ])
1003+ # => [n, source_batch, time] where n = target_batch // source_batch
1004+ probs = tf .reshape (returned_probs , [- 1 , source_batch , time ])
10051005 # => [source_batch, n, time].
10061006 probs = tf .transpose (probs , [1 , 0 , 2 ])
10071007
0 commit comments