|
| 1 | +.. currentmodule:: torchrl.collectors |
| 2 | + |
| 3 | +.. _ref_collectors: |
| 4 | + |
| 5 | +Collector Basics |
| 6 | +================ |
| 7 | + |
| 8 | +Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they |
| 9 | +collect data over non-static data sources and (2) the data is collected using a model |
| 10 | +(likely a version of the model that is being trained). |
| 11 | + |
| 12 | +TorchRL's data collectors accept two main arguments: an environment (or a list of |
| 13 | +environment constructors) and a policy. They will iteratively execute an environment |
| 14 | +step and a policy query over a defined number of steps before delivering a stack of |
| 15 | +the data collected to the user. Environments will be reset whenever they reach a done |
| 16 | +state, and/or after a predefined number of steps. |
| 17 | + |
| 18 | +Because data collection is a potentially compute heavy process, it is crucial to |
| 19 | +configure the execution hyperparameters appropriately. |
| 20 | +The first parameter to take into consideration is whether the data collection should |
| 21 | +occur serially with the optimization step or in parallel. The :class:`SyncDataCollector` |
| 22 | +class will execute the data collection on the training worker. The :class:`MultiSyncDataCollector` |
| 23 | +will split the workload across an number of workers and aggregate the results that |
| 24 | +will be delivered to the training worker. Finally, the :class:`MultiaSyncDataCollector` will |
| 25 | +execute the data collection on several workers and deliver the first batch of results |
| 26 | +that it can gather. This execution will occur continuously and concomitantly with |
| 27 | +the training of the networks: this implies that the weights of the policy that |
| 28 | +is used for the data collection may slightly lag the configuration of the policy |
| 29 | +on the training worker. Therefore, although this class may be the fastest to collect |
| 30 | +data, it comes at the price of being suitable only in settings where it is acceptable |
| 31 | +to gather data asynchronously (e.g. off-policy RL or curriculum RL). |
| 32 | +For remotely executed rollouts (:class:`MultiSyncDataCollector` or :class:`MultiaSyncDataCollector`) |
| 33 | +it is necessary to synchronise the weights of the remote policy with the weights |
| 34 | +from the training worker using either the :meth:`collector.update_policy_weights_` or |
| 35 | +by setting ``update_at_each_batch=True`` in the constructor. |
| 36 | + |
| 37 | +The second parameter to consider (in the remote settings) is the device where the |
| 38 | +data will be collected and the device where the environment and policy operations |
| 39 | +will be executed. For instance, a policy executed on CPU may be slower than one |
| 40 | +executed on CUDA. When multiple inference workers run concomitantly, dispatching |
| 41 | +the compute workload across the available devices may speed up the collection or |
| 42 | +avoid OOM errors. Finally, the choice of the batch size and passing device (ie the |
| 43 | +device where the data will be stored while waiting to be passed to the collection |
| 44 | +worker) may also impact the memory management. The key parameters to control are |
| 45 | +``devices`` which controls the execution devices (ie the device of the policy) |
| 46 | +and ``storing_device`` which will control the device where the environment and |
| 47 | +data are stored during a rollout. A good heuristic is usually to use the same device |
| 48 | +for storage and compute, which is the default behavior when only the ``devices`` argument |
| 49 | +is being passed. |
| 50 | + |
| 51 | +Besides those compute parameters, users may choose to configure the following parameters: |
| 52 | + |
| 53 | +- max_frames_per_traj: the number of frames after which a :meth:`env.reset` is called |
| 54 | +- frames_per_batch: the number of frames delivered at each iteration over the collector |
| 55 | +- init_random_frames: the number of random steps (steps where :meth:`env.rand_step` is being called) |
| 56 | +- reset_at_each_iter: if ``True``, the environment(s) will be reset after each batch collection |
| 57 | +- split_trajs: if ``True``, the trajectories will be split and delivered in a padded tensordict |
| 58 | + along with a ``"mask"`` key that will point to a boolean mask representing the valid values. |
| 59 | +- exploration_type: the exploration strategy to be used with the policy. |
| 60 | +- reset_when_done: whether environments should be reset when reaching a done state. |
| 61 | + |
| 62 | +Collectors and batch size |
| 63 | +------------------------- |
| 64 | + |
| 65 | +Because each collector has its own way of organizing the environments that are |
| 66 | +run within, the data will come with different batch-size depending on how |
| 67 | +the specificities of the collector. The following table summarizes what is to |
| 68 | +be expected when collecting data: |
| 69 | + |
| 70 | + |
| 71 | ++--------------------+---------------------+--------------------------------------------+------------------------------+ |
| 72 | +| | SyncDataCollector | MultiSyncDataCollector (n=B) |MultiaSyncDataCollector (n=B) | |
| 73 | ++====================+=====================+=============+==============+===============+==============================+ |
| 74 | +| `cat_results` | NA | `"stack"` | `0` | `-1` | NA | |
| 75 | ++--------------------+---------------------+-------------+--------------+---------------+------------------------------+ |
| 76 | +| Single env | [T] | `[B, T]` | `[B*(T//B)` | `[B*(T//B)]` | [T] | |
| 77 | ++--------------------+---------------------+-------------+--------------+---------------+------------------------------+ |
| 78 | +| Batched env (n=P) | [P, T] | `[B, P, T]` | `[B * P, T]`| `[P, T * B]` | [P, T] | |
| 79 | ++--------------------+---------------------+-------------+--------------+---------------+------------------------------+ |
| 80 | + |
| 81 | +In each of these cases, the last dimension (``T`` for ``time``) is adapted such |
| 82 | +that the batch size equals the ``frames_per_batch`` argument passed to the |
| 83 | +collector. |
| 84 | + |
| 85 | +.. warning:: :class:`~torchrl.collectors.MultiSyncDataCollector` should not be |
| 86 | + used with ``cat_results=0``, as the data will be stacked along the batch |
| 87 | + dimension with batched environment, or the time dimension for single environments, |
| 88 | + which can introduce some confusion when swapping one with the other. |
| 89 | + ``cat_results="stack"`` is a better and more consistent way of interacting |
| 90 | + with the environments as it will keep each dimension separate, and provide |
| 91 | + better interchangeability between configurations, collector classes and other |
| 92 | + components. |
| 93 | + |
| 94 | +Whereas :class:`~torchrl.collectors.MultiSyncDataCollector` |
| 95 | +has a dimension corresponding to the number of sub-collectors being run (``B``), |
| 96 | +:class:`~torchrl.collectors.MultiaSyncDataCollector` doesn't. This |
| 97 | +is easily understood when considering that :class:`~torchrl.collectors.MultiaSyncDataCollector` |
| 98 | +delivers batches of data on a first-come, first-serve basis, whereas |
| 99 | +:class:`~torchrl.collectors.MultiSyncDataCollector` gathers data from |
| 100 | +each sub-collector before delivering it. |
| 101 | + |
| 102 | +Collectors and policy copies |
| 103 | +---------------------------- |
| 104 | + |
| 105 | +When passing a policy to a collector, we can choose the device on which this policy will be run. This can be used to |
| 106 | +keep the training version of the policy on a device and the inference version on another. For example, if you have two |
| 107 | +CUDA devices, it may be wise to train on one device and execute the policy for inference on the other. If that is the |
| 108 | +case, a :meth:`~torchrl.collectors.DataCollector.update_policy_weights_` can be used to copy the parameters from one |
| 109 | +device to the other (if no copy is required, this method is a no-op). |
| 110 | + |
| 111 | +Since the goal is to avoid calling `policy.to(policy_device)` explicitly, the collector will do a deepcopy of the |
| 112 | +policy structure and copy the parameters placed on the new device during instantiation if necessary. |
| 113 | +Since not all policies support deepcopies (e.g., policies using CUDA graphs or relying on third-party libraries), we |
| 114 | +try to limit the cases where a deepcopy will be executed. The following chart shows when this will occur. |
| 115 | + |
| 116 | +.. figure:: /_static/img/collector-copy.png |
| 117 | + |
| 118 | + Policy copy decision tree in Collectors. |
0 commit comments