Skip to content

It seems that the importance sampling code part is wrong. #22

@yhy258

Description

@yhy258

pytorch-trpo/main.py

Lines 108 to 119 in e200eb8

fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()
def get_loss(volatile=False):
if volatile:
with torch.no_grad():
action_means, action_log_stds, action_stds = policy_net(Variable(states))
else:
action_means, action_log_stds, action_stds = policy_net(Variable(states))
log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
return action_loss.mean()

The fixed log prob part of the line and the "get_loss" function part are exactly the same.
The two parts are executed consecutively so that the two values ("fixed_log_prob", "log_prob") ​​are exactly the same.
Is there a reason you wrote the code like this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions