Skip to content

Commit 06c6e62

Browse files
Remove explict forward calls in RecurrentPPO (#320)
* Fix: Do not call model.forward() directly! * update changelog * Fix type hint in distributions * Update changelog.rst --------- Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
1 parent d67de80 commit 06c6e62

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

docs/misc/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ New Features:
1919

2020
Bug Fixes:
2121
^^^^^^^^^^
22+
- Do not call ``forward()`` method directly in ``RecurrentPPO``
2223

2324
Deprecations:
2425
^^^^^^^^^^^^^

sb3_contrib/common/maskable/distributions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,10 @@ class MaskableCategoricalDistribution(MaskableDistribution):
110110
:param action_dim: Number of discrete actions
111111
"""
112112

113+
distribution: MaskableCategorical
114+
113115
def __init__(self, action_dim: int):
114116
super().__init__()
115-
self.distribution: MaskableCategorical | None = None
116117
self.action_dim = action_dim
117118

118119
def proba_distribution_net(self, latent_dim: int) -> nn.Module:

sb3_contrib/ppo_recurrent/ppo_recurrent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def collect_rollouts(
239239
# Convert to pytorch tensor or to TensorDict
240240
obs_tensor = obs_as_tensor(self._last_obs, self.device)
241241
episode_starts = th.tensor(self._last_episode_starts, dtype=th.float32, device=self.device)
242-
actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
242+
actions, values, log_probs, lstm_states = self.policy(obs_tensor, lstm_states, episode_starts)
243243

244244
actions = actions.cpu().numpy()
245245

0 commit comments

Comments
 (0)