Skip to content

Commit 00baa00

Browse files
committed
Update
[ghstack-poisoned]
1 parent 3852c9f commit 00baa00

File tree

3 files changed

+37
-32
lines changed

3 files changed

+37
-32
lines changed

docs/source/reference/collectors.rst

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,15 @@ Usage Examples
162162
Using Weight Update Schemes Independently
163163
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
164164

165-
Weight update schemes can be used outside of collectors for custom synchronization scenarios. Here's a basic example:
165+
Weight update schemes can be used outside of collectors for custom synchronization scenarios.
166+
The new simplified API provides four core methods for weight synchronization:
167+
168+
- ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side
169+
- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side
170+
- ``get_sender()`` - Get the configured sender instance
171+
- ``get_receiver()`` - Get the configured receiver instance
172+
173+
Here's a basic example:
166174

167175
.. code-block:: python
168176
@@ -182,39 +190,37 @@ Weight update schemes can be used outside of collectors for custom synchronizati
182190
# --------------------------------------------------------------
183191
# On the main process side (trainer):
184192
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
185-
sender = scheme.create_sender()
186-
187-
# Register worker pipes
193+
194+
# Initialize scheme with pipes
188195
parent_pipe, child_pipe = mp.Pipe()
189-
sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe)
190-
191-
# Send weights to workers
196+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
197+
198+
# Get the sender and send weights
199+
sender = scheme.get_sender()
192200
weights = policy.state_dict()
193-
sender.update_weights(weights)
201+
sender.send(weights) # Synchronous send
202+
# or sender.send_async(weights); sender.wait_async() # Asynchronous send
194203
195204
# On the worker process side:
196-
# receiver = scheme.create_receiver()
197-
# receiver.register_model(policy)
198-
# receiver.register_worker_transport(child_pipe)
199-
# # Receive and apply weights
200-
# result = receiver._transport.receive_weights(timeout=5.0)
201-
# if result is not None:
202-
# model_id, weights = result
203-
# receiver.apply_weights(weights)
205+
# scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy)
206+
# receiver = scheme.get_receiver()
207+
# # Non-blocking check for new weights
208+
# if receiver.receive(timeout=0.001):
209+
# # Weights were received and applied
204210
205211
# Example 2: Shared memory weight synchronization
206212
# ------------------------------------------------
207213
# Create shared memory scheme with auto-registration
208214
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
209-
shared_sender = shared_scheme.create_sender()
210-
211-
# Register worker pipes for lazy registration
215+
216+
# Initialize with pipes for lazy registration
212217
parent_pipe2, child_pipe2 = mp.Pipe()
213-
shared_sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe2)
214-
215-
# Send weights (automatically creates shared buffer on first send)
218+
shared_scheme.init_on_sender(model_id="policy", pipes=[parent_pipe2])
219+
220+
# Get sender and send weights (automatically creates shared buffer on first send)
221+
shared_sender = shared_scheme.get_sender()
216222
weights_td = TensorDict.from_module(policy)
217-
shared_sender.update_weights(weights_td)
223+
shared_sender.send(weights_td)
218224
219225
# Workers automatically see updates via shared memory!
220226

torchrl/collectors/collectors.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2118,9 +2118,10 @@ def __repr__(self) -> str:
21182118
try:
21192119
env_str = indent(f"env={self.env}", 4 * " ")
21202120
policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ")
2121-
td_out_str = indent(
2122-
f"td_out={getattr(self, '_final_rollout', None)}", 4 * " "
2123-
)
2121+
td_out_str = repr(getattr(self, "_final_rollout", None))
2122+
if len(td_out_str) > 50:
2123+
td_out_str = td_out_str[:50] + "..."
2124+
td_out_str = indent(f"td_out={td_out_str}", 4 * " ")
21242125
string = (
21252126
f"{self.__class__.__name__}("
21262127
f"\n{env_str},"

torchrl/weight_update/weight_sync_schemes.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1873,10 +1873,9 @@ def init_on_sender(
18731873
sender = WeightSender(self)
18741874
sender._model_id = model_id
18751875

1876-
# Create transports for each Ray actor and register them
1876+
# Register each Ray actor - _register_worker will create the transport
18771877
for worker_idx, remote_collector in enumerate(remote_collectors):
1878-
transport = self.create_transport(remote_collector)
1879-
sender._register_worker(worker_idx, transport)
1878+
sender._register_worker(worker_idx, remote_collector)
18801879

18811880
# Set context with weak reference to avoid circular refs
18821881
if context is not None:
@@ -2012,10 +2011,9 @@ def init_on_sender(
20122011
sender = self.create_sender()
20132012
sender._model_id = model_id
20142013

2015-
# Register all actors as workers
2014+
# Register all actors - _register_worker will create the transport
20162015
for worker_idx, actor_ref in enumerate(actor_refs):
2017-
transport = self.create_transport(actor_ref)
2018-
sender._register_worker(worker_idx, transport)
2016+
sender._register_worker(worker_idx, actor_ref)
20192017

20202018
# Set context with weak reference
20212019
if context is not None:

0 commit comments

Comments
 (0)