Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions rsl_rl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from itertools import chain
from tensordict import TensorDict

from rsl_rl.modules import ActorCritic, ActorCriticRecurrent
from rsl_rl.modules import ActorCritic, ActorCriticPerceptive, ActorCriticRecurrent
from rsl_rl.modules.rnd import RandomNetworkDistillation
from rsl_rl.storage import RolloutStorage
from rsl_rl.utils import string_to_callable
Expand All @@ -20,12 +20,12 @@
class PPO:
"""Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347)."""

policy: ActorCritic | ActorCriticRecurrent
policy: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive
"""The actor critic module."""

def __init__(
self,
policy: ActorCritic | ActorCriticRecurrent,
policy: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive,
num_learning_epochs: int = 5,
num_mini_batches: int = 4,
clip_param: float = 0.2,
Expand Down
2 changes: 2 additions & 0 deletions rsl_rl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""Definitions for neural-network components for RL-agents."""

from .actor_critic import ActorCritic
from .actor_critic_perceptive import ActorCriticPerceptive
from .actor_critic_recurrent import ActorCriticRecurrent
from .rnd import RandomNetworkDistillation, resolve_rnd_config
from .student_teacher import StudentTeacher
Expand All @@ -14,6 +15,7 @@

__all__ = [
"ActorCritic",
"ActorCriticPerceptive",
"ActorCriticRecurrent",
"RandomNetworkDistillation",
"StudentTeacher",
Expand Down
5 changes: 2 additions & 3 deletions rsl_rl/modules/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ def __init__(
assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations."
num_critic_obs += obs[obs_group].shape[-1]

self.state_dependent_std = state_dependent_std

# Actor
self.state_dependent_std = state_dependent_std
if self.state_dependent_std:
self.actor = MLP(num_actor_obs, [2, num_actions], actor_hidden_dims, activation)
else:
Expand Down Expand Up @@ -121,7 +120,7 @@ def action_std(self) -> torch.Tensor:
def entropy(self) -> torch.Tensor:
return self.distribution.entropy().sum(dim=-1)

def _update_distribution(self, obs: TensorDict) -> None:
def _update_distribution(self, obs: torch.Tensor) -> None:
if self.state_dependent_std:
# Compute mean and standard deviation
mean_and_std = self.actor(obs)
Expand Down
269 changes: 269 additions & 0 deletions rsl_rl/modules/actor_critic_perceptive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch
import torch.nn as nn
from tensordict import TensorDict
from torch.distributions import Normal
from typing import Any

from rsl_rl.networks import CNN, MLP, EmpiricalNormalization

from .actor_critic import ActorCritic


class ActorCriticPerceptive(ActorCritic):
def __init__(
self,
obs: TensorDict,
obs_groups: dict[str, list[str]],
num_actions: int,
actor_obs_normalization: bool = False,
critic_obs_normalization: bool = False,
actor_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
critic_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
actor_cnn_cfg: dict[str, dict] | dict | None = None,
critic_cnn_cfg: dict[str, dict] | dict | None = None,
activation: str = "elu",
init_noise_std: float = 1.0,
noise_std_type: str = "scalar",
state_dependent_std: bool = False,
**kwargs: dict[str, Any],
) -> None:
if kwargs:
print(
"PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: "
+ str([key for key in kwargs])
)
nn.Module.__init__(self)

# Get the observation dimensions
self.obs_groups = obs_groups
num_actor_obs_1d = 0
self.actor_obs_groups_1d = []
actor_in_dims_2d = []
actor_in_channels_2d = []
self.actor_obs_groups_2d = []
for obs_group in obs_groups["policy"]:
if len(obs[obs_group].shape) == 4: # B, C, H, W
self.actor_obs_groups_2d.append(obs_group)
actor_in_dims_2d.append(obs[obs_group].shape[2:4])
actor_in_channels_2d.append(obs[obs_group].shape[1])
elif len(obs[obs_group].shape) == 2: # B, C
self.actor_obs_groups_1d.append(obs_group)
num_actor_obs_1d += obs[obs_group].shape[-1]
else:
raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}")
num_critic_obs_1d = 0
self.critic_obs_groups_1d = []
critic_in_dims_2d = []
critic_in_channels_2d = []
self.critic_obs_groups_2d = []
for obs_group in obs_groups["critic"]:
if len(obs[obs_group].shape) == 4: # B, C, H, W
self.critic_obs_groups_2d.append(obs_group)
critic_in_dims_2d.append(obs[obs_group].shape[2:4])
critic_in_channels_2d.append(obs[obs_group].shape[1])
elif len(obs[obs_group].shape) == 2: # B, C
self.critic_obs_groups_1d.append(obs_group)
num_critic_obs_1d += obs[obs_group].shape[-1]
else:
raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}")

# Actor CNN
if self.actor_obs_groups_2d:
assert actor_cnn_cfg is not None, "An actor CNN configuration is required for 2D actor observations."

# Check if multiple 2D actor observations are provided
if len(self.actor_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_cfg.values()):
assert len(actor_cnn_cfg) == len(self.actor_obs_groups_2d), (
"The number of CNN configurations must match the number of 2D actor observations."
)
elif len(self.actor_obs_groups_2d) > 1:
print(
"Only one CNN configuration for multiple 2D actor observations given, using the same configuration "
"for all groups."
)
actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg] * len(self.actor_obs_groups_2d)))
else:
actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg]))

# Create CNNs for each 2D actor observation
self.actor_cnns = nn.ModuleDict()
encoding_dim = 0
for idx, obs_group in enumerate(self.actor_obs_groups_2d):
self.actor_cnns[obs_group] = CNN(
input_dim=actor_in_dims_2d[idx],
input_channels=actor_in_channels_2d[idx],
**actor_cnn_cfg[obs_group],
)
print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}")
# Get the output dimension of the CNN
if self.actor_cnns[obs_group].output_channels is None:
encoding_dim += int(self.actor_cnns[obs_group].output_dim) # type: ignore
else:
raise ValueError("The output of the actor CNN must be flattened before passing it to the MLP.")
else:
self.actor_cnns = None
encoding_dim = 0

# Actor MLP
self.state_dependent_std = state_dependent_std
if self.state_dependent_std:
self.actor = MLP(num_actor_obs_1d + encoding_dim, [2, num_actions], actor_hidden_dims, activation)
else:
self.actor = MLP(num_actor_obs_1d + encoding_dim, num_actions, actor_hidden_dims, activation)
print(f"Actor MLP: {self.actor}")

# Actor observation normalization (only for 1D actor observations)
self.actor_obs_normalization = actor_obs_normalization
if actor_obs_normalization:
self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs_1d)
else:
self.actor_obs_normalizer = torch.nn.Identity()

# Critic CNN
if self.critic_obs_groups_2d:
assert critic_cnn_cfg is not None, " A critic CNN configuration is required for 2D critic observations."

# check if multiple 2D critic observations are provided
if len(self.critic_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_cfg.values()):
assert len(critic_cnn_cfg) == len(self.critic_obs_groups_2d), (
"The number of CNN configurations must match the number of 2D critic observations."
)
elif len(self.critic_obs_groups_2d) > 1:
print(
"Only one CNN configuration for multiple 2D critic observations given, using the same configuration"
" for all groups."
)
critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg] * len(self.critic_obs_groups_2d)))
else:
critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg]))

# Create CNNs for each 2D critic observation
self.critic_cnns = nn.ModuleDict()
encoding_dim = 0
for idx, obs_group in enumerate(self.critic_obs_groups_2d):
self.critic_cnns[obs_group] = CNN(
input_dim=critic_in_dims_2d[idx],
input_channels=critic_in_channels_2d[idx],
**critic_cnn_cfg[obs_group],
)
print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}")
# Get the output dimension of the CNN
if self.critic_cnns[obs_group].output_channels is None:
encoding_dim += int(self.critic_cnns[obs_group].output_dim) # type: ignore
else:
raise ValueError("The output of the critic CNN must be flattened before passing it to the MLP.")
else:
self.critic_cnns = None
encoding_dim = 0

# Critic MLP
self.critic = MLP(num_critic_obs_1d + encoding_dim, 1, critic_hidden_dims, activation)
print(f"Critic MLP: {self.critic}")

# Critic observation normalization (only for 1D critic observations)
self.critic_obs_normalization = critic_obs_normalization
if critic_obs_normalization:
self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs_1d)
else:
self.critic_obs_normalizer = torch.nn.Identity()

# Action noise
self.noise_std_type = noise_std_type
if self.state_dependent_std:
torch.nn.init.zeros_(self.actor[-2].weight[num_actions:])
if self.noise_std_type == "scalar":
torch.nn.init.constant_(self.actor[-2].bias[num_actions:], init_noise_std)
elif self.noise_std_type == "log":
torch.nn.init.constant_(
self.actor[-2].bias[num_actions:], torch.log(torch.tensor(init_noise_std + 1e-7))
)
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
else:
if self.noise_std_type == "scalar":
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
elif self.noise_std_type == "log":
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

# Action distribution
# Note: Populated in update_distribution
self.distribution = None

# Disable args validation for speedup
Normal.set_default_validate_args(False)

def _update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]) -> None:
if self.actor_cnns is not None:
# Encode the 2D actor observations
cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d]
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
# Concatenate to the MLP observations
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)

super()._update_distribution(mlp_obs)

def act(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor:
mlp_obs, cnn_obs = self.get_actor_obs(obs)
mlp_obs = self.actor_obs_normalizer(mlp_obs)
self._update_distribution(mlp_obs, cnn_obs)
return self.distribution.sample()

def act_inference(self, obs: TensorDict) -> torch.Tensor:
mlp_obs, cnn_obs = self.get_actor_obs(obs)
mlp_obs = self.actor_obs_normalizer(mlp_obs)

if self.actor_cnns is not None:
# Encode the 2D actor observations
cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d]
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
# Concatenate to the MLP observations
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)

if self.state_dependent_std:
return self.actor(obs)[..., 0, :]
else:
return self.actor(mlp_obs)

def evaluate(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor:
mlp_obs, cnn_obs = self.get_critic_obs(obs)
mlp_obs = self.critic_obs_normalizer(mlp_obs)

if self.critic_cnns is not None:
# Encode the 2D critic observations
cnn_enc_list = [self.critic_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.critic_obs_groups_2d]
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
# Concatenate to the MLP observations
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)

return self.critic(mlp_obs)

def get_actor_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
obs_list_1d = [obs[obs_group] for obs_group in self.actor_obs_groups_1d]
obs_dict_2d = {}
for obs_group in self.actor_obs_groups_2d:
obs_dict_2d[obs_group] = obs[obs_group]
return torch.cat(obs_list_1d, dim=-1), obs_dict_2d

def get_critic_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
obs_list_1d = [obs[obs_group] for obs_group in self.critic_obs_groups_1d]
obs_dict_2d = {}
for obs_group in self.critic_obs_groups_2d:
obs_dict_2d[obs_group] = obs[obs_group]
return torch.cat(obs_list_1d, dim=-1), obs_dict_2d

def update_normalization(self, obs: TensorDict) -> None:
if self.actor_obs_normalization:
actor_obs, _ = self.get_actor_obs(obs)
self.actor_obs_normalizer.update(actor_obs)
if self.critic_obs_normalization:
critic_obs, _ = self.get_critic_obs(obs)
self.critic_obs_normalizer.update(critic_obs)
5 changes: 2 additions & 3 deletions rsl_rl/modules/actor_critic_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@ def __init__(
assert len(obs[obs_group].shape) == 2, "The ActorCriticRecurrent module only supports 1D observations."
num_critic_obs += obs[obs_group].shape[-1]

self.state_dependent_std = state_dependent_std

# Actor
self.state_dependent_std = state_dependent_std
self.memory_a = Memory(num_actor_obs, rnn_hidden_dim, rnn_num_layers, rnn_type)
if self.state_dependent_std:
self.actor = MLP(rnn_hidden_dim, [2, num_actions], actor_hidden_dims, activation)
Expand Down Expand Up @@ -138,7 +137,7 @@ def reset(self, dones: torch.Tensor | None = None) -> None:
def forward(self) -> NoReturn:
raise NotImplementedError

def _update_distribution(self, obs: TensorDict) -> None:
def _update_distribution(self, obs: torch.Tensor) -> None:
if self.state_dependent_std:
# Compute mean and standard deviation
mean_and_std = self.actor(obs)
Expand Down
2 changes: 2 additions & 0 deletions rsl_rl/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

"""Definitions for components of modules."""

from .cnn import CNN
from .memory import HiddenState, Memory
from .mlp import MLP
from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization

__all__ = [
"CNN",
"MLP",
"EmpiricalDiscountedVariationNormalization",
"EmpiricalNormalization",
Expand Down
Loading