Skip to content

Commit 2c6e7df

Browse files
committed
[Doc] Huge doc refactoring
ghstack-source-id: 5747424 Pull-Request: #3231
1 parent bc6d282 commit 2c6e7df

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+3220
-5264
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,10 @@ repos:
5050
entry: autoflake --in-place --remove-unused-variables --remove-all-unused-imports
5151
language: system
5252
types: [python]
53+
- id: check-sphinx-section-underline
54+
name: Check Sphinx section underline lengths
55+
entry: ./scripts/check-sphinx-section-underline --fix
56+
language: script
57+
files: ^docs/.*\.rst$
58+
pass_filenames: true
59+
description: Ensure Sphinx section underline lengths match section titles.

docs/source/reference/collectors.rst

Lines changed: 45 additions & 570 deletions
Large diffs are not rendered by default.
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
.. currentmodule:: torchrl.collectors
2+
3+
.. _ref_collectors:
4+
5+
Collector Basics
6+
================
7+
8+
Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they
9+
collect data over non-static data sources and (2) the data is collected using a model
10+
(likely a version of the model that is being trained).
11+
12+
TorchRL's data collectors accept two main arguments: an environment (or a list of
13+
environment constructors) and a policy. They will iteratively execute an environment
14+
step and a policy query over a defined number of steps before delivering a stack of
15+
the data collected to the user. Environments will be reset whenever they reach a done
16+
state, and/or after a predefined number of steps.
17+
18+
Because data collection is a potentially compute heavy process, it is crucial to
19+
configure the execution hyperparameters appropriately.
20+
The first parameter to take into consideration is whether the data collection should
21+
occur serially with the optimization step or in parallel. The :class:`SyncDataCollector`
22+
class will execute the data collection on the training worker. The :class:`MultiSyncDataCollector`
23+
will split the workload across an number of workers and aggregate the results that
24+
will be delivered to the training worker. Finally, the :class:`MultiaSyncDataCollector` will
25+
execute the data collection on several workers and deliver the first batch of results
26+
that it can gather. This execution will occur continuously and concomitantly with
27+
the training of the networks: this implies that the weights of the policy that
28+
is used for the data collection may slightly lag the configuration of the policy
29+
on the training worker. Therefore, although this class may be the fastest to collect
30+
data, it comes at the price of being suitable only in settings where it is acceptable
31+
to gather data asynchronously (e.g. off-policy RL or curriculum RL).
32+
For remotely executed rollouts (:class:`MultiSyncDataCollector` or :class:`MultiaSyncDataCollector`)
33+
it is necessary to synchronise the weights of the remote policy with the weights
34+
from the training worker using either the :meth:`collector.update_policy_weights_` or
35+
by setting ``update_at_each_batch=True`` in the constructor.
36+
37+
The second parameter to consider (in the remote settings) is the device where the
38+
data will be collected and the device where the environment and policy operations
39+
will be executed. For instance, a policy executed on CPU may be slower than one
40+
executed on CUDA. When multiple inference workers run concomitantly, dispatching
41+
the compute workload across the available devices may speed up the collection or
42+
avoid OOM errors. Finally, the choice of the batch size and passing device (ie the
43+
device where the data will be stored while waiting to be passed to the collection
44+
worker) may also impact the memory management. The key parameters to control are
45+
``devices`` which controls the execution devices (ie the device of the policy)
46+
and ``storing_device`` which will control the device where the environment and
47+
data are stored during a rollout. A good heuristic is usually to use the same device
48+
for storage and compute, which is the default behavior when only the ``devices`` argument
49+
is being passed.
50+
51+
Besides those compute parameters, users may choose to configure the following parameters:
52+
53+
- max_frames_per_traj: the number of frames after which a :meth:`env.reset` is called
54+
- frames_per_batch: the number of frames delivered at each iteration over the collector
55+
- init_random_frames: the number of random steps (steps where :meth:`env.rand_step` is being called)
56+
- reset_at_each_iter: if ``True``, the environment(s) will be reset after each batch collection
57+
- split_trajs: if ``True``, the trajectories will be split and delivered in a padded tensordict
58+
along with a ``"mask"`` key that will point to a boolean mask representing the valid values.
59+
- exploration_type: the exploration strategy to be used with the policy.
60+
- reset_when_done: whether environments should be reset when reaching a done state.
61+
62+
Collectors and batch size
63+
-------------------------
64+
65+
Because each collector has its own way of organizing the environments that are
66+
run within, the data will come with different batch-size depending on how
67+
the specificities of the collector. The following table summarizes what is to
68+
be expected when collecting data:
69+
70+
71+
+--------------------+---------------------+--------------------------------------------+------------------------------+
72+
| | SyncDataCollector | MultiSyncDataCollector (n=B) |MultiaSyncDataCollector (n=B) |
73+
+====================+=====================+=============+==============+===============+==============================+
74+
| `cat_results` | NA | `"stack"` | `0` | `-1` | NA |
75+
+--------------------+---------------------+-------------+--------------+---------------+------------------------------+
76+
| Single env | [T] | `[B, T]` | `[B*(T//B)` | `[B*(T//B)]` | [T] |
77+
+--------------------+---------------------+-------------+--------------+---------------+------------------------------+
78+
| Batched env (n=P) | [P, T] | `[B, P, T]` | `[B * P, T]`| `[P, T * B]` | [P, T] |
79+
+--------------------+---------------------+-------------+--------------+---------------+------------------------------+
80+
81+
In each of these cases, the last dimension (``T`` for ``time``) is adapted such
82+
that the batch size equals the ``frames_per_batch`` argument passed to the
83+
collector.
84+
85+
.. warning:: :class:`~torchrl.collectors.MultiSyncDataCollector` should not be
86+
used with ``cat_results=0``, as the data will be stacked along the batch
87+
dimension with batched environment, or the time dimension for single environments,
88+
which can introduce some confusion when swapping one with the other.
89+
``cat_results="stack"`` is a better and more consistent way of interacting
90+
with the environments as it will keep each dimension separate, and provide
91+
better interchangeability between configurations, collector classes and other
92+
components.
93+
94+
Whereas :class:`~torchrl.collectors.MultiSyncDataCollector`
95+
has a dimension corresponding to the number of sub-collectors being run (``B``),
96+
:class:`~torchrl.collectors.MultiaSyncDataCollector` doesn't. This
97+
is easily understood when considering that :class:`~torchrl.collectors.MultiaSyncDataCollector`
98+
delivers batches of data on a first-come, first-serve basis, whereas
99+
:class:`~torchrl.collectors.MultiSyncDataCollector` gathers data from
100+
each sub-collector before delivering it.
101+
102+
Collectors and policy copies
103+
----------------------------
104+
105+
When passing a policy to a collector, we can choose the device on which this policy will be run. This can be used to
106+
keep the training version of the policy on a device and the inference version on another. For example, if you have two
107+
CUDA devices, it may be wise to train on one device and execute the policy for inference on the other. If that is the
108+
case, a :meth:`~torchrl.collectors.DataCollector.update_policy_weights_` can be used to copy the parameters from one
109+
device to the other (if no copy is required, this method is a no-op).
110+
111+
Since the goal is to avoid calling `policy.to(policy_device)` explicitly, the collector will do a deepcopy of the
112+
policy structure and copy the parameters placed on the new device during instantiation if necessary.
113+
Since not all policies support deepcopies (e.g., policies using CUDA graphs or relying on third-party libraries), we
114+
try to limit the cases where a deepcopy will be executed. The following chart shows when this will occur.
115+
116+
.. figure:: /_static/img/collector-copy.png
117+
118+
Policy copy decision tree in Collectors.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
.. currentmodule:: torchrl.collectors.distributed
2+
3+
Distributed Collectors
4+
======================
5+
6+
TorchRL provides a set of distributed data collectors. These tools support
7+
multiple backends (``'gloo'``, ``'nccl'``, ``'mpi'`` with the :class:`~.DistributedDataCollector`
8+
or PyTorch RPC with :class:`~.RPCDataCollector`) and launchers (``'ray'``,
9+
``submitit`` or ``torch.multiprocessing``).
10+
They can be efficiently used in synchronous or asynchronous mode, on a single
11+
node or across multiple nodes.
12+
13+
*Resources*: Find examples for these collectors in the
14+
`dedicated folder <https://github.com/pytorch/rl/examples/distributed/collectors>`_.
15+
16+
.. note::
17+
*Choosing the sub-collector*: All distributed collectors support the various single machine collectors.
18+
One may wonder why using a :class:`MultiSyncDataCollector` or a :class:`~torchrl.envs.ParallelEnv`
19+
instead. In general, multiprocessed collectors have a lower IO footprint than
20+
parallel environments which need to communicate at each step. Yet, the model specs
21+
play a role in the opposite direction, since using parallel environments will
22+
result in a faster execution of the policy (and/or transforms) since these
23+
operations will be vectorized.
24+
25+
.. note::
26+
*Choosing the device of a collector (or a parallel environment)*: Sharing data
27+
among processes is achieved via shared-memory buffers with parallel environment
28+
and multiprocessed environments executed on CPU. Depending on the capabilities
29+
of the machine being used, this may be prohibitively slow compared to sharing
30+
data on GPU which is natively supported by cuda drivers.
31+
In practice, this means that using the ``device="cpu"`` keyword argument when
32+
building a parallel environment or collector can result in a slower collection
33+
than using ``device="cuda"`` when available.
34+
35+
.. note::
36+
Given the library's many optional dependencies (eg, Gym, Gymnasium, and many others)
37+
warnings can quickly become quite annoying in multiprocessed / distributed settings.
38+
By default, TorchRL filters out these warnings in sub-processes. If one still wishes to
39+
see these warnings, they can be displayed by setting ``torchrl.filter_warnings_subprocess=False``.
40+
41+
.. autosummary::
42+
:toctree: generated/
43+
:template: rl_template.rst
44+
45+
DistributedDataCollector
46+
RPCDataCollector
47+
DistributedSyncDataCollector
48+
submitit_delayed_launcher
49+
RayCollector
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
.. currentmodule:: torchrl.collectors
2+
3+
Collectors and Replay Buffers
4+
=============================
5+
6+
Collectors and replay buffers interoperability
7+
----------------------------------------------
8+
9+
In the simplest scenario where single transitions have to be sampled
10+
from the replay buffer, little attention has to be given to the way
11+
the collector is built. Flattening the data after collection will
12+
be a sufficient preprocessing step before populating the storage:
13+
14+
>>> memory = ReplayBuffer(
15+
... storage=LazyTensorStorage(N),
16+
... transform=lambda data: data.reshape(-1))
17+
>>> for data in collector:
18+
... memory.extend(data)
19+
20+
If trajectory slices have to be collected, the recommended way to achieve this is to create
21+
a multidimensional buffer and sample using the :class:`~torchrl.data.replay_buffers.SliceSampler`
22+
sampler class. One must ensure that the data passed to the buffer is properly shaped, with the
23+
``time`` and ``batch`` dimensions clearly separated. In practice, the following configurations
24+
will work:
25+
26+
>>> # Single environment: no need for a multi-dimensional buffer
27+
>>> memory = ReplayBuffer(
28+
... storage=LazyTensorStorage(N),
29+
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
30+
... )
31+
>>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1)
32+
>>> for data in collector:
33+
... memory.extend(data)
34+
>>> # Batched environments: a multi-dim buffer is required
35+
>>> memory = ReplayBuffer(
36+
... storage=LazyTensorStorage(N, ndim=2),
37+
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
38+
... )
39+
>>> env = ParallelEnv(4, make_env)
40+
>>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1)
41+
>>> for data in collector:
42+
... memory.extend(data)
43+
>>> # MultiSyncDataCollector + regular env: behaves like a ParallelEnv if cat_results="stack"
44+
>>> memory = ReplayBuffer(
45+
... storage=LazyTensorStorage(N, ndim=2),
46+
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
47+
... )
48+
>>> collector = MultiSyncDataCollector([make_env] * 4,
49+
... policy,
50+
... frames_per_batch=N,
51+
... total_frames=-1,
52+
... cat_results="stack")
53+
>>> for data in collector:
54+
... memory.extend(data)
55+
>>> # MultiSyncDataCollector + parallel env: the ndim must be adapted accordingly
56+
>>> memory = ReplayBuffer(
57+
... storage=LazyTensorStorage(N, ndim=3),
58+
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
59+
... )
60+
>>> collector = MultiSyncDataCollector([ParallelEnv(2, make_env)] * 4,
61+
... policy,
62+
... frames_per_batch=N,
63+
... total_frames=-1,
64+
... cat_results="stack")
65+
>>> for data in collector:
66+
... memory.extend(data)
67+
68+
Using replay buffers that sample trajectories with :class:`~torchrl.collectors.MultiSyncDataCollector`
69+
isn't currently fully supported as the data batches can come from any worker and in most cases consecutive
70+
batches written in the buffer won't come from the same source (thereby interrupting the trajectories).
71+
72+
Helper functions
73+
----------------
74+
75+
.. currentmodule:: torchrl.collectors.utils
76+
77+
.. autosummary::
78+
:toctree: generated/
79+
:template: rl_template_fun.rst
80+
81+
split_trajectories
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
.. currentmodule:: torchrl.collectors
2+
3+
Single Node Collectors
4+
======================
5+
6+
TorchRL provides several collector classes for single-node data collection, each with different execution strategies.
7+
8+
Single node data collectors
9+
---------------------------
10+
11+
.. autosummary::
12+
:toctree: generated/
13+
:template: rl_template.rst
14+
15+
DataCollectorBase
16+
SyncDataCollector
17+
MultiSyncDataCollector
18+
MultiaSyncDataCollector
19+
aSyncDataCollector
20+
21+
Running the Collector Asynchronously
22+
------------------------------------
23+
24+
Passing replay buffers to a collector allows us to start the collection and get rid of the iterative nature of the
25+
collector.
26+
If you want to run a data collector in the background, simply run :meth:`~torchrl.DataCollectorBase.start`:
27+
28+
>>> collector = SyncDataCollector(..., replay_buffer=rb) # pass your replay buffer
29+
>>> collector.start()
30+
>>> # little pause
31+
>>> time.sleep(10)
32+
>>> # Start training
33+
>>> for i in range(optim_steps):
34+
... data = rb.sample() # Sampling from the replay buffer
35+
... # rest of the training loop
36+
37+
Single-process collectors (:class:`~torchrl.collectors.SyncDataCollector`) will run the process using multithreading,
38+
so be mindful of Python's GIL and related multithreading restrictions.
39+
40+
Multiprocessed collectors will on the other hand let the child processes handle the filling of the buffer on their own,
41+
which truly decouples the data collection and training.
42+
43+
Data collectors that have been started with `start()` should be shut down using
44+
:meth:`~torchrl.DataCollectorBase.async_shutdown`.
45+
46+
.. warning:: Running a collector asynchronously decouples the collection from training, which means that the training
47+
performance may be drastically different depending on the hardware, load and other factors (although it is generally
48+
expected to provide significant speed-ups). Make sure you understand how this may affect your algorithm and if it
49+
is a legitimate thing to do! (For example, on-policy algorithms such as PPO should not be run asynchronously
50+
unless properly benchmarked).

0 commit comments

Comments
 (0)