-
Notifications
You must be signed in to change notification settings - Fork 91
Open
Description
flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data
return flat_grad_grad_kl + v * damping
stepdir = conjugate_gradients(Fvp, -loss_grad, 10)
shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True)
lm = torch.sqrt(shs / max_kl)
fullstep = stepdir / lm[0]
According to the TRPO formular,
So
but your coding is different from that, why?
Metadata
Metadata
Assignees
Labels
No labels