-
Notifications
You must be signed in to change notification settings - Fork 223
Open
Labels
bugSomething isn't workingSomething isn't workingcustom gym envIssue related to Custom Gym EnvIssue related to Custom Gym Envhelp wantedHelp from contributors is neededHelp from contributors is neededtrading warningTrading with RL is usually a bad ideaTrading with RL is usually a bad idea
Description
🐛 Bug
When action space shape is not flatten, the shape is not match when clip in RecurrentPPO.collect_rollout and RecurrentActorCriticPolicy.predict.
But PPO has no problem, because it reshape action shape in ActorCriticPolicy.foward and ActorCriticPolicy.predict.
Code example
import gymnasium as gym
import numpy
import sb3_contrib
import stable_baselines3 as sb3
class TestEnv(gym.Env):
def __init__(self):
self.observation = numpy.empty((10,), dtype=numpy.float32)
self.observation_space = gym.spaces.Box(
low=-1,
high=1,
shape=(10,),
dtype=numpy.float32,
)
self.action_space = gym.spaces.Box(
low=-1,
high=1,
shape=(2, 2),
dtype=numpy.float32,
)
def reset(self, seed=None, options=None):
super().reset(seed=seed)
return self.observation, {}
def step(self, action):
return self.observation, 1, False, False, {}
if __name__ == "__main__":
vec_env = sb3.common.env_util.make_vec_env(TestEnv, n_envs=4)
model = sb3.PPO("MlpPolicy", vec_env)
model.learn(10)
model = sb3_contrib.RecurrentPPO("MlpLstmPolicy", vec_env)
model.learn(10)Relevant log output / Error message
Traceback (most recent call last):
File "/Users/yangjun/Workspace/alpha_trader/unit_test.py", line 36, in <module>
model.learn(10)
File "/Users/yangjun/Workspace/alpha_trader/venv/lib/python3.12/site-packages/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 450, in learn
return super().learn(
^^^^^^^^^^^^^^
File "/Users/yangjun/Workspace/alpha_trader/venv/lib/python3.12/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 324, in learn
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/yangjun/Workspace/alpha_trader/venv/lib/python3.12/site-packages/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 250, in collect_rollouts
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/yangjun/Workspace/alpha_trader/venv/lib/python3.12/site-packages/numpy/_core/fromnumeric.py", line 2341, in clip
return _wrapfunc(a, 'clip', a_min, a_max, out=out, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/yangjun/Workspace/alpha_trader/venv/lib/python3.12/site-packages/numpy/_core/fromnumeric.py", line 57, in _wrapfunc
return bound(*args, **kwds)
^^^^^^^^^^^^^^^^^^^^
File "/Users/yangjun/Workspace/alpha_trader/venv/lib/python3.12/site-packages/numpy/_core/_methods.py", line 117, in _clip
return um.clip(a, min, max, out=out, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: operands could not be broadcast together with shapes (4,4) (2,2) (2,2)System Info
No response
Checklist
- I have checked that there is no similar issue in the repo
- I have read the documentation
- I have provided a minimal and working example to reproduce the bug
- I have checked my env using the env checker
- I've used the markdown code blocks for both code and stack traces.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcustom gym envIssue related to Custom Gym EnvIssue related to Custom Gym Envhelp wantedHelp from contributors is neededHelp from contributors is neededtrading warningTrading with RL is usually a bad ideaTrading with RL is usually a bad idea