@@ -162,7 +162,15 @@ Usage Examples
162162Using Weight Update Schemes Independently
163163^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
164164
165- Weight update schemes can be used outside of collectors for custom synchronization scenarios. Here's a basic example:
165+ Weight update schemes can be used outside of collectors for custom synchronization scenarios.
166+ The new simplified API provides four core methods for weight synchronization:
167+
168+ - ``init_on_sender(model_id, **kwargs) `` - Initialize on the main process (trainer) side
169+ - ``init_on_worker(model_id, **kwargs) `` - Initialize on worker process side
170+ - ``get_sender() `` - Get the configured sender instance
171+ - ``get_receiver() `` - Get the configured receiver instance
172+
173+ Here's a basic example:
166174
167175.. code-block :: python
168176
@@ -182,39 +190,37 @@ Weight update schemes can be used outside of collectors for custom synchronizati
182190 # --------------------------------------------------------------
183191 # On the main process side (trainer):
184192 scheme = MultiProcessWeightSyncScheme(strategy = " state_dict" )
185- sender = scheme.create_sender()
186-
187- # Register worker pipes
193+
194+ # Initialize scheme with pipes
188195 parent_pipe, child_pipe = mp.Pipe()
189- sender.register_worker(worker_idx = 0 , pipe_or_context = parent_pipe)
190-
191- # Send weights to workers
196+ scheme.init_on_sender(model_id = " policy" , pipes = [parent_pipe])
197+
198+ # Get the sender and send weights
199+ sender = scheme.get_sender()
192200 weights = policy.state_dict()
193- sender.update_weights(weights)
201+ sender.send(weights) # Synchronous send
202+ # or sender.send_async(weights); sender.wait_async() # Asynchronous send
194203
195204 # On the worker process side:
196- # receiver = scheme.create_receiver()
197- # receiver.register_model(policy)
198- # receiver.register_worker_transport(child_pipe)
199- # # Receive and apply weights
200- # result = receiver._transport.receive_weights(timeout=5.0)
201- # if result is not None:
202- # model_id, weights = result
203- # receiver.apply_weights(weights)
205+ # scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy)
206+ # receiver = scheme.get_receiver()
207+ # # Non-blocking check for new weights
208+ # if receiver.receive(timeout=0.001):
209+ # # Weights were received and applied
204210
205211 # Example 2: Shared memory weight synchronization
206212 # ------------------------------------------------
207213 # Create shared memory scheme with auto-registration
208214 shared_scheme = SharedMemWeightSyncScheme(strategy = " tensordict" , auto_register = True )
209- shared_sender = shared_scheme.create_sender()
210-
211- # Register worker pipes for lazy registration
215+
216+ # Initialize with pipes for lazy registration
212217 parent_pipe2, child_pipe2 = mp.Pipe()
213- shared_sender.register_worker(worker_idx = 0 , pipe_or_context = parent_pipe2)
214-
215- # Send weights (automatically creates shared buffer on first send)
218+ shared_scheme.init_on_sender(model_id = " policy" , pipes = [parent_pipe2])
219+
220+ # Get sender and send weights (automatically creates shared buffer on first send)
221+ shared_sender = shared_scheme.get_sender()
216222 weights_td = TensorDict.from_module(policy)
217- shared_sender.update_weights (weights_td)
223+ shared_sender.send (weights_td)
218224
219225 # Workers automatically see updates via shared memory!
220226
0 commit comments