-
Notifications
You must be signed in to change notification settings - Fork 374
Open
Description
As you mentioned in the start of the 2nd tutorial, it is good idea to mix teacher forcing with "feed previous" technique, while decoding. Just thought I could share some ideas on how to do that.
prob = 0.5 # set as placeholder or tf.constant
r = tf.random_normal(shape=[],mean=prob, stddev=0.5, dtype=tf.float32) # get a random value
feed_previous = r > prob # sample -> True/FalseIn the loop_fn_transition function, you could add an outer condition like this.
if feed_previous:
input = tf.cond(finished, padded_next_input, search_for_next_input)
else:
input = tf.cond(finished, padded_next_input, fetch_next_decoder_target)The fetch_next_decoder_target function is supposed to fetch the next decoder target by indexing decoder_targets with time - decoder_targets[time]. Though you need to transpose decoder_targets to "time major" format.
Hope this helps. I will try this and add a pull request if I find time.
ematvey, j-min, esafak and mingchen62
Metadata
Metadata
Assignees
Labels
No labels