Skip to content

Commit 41e0c22

Browse files
committed
[Refactor] Weight sync schemes refactor
ghstack-source-id: bbae696 Pull-Request: #3230
1 parent e06c3fe commit 41e0c22

File tree

5 files changed

+1427
-272
lines changed

5 files changed

+1427
-272
lines changed

docs/source/reference/collectors.rst

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,15 @@ Usage Examples
162162
Using Weight Update Schemes Independently
163163
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
164164

165-
Weight update schemes can be used outside of collectors for custom synchronization scenarios. Here's a basic example:
165+
Weight update schemes can be used outside of collectors for custom synchronization scenarios.
166+
The new simplified API provides four core methods for weight synchronization:
167+
168+
- ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side
169+
- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side
170+
- ``get_sender()`` - Get the configured sender instance
171+
- ``get_receiver()`` - Get the configured receiver instance
172+
173+
Here's a basic example:
166174

167175
.. code-block:: python
168176
@@ -182,39 +190,37 @@ Weight update schemes can be used outside of collectors for custom synchronizati
182190
# --------------------------------------------------------------
183191
# On the main process side (trainer):
184192
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
185-
sender = scheme.create_sender()
186-
187-
# Register worker pipes
193+
194+
# Initialize scheme with pipes
188195
parent_pipe, child_pipe = mp.Pipe()
189-
sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe)
190-
191-
# Send weights to workers
196+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
197+
198+
# Get the sender and send weights
199+
sender = scheme.get_sender()
192200
weights = policy.state_dict()
193-
sender.update_weights(weights)
201+
sender.send(weights) # Synchronous send
202+
# or sender.send_async(weights); sender.wait_async() # Asynchronous send
194203
195204
# On the worker process side:
196-
# receiver = scheme.create_receiver()
197-
# receiver.register_model(policy)
198-
# receiver.register_worker_transport(child_pipe)
199-
# # Receive and apply weights
200-
# result = receiver._transport.receive_weights(timeout=5.0)
201-
# if result is not None:
202-
# model_id, weights = result
203-
# receiver.apply_weights(weights)
205+
# scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy)
206+
# receiver = scheme.get_receiver()
207+
# # Non-blocking check for new weights
208+
# if receiver.receive(timeout=0.001):
209+
# # Weights were received and applied
204210
205211
# Example 2: Shared memory weight synchronization
206212
# ------------------------------------------------
207213
# Create shared memory scheme with auto-registration
208214
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
209-
shared_sender = shared_scheme.create_sender()
210-
211-
# Register worker pipes for lazy registration
215+
216+
# Initialize with pipes for lazy registration
212217
parent_pipe2, child_pipe2 = mp.Pipe()
213-
shared_sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe2)
214-
215-
# Send weights (automatically creates shared buffer on first send)
218+
shared_scheme.init_on_sender(model_id="policy", pipes=[parent_pipe2])
219+
220+
# Get sender and send weights (automatically creates shared buffer on first send)
221+
shared_sender = shared_scheme.get_sender()
216222
weights_td = TensorDict.from_module(policy)
217-
shared_sender.update_weights(weights_td)
223+
shared_sender.send(weights_td)
218224
219225
# Workers automatically see updates via shared memory!
220226

0 commit comments

Comments
 (0)