Skip to content

Releases: pytorch/rl

TorchRL 0.10.0: async LLM inference

16 Sep 13:48
Compare
Choose a tag to compare

TorchRL 0.10.0 Release Notes

What's New in 0.10.0

TorchRL 0.10.0 introduces significant advancements in Large Language Model (LLM) support, new algorithms, enhanced environment integrations, and numerous performance improvements and bug fixes.

Major Features

LLM Support and RLHF

  • vLLM Integration Revamp: Complete overhaul of vLLM support with improved batching and performance (#3158) @vmoens
  • GRPO (Generalized Reinforcement Learning from Preference Optimization): New algorithm implementation with both sync and async variants (#2970, #2997, #3006) @vmoens
  • Expert Iteration and SFT: Implementation of expert iteration algorithms and supervised fine-tuning (#3017) @vmoens
  • PPOTrainer: New high-level trainer class for PPO training (#3117) @vmoens
  • LLM Tooling: Comprehensive tooling support for LLM environments and transformations (#2966) @vmoens
  • Remote LLM Wrappers: Support for remote LLM inference with improved batching (#3116) @vmoens
  • Common LLM Generation Interface: Unified kwargs for generation across vLLM and Transformers (#3107) @vmoens
  • LLM Transforms:
    • AddThinkingPrompt transform for reasoning prompts (#3027) @vmoens
    • MCPToolTransform for tool integration (#2993) @vmoens
    • PythonInterpreter transform for code execution (#2988) @vmoens
    • LLMMaskedCategorical for masked categorical distributions (#3041) @vmoens
  • Content Management: ContentBase system for structured content handling (#2985) @vmoens
  • History Tracking: New history system for conversation management (#2965) @vmoens

New Algorithms and Training

  • Async SAC: Asynchronous implementation of Soft Actor-Critic (#2946) @vmoens
  • Discrete Offline CQL: SOTA implementation for discrete action spaces (#3098) @Ibinarriaga
  • Multi-node Ray Support: Enhanced distributed training for GRPO (#3040) @albertbou92

Environment Support

  • NPU Support: Added NPU device support for SyncDataCollector (#3155) @lowdy1
  • IsaacLab Wrapper: Integration with IsaacLab simulation framework (#2937) @vmoens
  • Complete PettingZoo State Support: Enhanced multi-agent environment support (#2953) @JGuzzi
  • Minari Integration: Support for loading datasets from local Minari cache (#3068) @Ibinarriaga

Storage and Replay Buffers

  • Compressed Storage GPU: GPU acceleration for compressed replay buffers (#3062) @aorenstein68
  • Packing: New data packing functionality for efficient storage (#3060) @vmoens
  • Ray Replay Buffer: Enhanced distributed replay buffer support (#2949) @vmoens

🔧 Improvements and Enhancements

Performance Optimizations

  • Bounded Specs Memory: Single copy optimization for bounded specifications (#2977) @vmoens
  • Log-prob Computation: Avoid unnecessary log-prob calculations when retrieving distributions (#3081) @vmoens
  • LLM Wrapper Queuing: Performance fixes in LLM wrapper queuing (#3125) @vmoens
  • vmap Deactivation: Selective vmap deactivation in objectives for better performance (#2957) @vmoens

API Improvements

  • Public SAC Methods: Exposed public methods for SAC algorithm (#3085) @vmoens
  • Composite Entropy: Fixed entropy computation for nested keys (#3101) @juandelos
  • Multi-head Entropy: Per-head entropy coefficients for PPO (#2972) @Felixs
  • ClippedPPOLoss: Support for composite value networks (#3031) @louisfaury
  • LineariseRewards: Support for negative weights (#3064) @YoannPoupart
  • GAE Typing: Improved typing with optional value networks (#3029) @louisfaury
  • Explained Variance: Optional explained variance logging (#3010) @OswaldZink
  • Frame Control: Worker-level control over frames_per_batch (#3020) @alexghh

Developer Experience

  • Colored Logger: Enhanced logging with colored output (#2967) @vmoens
  • Better Error Handling: Improved error catching in env.rollout and rb.add (#3102) @vmoens
  • Warning Management: Better warning control for various components (#3099, #3115) @vmoens
  • Faster Tests: Optimized test suite performance (#3162) @vmoens

Bug Fixes

Core Functionality

Environment and Wrapper Fixes

  • TransformedEnv: Fixed in-place modification of specs (#3076) @vmoens
  • Parallel Environments: Fixed partial and nested done states (#2959) @vmoens
  • Gym Actions: Fixed single action passing when action key is not "action" (#2942) @vmoens
  • Brax Memory: Fixed memory leak in Brax environments (#3052) @vmoens
  • Atari Patching: Fixed patching for NonTensorData observations (#3091) @marcosGR

Collector and Replay Buffer Fixes

  • LLMCollector: Fixed trajectory collection when multiple trajectories complete (#3018) @albertbou92
  • Postprocessing: Consistent postprocessing when using replay buffers in collectors (#3144) @vmoens
  • Weight Updates: Fixed original weights retrieval in collectors (#2951) @vmoens
  • Transform Handling: Fixed transform application and metadata preservation (#3047, #3050) @vmoens

Compatibility and Infrastructure

  • PyTorch 2.1.1: Fixed compatibility issues (#3157) @vmoens
  • NPU Attribute: Fixed missing NPU attribute (#3159) @vmoens
  • CUDA Graph: Fixed update_policy_weights_ with CUDA graphs (#3003) @vmoens
  • Stream Capturing: Robust CUDA stream capturing calls (#2950) @vmoens

Documentation and Tutorials

  • DQN with RNN Tutorial: Upgraded tutorial with latest best practices (#3152) @vmoens
  • LLM API Documentation: Comprehensive documentation for LLM environments and transforms (#2991) @vmoens
  • Multi-head Entropy: Better documentation for multi-head entropy usage (#3109) @vmoens
  • LSTM Module: Fixed import examples in documentation (#3138) @arvindcr4
  • A2C Documentation: Updated AcceptedKeys documentation (#2987) @simeet-n
  • History API: Added missing docstrings for History functionality (#3083) @vmoens
  • Multi-agent PPO: Fixed tutorial issues (#2940) @matteobettini
  • WeightUpdater: Updated documentation after renaming (#3007) @albertbou92

Infrastructure and CI

  • Pre-commit Updates: Updated formatting and linting tools (#3108) @vmoens
  • Benchmark CI: Fixed benchmark runs and added missing dependencies (#3092, #3163) @vmoens
  • Windows CI: Fixed Windows continuous integration (#3028) @vmoens
  • Old Dependencies: Fixed CI for older dependency versions (#3165) @vmoens
  • C++ Linting: Fixed C++ code linting issues (#3129) @vmoens
  • Build System: Improved pyproject.toml usage and versioning (#3089, #3166) @vmoens

🏆 Contributors

Special thanks to all contributors who made this release possible:

  • @albertbou92 (Albert Bou) - GRPO multi-node support and LLM improvements
  • @Ibinarriaga - CQL offline algorithm and Minari integration
  • @aorenstein68 (Adrian Orenstein) - Compressed storage GPU support
  • @louisfaury (Louis Faury) - Categorical spec and PPO improvements
  • @LucaCarminati (Luca Carminati) - Binary tensor fixes
  • @JGuzzi (Jérôme Guzzi) - PettingZoo state support
  • @lowdy1 - NPU device support
  • @Felixs (Felix Sittenauer) - Multi-head entropy coefficients
  • @YoannPoupart (Yoann Poupart) - LineariseRewards improvements
  • @OswaldZink (Oswald Zink) - Explained variance logging
  • @alexghh (Alexandre Ghelfi) - Frame control improvements
  • @marcosGR (Marcos Galletero Romero) - Atari patching fixes
  • @matteobettini (Matteo Bettini) - Tutorial fixes
  • @simeet-n (Simeet Nayan) - Documentation improvements
  • @arvindcr4 - Documentation fixes
  • @felixy12 (Felix Yu) - State dict reference fixes
  • @SendhilPanchadsaram (Sendhil Panchadsaram) - Documentation typo fixes
  • @abhishekunique (Abhishek) - WandB logger and value estimation improvements
  • @骑马小猫 - DQN module typo fix
  • @ZainRizvi (Zain Rizvi) - CI improvements and meta-pytorch migration
  • @mikayla-gawarecki (Mikayla Gawarecki) - Usage tracking and ConditionalPolicySwitch

🔗 Compatibility

  • PyTorch: Compatible with PyTorch 2.1.1+ -- recommended >=2.8.0,<2.9.0 for full compatibility
  • TensorDict: Updated to work with TensorDict 0.10+
  • Python: Supports Python 3.9+

📦 Installation

pip install torchrl==0.10.0

For the latest features:

pip install git+https://github.com/pytorch/rl.git@release/0.10.0

v0.9.2: Bug fixes and perf improvements

17 Jul 17:11
Compare
Choose a tag to compare

TorchRL 0.9.2 Release Notes

This release focuses on bug fixes, performance improvements, and code quality enhancements.

🚀 New Features

  • LineariseRewards: Now supports negative weights for more flexible reward shaping (#3064)

🐛 Bug Fixes

  • Fixed policy reference handling in state dictionaries (#3043)
  • Improved unbatched data handling in LLM wrappers (#3070)
  • Fixed cross-entropy log-probability computation for batched inputs (#3080)
  • Fixed Binary clone() operations (#3077)
  • Fixed in-place spec modifications in TransformedEnv (#3076)

⚡ Performance Improvements

  • Optimized distribution sampling by avoiding unnecessary log-probability computations (#3081)

🔧 Code Quality

  • Standardized coefficient naming in A2C and PPO algorithms (#3079)

📦 Installation

pip install torchrl==0.9.2

Thanks to all contributors: @felixy12, @Xmaster6y, @louisfaury and @LCarmi

v0.9.1: fix for history-based vLLM and Transformers wrappers

11 Jul 15:48
Compare
Choose a tag to compare

Fixes an critical issue with vLLMWrapper and TransformersWrapper, where a stack of History objects is resent to stack, resulting in a bug.

TorchRL 0.9.0 Release Notes

10 Jul 15:28
7e8f940
Compare
Choose a tag to compare

We are excited to announce the release of TorchRL 0.9.0! This release introduces a comprehensive LLM API for language model fine-tuning, extensive torch.compile compatibility across all algorithms, and numerous performance improvements.

🚀 Major Features

🤖 LLM API - Complete Framework for Language Model Fine-tuning

TorchRL now includes a comprehensive LLM API for post-training and fine-tuning of language models! This new framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:

The LLM API follows TorchRL's modular design principles, allowing you to mix and match components for your specific use case. Check out the complete documentation and GRPO implementation example to get started!

Unified LLM Wrappers

  • TransformersWrapper: Seamless integration with Hugging Face models
  • vLLMWrapper: High-performance inference with vLLM engines
  • Consistent API: Both wrappers provide unified input/output interfaces using TensorClass objects
  • Multiple input modes: Support for history, text, and tokenized inputs
  • Configurable outputs: Text, tokens, masks, and log probabilities

Advanced Conversation Management

  • History class: Advanced bidirectional conversation management with automatic chat template detection
  • Multi-model support: Automatic template detection for various model families (Qwen, DialoGPT, Falcon, DeepSeek, etc.)
  • Assistant token masking: Identify which tokens were generated by the assistant for RL applications
  • Tool calling support: Handle function calls and tool responses in conversations
  • Batch operations: Efficient tensor operations for processing multiple conversations

🛠️ Tool Integration

  • PythonInterpreter transform: Built-in Python code execution capabilities
  • MCPToolTransform: General tool calling support
  • Extensible architecture: Easy to add custom tool transforms
  • Safe execution: Controlled environment for tool execution

🎯 Specialized Objectives

  • GRPOLoss: Group Relative Policy Optimization loss function optimized for language models
  • SFTLoss: Supervised fine-tuning loss with assistant token masking support
  • MCAdvantage: Monte-Carlo advantage estimation for LLM training
  • KL divergence rewards: Built-in KL penalty computation

⚡ High-Performance Collectors

  • LLMCollector: Async data collection with distributed training support
  • RayLLMCollector: Multi-node distributed collection using Ray
  • Weight synchronization: Automatic model weight updates across distributed setups
  • Trajectory management: Efficient handling of variable-length conversations

🔄 Flexible Environments

  • ChatEnv: Transform-based architecture for conversation management
  • Transform-based rewards: Modular reward computation and data loading
  • Dataset integration: Built-in support for loading prompts from datasets
  • Thinking prompts: Chain-of-thought reasoning support

📚 Complete Implementation Example

A full GRPO implementation is provided in sota-implementations/grpo/ with:

  • Multi-GPU support with efficient device management
  • Mixed precision training
  • Gradient accumulation
  • Automatic checkpointing
  • Comprehensive logging with Weights & Biases
  • Hydra configuration system
  • Asynchronous training support with Ray

🆕 New Features

LLM API Components

  • LLMMaskedCategorical (#3041) - Categorical distribution with masking for LLM token selection
  • AddThinkingPrompt transform (#3027) - Add chain-of-thought reasoning prompts
  • MCPToolTransform (#2993) - Model Context Protocol tool integration
  • PythonInterpreter transform (#2988) - Python code execution in LLM environments
  • ContentBase (#2985) - Base class for structured content in LLM workflows
  • LLM Tooling (#2966) - Comprehensive tool integration framework
  • History API (#2965) - Advanced conversation management system
  • LLM collector (#2879) - Specialized data collection for language models
  • vLLM wrapper (#2830) - High-performance vLLM integration
  • Transformers policy (#2825) - Hugging Face transformers integration

Environment Enhancements

  • IsaacLab wrapper (#2937) - NVIDIA Isaac Lab environment support
  • Complete PettingZooWrapper state support (#2953) - Full state management for multi-agent environments
  • ConditionalPolicySwitch transform (#2711) - Dynamic policy switching based on conditions
  • Async environments (#2864) - Asynchronous environment execution
  • VecNormV2 (#2867) - Improved vector normalization with batched environment support

Algorithm Improvements

  • Async GRPO (#2997) - Asynchronous Group Relative Policy Optimization
  • Expert Iteration and SFT (#3017) - Expert iteration and supervised fine-tuning algorithms
  • Async SAC (#2946) - Asynchronous Soft Actor-Critic implementation
  • Multi-node Ray support for GRPO (#3040) - Distributed GRPO training

Data Management

  • RayReplayBuffer (#2835) - Distributed replay buffer using Ray
  • RayReplayBuffer usage examples (#2949) - Comprehensive usage examples
  • Policy factory for collectors (#2841) - Flexible policy creation in collectors
  • Local and Remote WeightUpdaters (#2848) - Distributed weight synchronization

Performance Optimizations

  • Deactivate vmap in objectives (#2957) - Improved performance by disabling vectorized operations
  • Hold a single copy of low/high in bounded specs (#2977) - Memory optimization for bounded specifications
  • Use TensorDict._new_unsafe in step (#2905) - Performance improvement in environment steps
  • Memoize calls to encode and related methods (#2907) - Caching for improved performance

Utility Features

  • Compose.pop (#3026) - Remove transforms from composition
  • Add optional Explained Variance logging (#3010) - Enhanced logging capabilities
  • Enabling worker level control on frames_per_batch (#3020) - Granular control over data collection
  • collector.start() (#2935) - Explicit collector lifecycle management
  • Timer transform (#2806) - Timing capabilities for environments
  • MultiAction transform (#2779) - Multi-action environment support
  • Transform for partial steps (#2777) - Partial step execution support

🔧 Performance Improvements

  • VecNormV2: Improved vector normalization with better bias correction timing (#2900, #2901)
  • MaskedCategorical cross_entropy: Faster loss computation (#2882)
  • Avoid padding in transformer wrapper: Memory and performance optimization (#2881)
  • Set padded token log-prob to 0.0: Improved numerical stability (#2857)
  • Better device checks: Enhanced device management (#2909)
  • Local dtype maps: Optimized dtype handling (#2936)

🐛 Bug Fixes

LLM API Fixes

  • Variable length vllm wrapper answer stacking (#3049) - Fixed stacking issues with variable-length responses
  • LLMCollector trajectory collection methods (#3018) - Fixed trajectory collection when multiple trajectories complete simultaneously
  • Fix IFEval GRPO runs (#3012) - Resolved issues with IFEval dataset runs
  • Fix cuda cache empty in GRPO scripts (#3016) - Memory management improvements
  • Right log-prob size in transformer wrapper (#2856) - Fixed log probability tensor sizing
  • Fix gc import (#2862) - Import error resolution

Environment Fixes

  • Brax memory leak fix (#3052) - Resolved memory leaks in Brax environments
  • Fix behavior of partial, nested dones in PEnv and TEnv (#2959) - Improved done state handling
  • Fix shifted value computation with an LSTM (#2941) - LSTM value computation fixes
  • Fix single action pass to gym when action key is not "action" (#2942) - Action key handling improvements
  • Fix PEnv device copies (#2840) - Device management in parallel environments

Data Management Fixes

-...

Read more

v0.8.1: Async collectors patch

16 May 16:32
Compare
Choose a tag to compare

Async Collector execution

This release major upgrades is a patch to collector.start() to allow collectors (single or multi-proc) to run asynchronously. #2935

An example is provided in the async SAC example. #2946

Single-agent reset

Fixes #2958 where partial resets are not handled correctly when a BatchedEnv is transformed - as the "done" checks were inconsistent. We now enforce that root "_reset" entries always precede their respective leaves.

Fix shifted values in GAE using LSTMs

Using an LSTM within GAE is facilitated by ensuring that shifted=True and shifted=False work properly (with appropriate warnings/errors if other hyperparameters need to be set). #2941

Full Changelog: v0.8.0...v0.8.1

v0.8.0: Async envs and better weight update API

30 Apr 14:58
Compare
Choose a tag to compare

TorchRL v0.8.0: Async envs and better weight update API

  • Async environments: #2864 introduces asynchronous environments, which can be built using different backends (currently
    "threading" or "multiprocessing"). Instantiating an async env is roughly the same as a parallel one:
    from torchrl.envs import AsyncEnvPool
    env = AsyncEnvPool([partial(GymEnv, "Pendulum-v1"), partial(GymEnv, "Pendulum-v1")], backend="threading")
    These environments support the regular environment methods (reset, step or rollout) but their main advantage lies
    in their new async methods:
    s0 = env.rand_action(env.reset())
    env.async_step_send(s0)
    # receive
    result = env.async_step_recv()
    In this example, result will contain the results of the call to step for one or two environments. The environment indices
    can be found in the result['env_index'] entry (the name of that key is stored in env._env_idx_key).
  • Support for environments with tensorclass attributes (#2788)
  • Distributed RayReplayBuffer (#2835)
  • Gymnasium 1.1 compatibility (#2898): we managed to make TorchRL compatible with Gymnasium 1.1 as this version lets
    users choose how to handle partial resets, which facilitates integration in the library.
  • VecNormV2, a new version of vecnorm which is more numerically stable and easier to handle. This can be created directly
    through the usual VecNorm by passing the new_api keyword argument.
  • policy factory for collectors: you can now pass a factory for your policy instead of passing the real object.
    Given that the collector will update the weights of the policy when asked to, this will in most cases not cause any
    synchronization problem with the copy that is used by the training pipeline.
  • An Update API for policy weights in collector: we have isolated the weight update API in a torchrl.collectors.WeightUpdaterBase
    abstract class. This should the entry point for any user wanting to implement their own weight update strategy, alleviating
    the need to subclass or patch the collector or the policy directly.

Packaging

We relaxed TorchRL dependency to make it compatible with any pytorch version. The current status is:

  • tensordict dependency will from now on be enforced (>=0.8.1,<0.9.0 for this release)
  • For PyTorch prior to 2.7.0, backward compatibility is guaranteed to some extend (most classes should work, unless new features are used) but C++ binaries (for prioritized replay buffers) will not work.
  • For PyTorch >= 2.7.0, C++ binaries should work across versions. In other words, torchrl binaries for 0.8.0 will work with PyTorch 2.7.0, 2.8.0 etc., and the same goes for the future TorchRL 0.9.0... A big thanks to @janeyx99 for enabling this!

New features

[Feature] Add EnvBase.all_actions (#2780) (67c3e9a) by @kurtamohler ghstack-source-id: 7abf9d469f740be5f14daffa2330811f7572dad9
[Feature] Add MCTSForest/Tree.to_string (#2794) (f862669) by @kurtamohler ghstack-source-id: 2127bf24d66e44fb310d12ff5f72e92aa0371cd7
[Feature] Add include_hash_inv arg to ChessEnv (#2766) (3be85c6) by @kurtamohler ghstack-source-id: f6920d781835902a6db02f74c5e5a3041243c5e3
[Feature] Add option for auto-resetting envs in GAE (#2851) (f5f3ae4) by @lin-erica Co-authored-by: Erica Lin elin@theaiinstitute.com
[Feature] Async environments (#2864) (4f00025) by @vmoens ghstack-source-id: 0a70ce0129d2ee6f85bb22adda3c332ff65e7501
[Feature] Capture wrong spec transforms (1/N) (#2805) (d3dca73) by @vmoens ghstack-source-id: f2d938b3dfe88af66622099f60cd7e3026289a02
[Feature] Collectors for async envs (#2893) (4ba5066) by @vmoens ghstack-source-id: 764c21d0f2c3b217440e1a6f12ee797b17820c1d
[Feature] DensifyReward postproc (#2823) (53065cf) by @vmoens ghstack-source-id: ef6a0f52601642c8944f63f9e3ac9e963425734e
[Feature] Dynamic specs for make_composite_from_td (#2829) (413571b) by @vmoens ghstack-source-id: 79e31e737c9f67ff20ce9fe32081e5b0a83de947
[Feature] Enable Hash.inv (#2757) (32c4623) by @kurtamohler ghstack-source-id: 956708121067855e519382a37764f06f53b16aa7
[Feature] Env with tensorclass attributes (#2788) (ab76027) by @vmoens ghstack-source-id: dc00ea3d23e015756974cd5c2ce638b55e5f6f92
[Feature] Gymnasium 1.1 compatibility (#2898) (78cd755) by @vmoens ghstack-source-id: e0891867f4318380f01c15449f9f26070b78536d
[Feature] History API (#2890) (fd10fe2) by @vmoens ghstack-source-id: 5b9723f6e1c327625e1a9be6f6eac68b91ed8492
[Feature] History.default_spec (#2894) (8ce11a8) by @vmoens ghstack-source-id: 40b8a492765a85adaccb591f1bc173754bacc313
[Feature] Local and Remote WeightUpdaters (#2848) (27d3680) by @vmoens ghstack-source-id: 2962530f87b596d038e3a13a934ea09064af2964
[Feature] Make PPO ready for text-based data (#2857) (595ddb4) by @vmoens ghstack-source-id: eeda5e2355e573e74cf7c080994cd47520ecd45b
[Feature] MultiAction transform (#2779) (621776a) by @vmoens ghstack-source-id: 0a6f7f916ee6f9c6d450c511385bdfdb1d911da0
[Feature] NonTensor batched arg (#2816) (b97bdb5) by @vmoens ghstack-source-id: c6de1bd1f1475b8d02df2ff3eb7438a50f2ae450
[Feature] Pass lists of policy_factory (#2888) (82f8ec2) by @vmoens ghstack-source-id: e42b100096c6e38365f8a80681473746f51d8a77
[Feature] RayReplayBuffer (#2835) (50af984) by @vmoens ghstack-source-id: 32eff06494037a1a30e532539794035c035f1e81
[Feature] Set padded token log-prob to 0.0 (#2856) (b9ddfa9) by @vmoens ghstack-source-id: 2b2993e0b15afae17326e6583390d57068712d4f
[Feature] Support lazy tensordict inputs in ppo loss (#2883) (c9caf3d) by @vmoens ghstack-source-id: 89098ba3ca61b1524aeddc68f54c377f29c8dc8b
[Feature] TensorDictPrimer with single default_value callable (#2732) (59e8545) by @vmoens ghstack-source-id: a9a677f24fc1e6a47312d0a96ab60daae543ff78
[Feature] Timer transform (#2806) (104b880) by @vmoens ghstack-source-id: e42f2aece15f90afc457e1fb3e41a1f7be1a6a85
[Feature] Transform for partial steps (#2777) (7c034e3) by @vmoens ghstack-source-id: 587f91e33dfe1d59b73c4b2f2f1c21760ee79d2e
[Feature] VecNormV2 (#2867) (40fcdb6) by @vmoens ghstack-source-id: 639d07ff54be200d54621c2c4619ebd0d3d7d79e
[Feature] VecNormV2: Usage with batched envs (#2901) (b08e7ac) by @vmoens ghstack-source-id: 5e14ed982b71b0e5192b0687c5259a3b49a81157
[Feature] pass policy-factory in mp data collectors (#2859) (31af2c5) by @vmoens ghstack-source-id: bce8abe9853d5ec187f91ffbcd8b940fa18ec8ab
[Feature] policy factory for collectors (#2841) (49a8a42) by @vmoens ghstack-source-id: 96b928e938b8b07fc7de23483358202737571f8e
[Feature] reset_time in Timer (#2807) (5a46379) by @vmoens ghstack-source-id: 36a74fd20b78e1cdde6bca19b4f95c3d9062d761
[Feature] transformers policy (#2825) (eea932c) by @vmoens ghstack-source-id: 870c221b4ebae132a44944f0be0ee78da540d115

Fixes

[BugFix] Apply inverse transform to input of TransformedEnv._reset (#2787) (1ed5d29) by @kurtamohler ghstack-source-id: 5f7c1fbd19b716f2b1602c34cf2ae1362f7bc7f6
[BugFix] Avoid calling reset during env init (#2770) (09e93c1) by @vmoens ghstack-source-id: 5ab8281c34aacfd7dbbfc0e285d88bcae0aededf
[BugFix] Ensure that Composite.set returns self as TensorDict does (#2784) (e084c02) by @vmoens ghstack-source-id: 23fe46b61dc2c9548fd9de7e4100431918fd0370
[BugFix] Fix .item() warning on tensors that require grad (#2885) (b66fcd4) by @vmoens ghstack-source-id: 502bdda3f5700dc900cf5c748839c965b1d67c1b
[BugFix] Fix KL penalty (#2908) (96c3003) by @vmoens ghstack-source-id: 475dccb0bcddbfe3bd2d826c5389834fb95e1ab8
[BugFix] Fix MultiAction reset (#2789) (76aa9bc) by @kurtamohler ghstack-source-id: a2f7bfdd7522a214430182dac65687a977b1a10d
[BugFix] Fix PEnv device copies (#2840) (6e40548) by @vmoens ghstack-source-id: df39fd2e4cd72f24c645b0ac32b46ab3e8d847fc
[BugFix] Fix batch_locked check in check_env_specs + error message callable (#2817) (9c98b82) by @vmoens ghstack-source-id: c722b164133c27c05dd21add3e7f3158189dd515
[BugFix] Fix calls to _reset_env_preprocess (#2798) (ea76ffb) by @vmoens ghstack-source-id: 59925635a87b196a5bcb0fb251afe4cc7b8b103e
[BugFix] Fix collector timeouts (#2774) (f6084b6) by @vmoens ghstack-source-id: cb71d95143beb22db1fe1752e72f70c19f43be79
[BugFix] Fix collector with no buffers and devices (#2809) (d4f8846) by @vmoens ghstack-source-id: 5367df9fcfdf549108be852476b049a0b978e348
[BugFix] Fix compile compatibility of PPO losses (#2889) (9bc85f4) by @vmoens ghstack-source-id: b346033641e5d27560fbfa011a006446e56a4e31
[BugFix] Fix composite setitem (#2778) (c2a149d) by @vmoens ghstack-source-id: f33b49beb4cf8c0c8b156559b1abbee8ac77db20
[BugFix] Fix env.full_done_specs (#2815) (f5c0666) by @vmoens ghstack-source-id: ba0d371d10b3f46ec1172fbec639ccc4d5559659
[BugFix] Fix forced batch-size in _skip_tensordict (#2808) (3acf491) by @vmoens ghstack-source-id: dac84e8b8835e870bce1772d7893c30b6f9af59c
[BugFix] Fix gc import (#2862) (a183f02) by @vmoens ghstack-source-id: b732d4f805d98ceaaa45326d619fce623c10482f
[BugFix] Fix lazy-stack in RBs (#2880) (e80732e) by @vmoens ghstack-source-id: 38399ee991bc065445f4eb1c84b71e7d844d794c
[BugFix] Fix property getter in RayReplayBuffer (#2869) (04d70c1) by @vmoens
[BugFix] Fix slow and flaky non-tensor parallel env test (#2926) by @vmoens ghstack-source-id: fcb5caa56e05176958b3468a7d6f69e363cfe558
[BugFix] Fix update shape mismatch in _skip_tensordict (#2792) (3e42e7a) by @vmoens ghstack-source-id: 27e7d444c126e48fdb70d951a0cc7beaee1db3a8
[BugFix] Fixed VideoRecorder crash when passing fps (#2827) (5ec9bc5) by Alexandre Brown
[BugFix] GAE warning when gamma/lmbda are tensors (#2838) (d561115) by @louisfaury Co-authored-by: Louis Faury louis.faury@helsing.ai
[BugFix] Keep original class in LazyStackStorage through lazy_stack (#2873) (70f5c06) by @vmoens ghstack-source-id: 661cd65c86648ffb2ee6ead40110ac3d57477514
[BugFix] Non...

Read more

0.7.2: ParallelEnv fix

10 Mar 16:04
Compare
Choose a tag to compare

We are releasing TorchRL 0.7.2, a minor update that addresses several important bug fixes to improve the stability and reliability of our library.

This release is particularly crucial as it resolves a critical issue (#2840) where, under certain conditions, the device setting of the parallel environment would prevent the tensors in the buffers from being properly cloned. This resulted in rollouts returning the same tensor instances across steps, potentially leading to incorrect behavior and results.

Due to the severity of this bug, we strongly recommend that all users upgrade to TorchRL 0.7.2 to ensure the accuracy and reliability of their experiments.

The full list of changes can be found below:

Full Changelog: v0.7.1...v0.7.2

0.7.1: Bug fixes and documentation improvements

18 Feb 11:16
Compare
Choose a tag to compare

We are pleased to announce the release of torchrl v0.7.1, which includes several bug fixes, documentation updates, and backend improvements.

Bug Fixes

  • Fixed collector timeouts (#2774)
  • Fixed composite setitem (#2778)
  • Ensured that Composite.set returns self as TensorDict does (#2784)
  • Fixed PPOs with composite distribution (#2791)
  • Used brackets to get non-tensor data in gym envs (#2769)
  • Avoided calling reset during env init (#2770)
  • NonTensor should not convert anything to numpy (#2771)

Documentation Updates:

  • Fixed tutorials (#2768)
  • Solved ref issues in docstrings (#2776)
  • Fixed formatting errors (#2786)

Backend Improvements:

  • Made better logits in cost tests (#2775)
  • Ensured abstractmethods are implemented for specs (#2790)
  • Removed deprec specs from tests (#2767)

Thank you to @antoinebrl, and @louisfaury for contributing to this release!

Full Changelog: v0.7.0...v0.7.1

0.7.0: Compile compatibility, Chess and better multi-head policies

05 Feb 21:54
Compare
Choose a tag to compare

As always, we want to warmly thank the RL community who's supporting this project. A special thanks to our first time
contributors:

as well as all the users who wrote issues, suggestions, started discussions here, on discord,
on the pytorch forum or elsewhere! We value your feedback!

BC-Breaking changes and Deprecated behaviors

Removed classes

As announced, we removed the following classes:

  • AdditiveGaussianWrapper
  • InPlaceSampler
  • NormalParamWrapper
  • OrnsteinUhlenbeckProcessWrapper

Default MLP config

The default MLP depth has passed from 3 to 0 (i.e., now MLP(in_features=3, out_features=4) is equivalent to a regular
nn.Linear layer).

Locking envs

Environments specs are now carefully locked by default (#2729, #2730). This means that

env.observation_spec = spec

is allowed (specs will be unlocked/re-locked automatically) but

env.observation_spec["value"] = spec

will not work. The core idea here is that we want to cache as much info as we can, such as action keys or whether
the env has dynamic specs. We can only do that if we can guarantee that the env has not been modified. Locking the specs
provides us such guarantee.
Note that a version of this already existed but it was not as robust as the new one.

Changes to composite distributions

TL;DR: We're changing the way log-probs and entropies are collected and written in ProbabilisticTensorDictModule and
in CompositeDistribution. The "sample_log_prob" default key will soon be "<value>_log_prob (or
("path", "to", "<value>_log_prob") for nested keys). For CompositeDistribution, a different log-prob will be
written for each leaf tensor in the distribution. This new behavior is controlled by the
tensordict.nn.set_composite_lp_aggregate(mode: bool) function or by the COMPOSITE_LP_AGGREGATE environment variable.
We strongly encourage users to adopt the new behavior by setting tensordict.nn.set_composite_lp_aggregate(False).set()
at the beginning of their training script.

The behavior of CompositeDistribution and its interaction with on-policy losses such as PPO has changed.
The PPO documentation now includes a section about multi-head policies and the examples also give such information.

See the tensordict v0.7.0 release notes or #2707 to know more.

[Deprecation] Change the default MLP depth (#2746) (12e6bce) by @vmoens ghstack-source-id: bd34b8e9112c4fc3a30bd095e3ac073a7d0b5469
[Deprecation] Gracing old *Spec with v0.8 versioning (#2751) (fa697fe) by @vmoens ghstack-source-id: e7c6e0a4b8520da887fe7e602a351c3c72a08c4c
[Deprecation] Remove AdditiveGaussianWrapper (#2748) (6c7f4fb) by @vmoens ghstack-source-id: 78f248e1239a04fc5213aa4418a158f741679593
[Deprecation] Remove InPlaceSampler (#2750) (0feef11) by @vmoens ghstack-source-id: eeae1bf0611a5d293f533767eee7b9700e720cc8
[Deprecation] Remove NormalParamWrapper (#2747) (a38604e) by @vmoens ghstack-source-id: 4a70178f54f9e25d602c86a0b61248d66f3e39bd
[Deprecation] Remove OrnsteinUhlenbeckProcessWrapper (#2749) (0111a87) by @vmoens ghstack-source-id: 401fdfaca2e27122d5a67fc7177e1015047f0098

New features

Compile compatibility

We gave a strong focus on a better compatibility with torch.compile across the SOTA training scripts which now
all accept a compile=1 argument. The overall speedups range from 1 to 4x

Screenshot 2025-02-05 at 21 20 54

Loss module speedups are displayed in the README.md page.

Replay buffers are also mostly compatible with compile now (with the notable exception of distributed and memmaped ones).

Specs: auto_spec_, <attr>_spec_unbatched

You can now use env.auto_spec_ to set the specs automatically based on a dummy rollout.

For batched environments, the unbatched spec can now be accessed via env.<attr>_spec_unbatched. This is useful to
create random policies, for example.

New transforms

We added TrajCounter (#2532), Hash and Tokenizer (#2648, #2700) and LineariseReward (#2681).

LazyStackStorage

We provide a new ListStorage-based storage (LazyStackStorage) that automatically represents samples as a LazyStackedTensorDict
which makes it easy to store ragged tensors (although not contiguously in memory) #2723.

ChessEnv

A new torchrl.envs.ChessEnv allows users to train agents to play chess!

Tutorials on exporting torchrl modules

We also opensourced a tutorial to export TorchRL modules on hardware: #2557

Full list of features

[Feature, Test] Adding tests for envs that have no specs (#2621) (c72583f) by @vmoens ghstack-source-id: 4c75691baa1e70f417e518df15c4208cff189950
[Feature,Refactor] Chess improvements: fen, pgn, pixels, san, action mask (#2702) (d425777) by @vmoens ghstack-source-id: f294a2bc99a17911c9b62558d530b148d3c0350f
[Feature] A2C compatibility with compile (#2464) (507766a) by @vmoens ghstack-source-id: 66a7f0d1dd82d6463d61c1671e8e0a14ac9a55e7
[Feature] ActionDiscretizer custom sampling (#2609) (3da76f0) @oslumbers Co-authored-by: Oliver Slumbers oliver.slumbers@helsing.ai
[Feature] Add Hash transform (#2648) (50011dc) @kurtamohler ghstack-source-id: dccf63fe4f9d5f76947ddb7d5dedcff87ff8cdc5
[Feature] Add Choice spec (#2713) (9368ca6) @kurtamohler ghstack-source-id: afa315a311845ab39ade3e75046f32757f9d94f1
[Feature] Add LossModule.reset_parameters_recursive (#2546) (218d5bf) by @kurtamohler
[Feature] Add Stack transform (#2567) (594462d) by @kurtamohler
[Feature] Add deterministic_sample to masked categorical (#2708) (49d9897) by @vmoens ghstack-source-id: d34fcf9b44d7a7c60dbde80b0835189f990ef226
[Feature] Adds ordinal distributions (#2520) (c851e16) by @louisfaury Co-authored-by: @louisfaury
[Feature] Avoid some recompiles of ReplayBuffer.extend/sample (#2504) (0f29c7e) @kurtamohler
[Feature] CQL compatibility with compile (#2553) (e2be42e) by @vmoens ghstack-source-id: d362d6c17faa0eb609009bce004bb4766e345d5e
[Feature] CROSSQ compatibility with compile (#2554) (01a421e) by @vmoens ghstack-source-id: 98a2b30e8f6a1b0bc583a9f3c51adc2634eb8028
[Feature] CatFrames.make_rb_transform_and_sampler (#2643) (9ee1ae7) by @vmoens ghstack-source-id: 7ecf952ec9f102a831aefdba533027ff8c4c29cc
[Feature] ChessEnv (#2641) (17983d4) by @vmoens ghstack-source-id: 087c3b12cd621ea11a252b34c4896133697bce1a
[Feature] Composite.batch_size (#2597) (2e82cab) by @vmoens ghstack-source-id: 621884a559a71e80a4be36c7ba984fd08be47952
[Feature] Composite.pop (#2598) (8d16c12) by @vmoens ghstack-source-id: 64d5bd736657ef56e37d57726dfcfd25b16b699f
[Feature] Composite.separates (#2599) (83e0b05) by @vmoens ghstack-source-id: fbfc4308a81cd96ecc61723df8c0eb870c442def
[Feature] Custom conversion tool for gym specs (#2726) (dbc8e2e) by @vmoens ghstack-source-id: d38bb02f15267a9b1637b3ed25fb44ef013e2456
[Feature] DDPG compatibility with compile (#2555) (7d7cd95) by @vmoens ghstack-source-id: f18928a419f81794d6870fd4e9fe1205c1b137e1
[Feature] DQN compatibility with compile (#2571) (f149811) by @vmoens ghstack-source-id: 113dc8c4a5562d217ed867ace1942b2f6b8a39f9
[Feature] DT compatibility with compile (#2556) (fbfe104) by @vmoens ghstack-source-id: 362b6e88bad4397f35036391729e58f4f7e4a25d
[Feature] Discrete SAC compatibility with compile (#2569) (9e2d214) by @vmoens ghstack-source-id: ddc131acedbbe451b28758e757a8c240ebd72b80
[Feature] Ensure out-place policy compatibility in rollout and collectors (#2717) (ec370c6) by @vmoens ghstack-source-id: 41a6aa56e0a045a20224b96f9537a7ae3ae14494
[Feature] EnvBase.auto_specs_ (#2601) (d537dcb) by @vmoens ghstack-source-id: 329679238c5172d7ff13097ceaa189479d4f4145
[Feature] EnvBase.check_env_specs (#2600) (00d3199) by @vmoens ghstack-source-id: 332dbf92db496c71c5ce6aba340ad123eac0f5d6
[Feature] GAIL compatibility with compile (#2573) (6482766) by @vmoens ghstack-source-id: 98c7602ec0343d7a83cb19bddeb579484c42e77e
[Feature] IQL compatibility with compile (#2649) (2cfc2ab) by @vmoens ghstack-source-id: 77bca166701d28dd69ef3964f55ab4f3e4b17fed
[Feature] LLMHashingEnv (#2635) (30d21e5) by @vmoens ghstack-source-id: d1a20ecd023008683cf18cf9e694340cfdbdac8a
[Feature] LazyStackStorage (#2723) (fe3f00c) by @vmoens ghstack-source-id: e9c031470aa0bdafbb2b26c73c06b25685a128e5
[Feature] Linearise reward transform (#2681) (ff1ff7e) by @louisfaury Co-authored-by: @louisfaury
[Feature] Log each entropy for composite distributions in PPO (#2707) (319bb68) by @louisfaury Co-authored-by: @louisfaury
[Feature] Log pbar rate in SOTA implementations (#2662) (1ce25f1) by @vmoens ghstack-source-id: 283cc1bb4ad2d60281296d2cfb78ec41c77f4129
[Feature] MCTSForest (#2307) (e9d1677) by @vmoens ghstack-source-id: 9ac5cd3de39a4dbe1c7c33cb71ff6f45a886ae65
[Feature] Make PPO compatible with composite actions and log-probs (#2665) (256a700) by @vmoens ghstack-source-id: c41718e697f9b6edda17d4ddb5bd6d41402b7c30
[Feature] PPO compatibility with compile (#2652) (f5a187d) by @vmoens ghstack...

Read more

v0.6.0: compiled losses and partial steps

22 Oct 21:42
Compare
Choose a tag to compare

What's Changed

We introduce wrappers for ML-Agents and OpenSpiel. See the doc here for OpenSpiel and here for MLAgents.

We introduce support for [partial steps](#2377, #2381), allowing you to run rollouts that ends only when all envs are done without resetting those who have reached a termination point.

We add the capability of passing replay buffers directly to data collectors, to avoid inter-process synced communications - thereby drastically speeding up data collection. See the doc of the collectors for more info.

The GAIL algorithm has also been integrated in the library (#2273).

We ensure that all loss modules are compatible with torch.compile without graph breaks (for a typical built). Execution of compiled losses is usually in the range of 2x faster than its eager counterpart.

Finally, we have sadly decided not to support Gymnasium v1.0 and future releases as the new autoreset API is fundamentally incompatible with TorchRL. Furthermore, it does not guarantee the same level of reproducibility as previous releases. See this discussion for more information.

We provide wheels for aarch64 machines, but not being able to upload them to PyPI we provide them attached to these release notes.

Deprecations

  • [Deprecation] Deprecate default num_cells in MLP (#2395) by @vmoens
  • [Deprecations] Deprecate in view of v0.6 release #2446 by @vmoens

New environments

New features

  • [Feature] Add group_map support to MLAgents wrappers (#2491) by @kurtamohler
  • [Feature] Add scheduler for alpha/beta parameters of PrioritizedSampler (#2452) Co-authored-by: Vincent Moens by @LTluttmann
  • [Feature] Check number of kwargs matches num_workers (#2465) Co-authored-by: Vincent Moens by @antoine.broyelle
  • [Feature] Compiled and cudagraph for policies #2478 by @vmoens
  • [Feature] Consistent Dropout (#2399) Co-authored-by: Vincent Moens by @depictiger
  • [Feature] Deterministic sample for Masked one-hot #2440 by @vmoens
  • [Feature] Dict specs in vmas (#2415) Co-authored-by: Vincent Moens by @55539777+matteobettini
  • [Feature] Ensure transformation keys have the same number of elements (#2466) by @f.broyelle
  • [Feature] Make benchmarked losses compatible with torch.compile #2405 by @vmoens
  • [Feature] Partial steps in batched envs #2377 by @vmoens
  • [Feature] Pass replay buffers to MultiaSyncDataCollector #2387 by @vmoens
  • [Feature] Pass replay buffers to SyncDataCollector #2384 by @vmoens
  • [Feature] Prevent loading existing mmap files in storages if they already exist #2438 by @vmoens
  • [Feature] RNG for RBs (#2379) by @vmoens
  • [Feature] Randint on device for buffers #2470 by @vmoens
  • [Feature] SAC compatibility with composite distributions. (#2447) by @albertbou92
  • [Feature] Store MARL parameters in module (#2351) by @vmoens
  • [Feature] Support wrapping IsaacLab environments with GymEnv (#2380) by @yu-fz
  • [Feature] TensorDictMap #2306 by @vmoens
  • [Feature] TensorDictMap Query module #2305 by @vmoens
  • [Feature] TensorDictMap hashing functions #2304 by @vmoens
  • [Feature] break_when_all_done in rollout #2381 by @vmoens
  • [Feature] inline hold_out_net #2499 by @vmoens
  • [Feature] replay_buffer_chunk #2388 by @vmoens

New Algorithms

  • [Algorithm] GAIL (#2273) Co-authored-by: Vincent Moens by @Sebastian.dittert

Fixes

  • [BugFix, CI] Set TD_GET_DEFAULTS_TO_NONE=1 in all CIs (#2363) by @vmoens
  • [BugFix] Add MultiCategorical support in PettingZoo action masks (#2485) Co-authored-by: Vincent Moens by @matteobettini
  • [BugFix] Allow for composite action distributions in PPO/A2C losses (#2391) by @albertbou92
  • [BugFix] Avoid reshape(-1) for inputs to DreamerActorLoss (#2496) by @kurtamohler
  • [BugFix] Avoid reshape(-1) for inputs to objectives modules (#2494) Co-authored-by: Vincent Moens by @kurtamohler
  • [BugFix] Better dumps/loads (#2343) by @vmoens
  • [BugFix] Extend RB with lazy stack #2453 by @vmoens
  • [BugFix] Extend RB with lazy stack (revamp) #2454 by @vmoens
  • [BugFix] Fix Compose input spec transform (#2463) Co-authored-by: Louis Faury @louisfaury
  • [BugFix] Fix DeviceCastTransform #2471 by @vmoens
  • [BugFix] Fix LSTM in GAE with vmap (#2376) by @vmoens
  • [BugFix] Fix MARL-DDPG tutorial and other MODE usages (#2373) by @vmoens
  • [BugFix] Fix displaying of tensor sizes in buffers #2456 by @vmoens
  • [BugFix] Fix dumps for SamplerWithoutReplacement (#2506) by @vmoens
  • [BugFix] Fix get-related errors (#2361) by @vmoens
  • [BugFix] Fix invalid CUDA ID error when loading Bounded variables across devices (#2421) by @cbhua
  • [BugFix] Fix listing of updated keys in collectors (#2460) by @vmoens
  • [BugFix] Fix old deps tests #2500 by @vmoens
  • [BugFix] Fix support for MiniGrid envs (#2416) by @kurtamohler
  • [BugFix] Fix tictactoeenv.py #2417 by @vmoens
  • [BugFix] Fixes to RenameTransform (#2442) Co-authored-by: Vincent Moens by @thomasbbrunner
  • [BugFix] Make sure keys are exclusive in envs (#1912) by @vmoens
  • [BugFix] TensorDictPrimer updates spec instead of overwriting (#2332) Co-authored-by: Vincent Moens by @matteobettini
  • [BugFix] Use a RL-specific NO_DEFAULT instead of TD's one (#2367) by @vmoens
  • [BugFix] compatibility to new Composite dist log_prob/entropy APIs #2435 by @vmoens
  • [BugFix] torch 2.0 compatibility fix #2475 by @vmoens

Performance

  • [Performance] Faster CatFrames.unfolding with padding="same" (#2407) by @kurtamohler
  • [Performance] Faster PrioritizedSliceSampler._padded_indices (#2433) by @kurtamohler
  • [Performance] Faster SliceSampler._tensor_slices_from_startend (#2423) by @kurtamohler
  • [Performance] Faster target update using foreach (#2046) by @vmoens

Documentation

  • [Doc] Better doc for inverse transform semantic #2459 by @vmoens
  • [Doc] Correct minor erratum in knowledge_base entry (#2383) by @depictiger
  • [Doc] Document losses in README.md #2408 by @vmoens
  • [Doc] Fix README example (#2398) by @vmoens
  • [Doc] Fix links to tutos (#2409) by @vmoens
  • [Doc] Fix pip3install typos in Readme (#2342) by @43245438+TheRisenPhoenix
  • [Doc] Fix policy in getting started (#2429) by @vmoens
  • [Doc] Fix tutorials for release #2476 by @vmoens
  • [Doc] Fix wrong default value for flatten_tensordicts in ReplayBufferTrainer (#2502) by @vmoens
  • [Doc] Minor fixes to comments and docstrings (#2443) by @thomasbbrunner
  • [Doc] Refactor README (#2352) by @vmoens
  • [Docs] Use more appropriate ActorValueOperator in PPOLoss documentation (#2350) by @GaetanLepage
  • [Documentation] README rewrite and broken links (#2023) by @vmoens

Not user facing

New Contributors

As always, we want to show how appreciative we are of the vibrant open-source community that keeps TorchRL alive.

Full Changelog: v0.5.0...v0.6.0