Skip to content

RecurrentPPO forward and predict do not reshape action before clip it #317

@immortal-boy

Description

@immortal-boy

🐛 Bug

When action space shape is not flatten, the shape is not match when clip in RecurrentPPO.collect_rollout and RecurrentActorCriticPolicy.predict.
But PPO has no problem, because it reshape action shape in ActorCriticPolicy.foward and ActorCriticPolicy.predict.

Code example

import gymnasium as gym
import numpy
import sb3_contrib
import stable_baselines3 as sb3

class TestEnv(gym.Env):
    def __init__(self):
        self.observation = numpy.empty((10,), dtype=numpy.float32)
        self.observation_space = gym.spaces.Box(
            low=-1,
            high=1,
            shape=(10,),
            dtype=numpy.float32,
        )

        self.action_space = gym.spaces.Box(
            low=-1,
            high=1,
            shape=(2, 2),
            dtype=numpy.float32,
        )

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        return self.observation, {} 

    def step(self, action):
        return self.observation, 1, False, False, {}


if __name__ == "__main__":
    vec_env = sb3.common.env_util.make_vec_env(TestEnv, n_envs=4)
    model = sb3.PPO("MlpPolicy", vec_env)
    model.learn(10)
    model = sb3_contrib.RecurrentPPO("MlpLstmPolicy", vec_env)
    model.learn(10)

Relevant log output / Error message

Traceback (most recent call last):
  File "/Users/yangjun/Workspace/alpha_trader/unit_test.py", line 36, in <module>
    model.learn(10)
  File "/Users/yangjun/Workspace/alpha_trader/venv/lib/python3.12/site-packages/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 450, in learn
    return super().learn(
           ^^^^^^^^^^^^^^
  File "/Users/yangjun/Workspace/alpha_trader/venv/lib/python3.12/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 324, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/yangjun/Workspace/alpha_trader/venv/lib/python3.12/site-packages/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 250, in collect_rollouts
    clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/yangjun/Workspace/alpha_trader/venv/lib/python3.12/site-packages/numpy/_core/fromnumeric.py", line 2341, in clip
    return _wrapfunc(a, 'clip', a_min, a_max, out=out, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/yangjun/Workspace/alpha_trader/venv/lib/python3.12/site-packages/numpy/_core/fromnumeric.py", line 57, in _wrapfunc
    return bound(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/yangjun/Workspace/alpha_trader/venv/lib/python3.12/site-packages/numpy/_core/_methods.py", line 117, in _clip
    return um.clip(a, min, max, out=out, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: operands could not be broadcast together with shapes (4,4) (2,2) (2,2)

System Info

No response

Checklist

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingcustom gym envIssue related to Custom Gym Envhelp wantedHelp from contributors is neededtrading warningTrading with RL is usually a bad idea

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions