You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
But the output from the mixture density layer are treated as scale variables in tfp.distributions.MultivariateNormalDiag. This notes that:
covariance = scale @ scale.T
Thus, it seems we should have been squaring the cov_matrix before putting it into the multivariate normal sampling procedure. This could explain why we end up having to scale down the sigma variable so much in real-world applications.
A todo here is to get a definite answer and do some test to try out what's going on.