55from __future__ import annotations
66
77import argparse
8+ import importlib .util
9+ import pickle
10+ import time
811
912import pytest
1013import torch
1619from torchrl .collectors import MultiSyncDataCollector , SyncDataCollector
1720from torchrl .weight_update .weight_sync_schemes import (
1821 _resolve_model ,
22+ DistributedWeightSyncScheme ,
1923 MPTransport ,
2024 MultiProcessWeightSyncScheme ,
2125 NoWeightSyncScheme ,
26+ RayModuleTransformScheme ,
27+ RayWeightSyncScheme ,
28+ RPCWeightSyncScheme ,
2229 SharedMemTransport ,
2330 SharedMemWeightSyncScheme ,
2431 WeightStrategy ,
2532)
2633
34+ _has_ray = importlib .util .find_spec ("ray" ) is not None
35+
2736
2837def worker_update_policy (pipe , timeout = 5.0 ):
2938 policy = nn .Linear (4 , 2 )
@@ -32,9 +41,8 @@ def worker_update_policy(pipe, timeout=5.0):
3241 policy .bias .fill_ (0.0 )
3342
3443 scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
35- receiver = scheme .create_receiver ()
36- receiver .register_model (policy )
37- receiver .register_worker_transport (pipe )
44+ scheme .init_on_worker (model_id = "policy" , pipe = pipe , model = policy )
45+ receiver = scheme .get_receiver ()
3846
3947 if receiver ._transport .pipe .poll (timeout ):
4048 data , msg = receiver ._transport .pipe .recv ()
@@ -52,9 +60,8 @@ def worker_update_policy_tensordict(pipe, timeout=5.0):
5260 policy .bias .fill_ (0.0 )
5361
5462 scheme = MultiProcessWeightSyncScheme (strategy = "tensordict" )
55- receiver = scheme .create_receiver ()
56- receiver .register_model (policy )
57- receiver .register_worker_transport (pipe )
63+ scheme .init_on_worker (model_id = "policy" , pipe = pipe , model = policy )
64+ receiver = scheme .get_receiver ()
5865
5966 if receiver ._transport .pipe .poll (timeout ):
6067 data , msg = receiver ._transport .pipe .recv ()
@@ -75,8 +82,6 @@ def worker_shared_mem(pipe, timeout=10.0):
7582 shared_weights .to_module (policy )
7683 pipe .send ((None , "registered" ))
7784
78- import time
79-
8085 time .sleep (0.5 )
8186
8287 return policy .weight .sum ().item (), policy .bias .sum ().item ()
@@ -192,39 +197,46 @@ def test_cross_format_conversion(self):
192197
193198
194199class TestWeightSyncSchemes :
200+ """Tests for weight sync schemes using the new simplified API.
201+
202+ Lower-level transport and legacy API tests are in TestTransportBackends.
203+ """
204+
195205 def test_multiprocess_scheme_state_dict (self ):
196206 parent_pipe , child_pipe = mp .Pipe ()
197207
198208 scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
199- sender = scheme .create_sender ( )
200- sender . register_worker ( 0 , parent_pipe )
209+ scheme .init_on_sender ( model_id = "policy" , pipes = [ parent_pipe ] )
210+ sender = scheme . get_sender ( )
201211
202212 proc = mp .Process (target = worker_update_policy , args = (child_pipe ,))
203- proc .start ()
213+ try :
214+ proc .start ()
204215
205- weights = {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}
206- sender .update_weights (weights )
207-
208- proc .join (timeout = 10.0 )
209- assert not proc .is_alive ()
216+ weights = {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}
217+ sender .send (weights )
218+ finally :
219+ proc .join (timeout = 10.0 )
220+ assert not proc .is_alive ()
210221
211222 def test_multiprocess_scheme_tensordict (self ):
212223 parent_pipe , child_pipe = mp .Pipe ()
213224
214225 scheme = MultiProcessWeightSyncScheme (strategy = "tensordict" )
215- sender = scheme .create_sender ( )
216- sender . register_worker ( 0 , parent_pipe )
226+ scheme .init_on_sender ( model_id = "policy" , pipes = [ parent_pipe ] )
227+ sender = scheme . get_sender ( )
217228
218229 proc = mp .Process (target = worker_update_policy_tensordict , args = (child_pipe ,))
219- proc .start ()
230+ try :
231+ proc .start ()
220232
221- weights = TensorDict (
222- {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}, batch_size = []
223- )
224- sender .update_weights (weights )
225-
226- proc .join (timeout = 10.0 )
227- assert not proc .is_alive ()
233+ weights = TensorDict (
234+ {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}, batch_size = []
235+ )
236+ sender .send (weights )
237+ finally :
238+ proc .join (timeout = 10.0 )
239+ assert not proc .is_alive ()
228240
229241 def test_shared_mem_scheme (self ):
230242 shared_buffer = TensorDict (
@@ -270,6 +282,51 @@ def test_no_weight_sync_scheme(self):
270282 weights = {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}
271283 transport .send_weights ("policy" , weights )
272284
285+ @classmethod
286+ def _worker_with_receive (cls , pipe , scheme ):
287+ policy = nn .Linear (4 , 2 )
288+ with torch .no_grad ():
289+ policy .weight .fill_ (0.0 )
290+ policy .bias .fill_ (0.0 )
291+
292+ scheme .init_on_worker (model_id = "policy" , pipe = pipe , model = policy )
293+ receiver = scheme .get_receiver ()
294+
295+ # Non-blocking receive should return False when no data
296+ result = receiver .receive (timeout = 0.001 )
297+ assert result is False
298+
299+ # Now actually receive the weights
300+ result = receiver .receive (timeout = 5.0 )
301+ assert result is True
302+
303+ # Check weights were applied
304+ return policy .weight .sum ().item (), policy .bias .sum ().item ()
305+
306+ def test_receiver_receive_method (self ):
307+ """Test the new non-blocking receive() method."""
308+
309+ parent_pipe , child_pipe = mp .Pipe ()
310+
311+ scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
312+ scheme .init_on_sender (model_id = "policy" , pipes = [parent_pipe ])
313+ sender = scheme .get_sender ()
314+
315+ proc = mp .Process (target = self ._worker_with_receive , args = (child_pipe , scheme ))
316+ try :
317+ proc .start ()
318+
319+ # Give worker time to call receive with no data
320+
321+ time .sleep (0.1 )
322+
323+ weights = {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}
324+ sender .send (weights )
325+
326+ finally :
327+ proc .join (timeout = 10.0 )
328+ assert not proc .is_alive ()
329+
273330
274331class TestCollectorIntegration :
275332 @pytest .fixture
@@ -568,6 +625,217 @@ def test_weight_strategy_parametrized(strategy):
568625 assert torch .allclose (policy .bias , target .bias )
569626
570627
628+ class TestSerializeScheme :
629+ """Test that WeightSyncScheme instances can be serialized after initialization.
630+
631+ This is critical for multiprocessing and Ray, where schemes may be pickled
632+ and sent across process boundaries. The _sender and _receiver attributes
633+ contain non-serializable objects (pipes, weak references, etc.) and must
634+ be excluded from serialization.
635+ """
636+
637+ def test_multiprocess_scheme_serialize_before_init (self ):
638+ """Test that uninitialized scheme can be pickled."""
639+ scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
640+
641+ # Serialize and deserialize
642+ pickled = pickle .dumps (scheme )
643+ restored = pickle .loads (pickled )
644+
645+ # Check that configuration is preserved
646+ assert restored .strategy == "state_dict"
647+ assert restored ._sender is None
648+ assert restored ._receiver is None
649+ assert not restored ._initialized_on_sender
650+ assert not restored ._initialized_on_worker
651+
652+ def test_multiprocess_scheme_serialize_after_sender_init (self ):
653+ """Test that initialized sender can be pickled (excluding runtime state)."""
654+ parent_pipe , child_pipe = mp .Pipe ()
655+
656+ scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
657+ scheme .init_on_sender (model_id = "policy" , pipes = [parent_pipe ])
658+
659+ # Scheme now has _sender with non-serializable pipes
660+ assert scheme ._sender is not None
661+ assert scheme ._initialized_on_sender
662+
663+ # Serialize and deserialize
664+ pickled = pickle .dumps (scheme )
665+ restored = pickle .loads (pickled )
666+
667+ # Check that configuration is preserved but runtime state is cleared
668+ assert restored .strategy == "state_dict"
669+ assert restored ._sender is None # Runtime state excluded
670+ assert restored ._receiver is None
671+ assert not restored ._initialized_on_sender # Reset
672+ assert not restored ._initialized_on_worker
673+
674+ # Clean up
675+ parent_pipe .close ()
676+ child_pipe .close ()
677+
678+ def test_shared_mem_scheme_serialize_before_init (self ):
679+ """Test that uninitialized SharedMemWeightSyncScheme can be pickled."""
680+ scheme = SharedMemWeightSyncScheme (strategy = "tensordict" , auto_register = True )
681+
682+ # Serialize and deserialize
683+ pickled = pickle .dumps (scheme )
684+ restored = pickle .loads (pickled )
685+
686+ # Check that configuration is preserved
687+ assert restored .strategy == "tensordict"
688+ assert restored ._sender is None
689+ assert restored ._receiver is None
690+
691+ def test_shared_mem_scheme_serialize_after_init (self ):
692+ """Test that initialized SharedMemWeightSyncScheme can be pickled."""
693+ parent_pipe , child_pipe = mp .Pipe ()
694+
695+ # Create shared buffer
696+ shared_buffer = TensorDict (
697+ {"weight" : torch .zeros (2 , 4 ), "bias" : torch .zeros (2 )}, batch_size = []
698+ ).share_memory_ ()
699+
700+ scheme = SharedMemWeightSyncScheme (
701+ policy_weights = {"policy" : shared_buffer },
702+ strategy = "tensordict" ,
703+ auto_register = False ,
704+ )
705+ scheme .init_on_sender (model_id = "policy" , pipes = [parent_pipe ])
706+
707+ # Scheme now has _sender with non-serializable state
708+ assert scheme ._sender is not None
709+
710+ # Serialize and deserialize
711+ pickled = pickle .dumps (scheme )
712+ restored = pickle .loads (pickled )
713+
714+ # Check that configuration is preserved but runtime state is cleared
715+ assert restored .strategy == "tensordict"
716+ assert restored ._sender is None
717+ assert not restored ._initialized_on_sender
718+
719+ # Note: policy_weights dict is preserved (but may need re-sharing)
720+ assert "policy" in restored .policy_weights
721+
722+ # Clean up
723+ parent_pipe .close ()
724+ child_pipe .close ()
725+
726+ def test_no_weight_sync_scheme_serialize (self ):
727+ """Test that NoWeightSyncScheme can be pickled."""
728+ scheme = NoWeightSyncScheme ()
729+ scheme .init_on_sender (model_id = "policy" )
730+
731+ # Serialize and deserialize
732+ pickled = pickle .dumps (scheme )
733+ restored = pickle .loads (pickled )
734+
735+ # Check that it's still a no-op scheme
736+ assert restored ._sender is None
737+ assert restored ._receiver is None
738+
739+ @pytest .mark .skipif (
740+ not torch .distributed .is_available (), reason = "torch.distributed not available"
741+ )
742+ def test_distributed_scheme_serialize_before_init (self ):
743+ """Test that uninitialized DistributedWeightSyncScheme can be pickled."""
744+
745+ scheme = DistributedWeightSyncScheme (backend = "gloo" , sync = True )
746+
747+ # Serialize and deserialize
748+ pickled = pickle .dumps (scheme )
749+ restored = pickle .loads (pickled )
750+
751+ # Check that configuration is preserved
752+ assert restored .backend == "gloo"
753+ assert restored .sync is True
754+ assert restored ._sender is None
755+ assert restored ._receiver is None
756+
757+ @pytest .mark .skipif (not _has_ray , reason = "Ray not available" )
758+ def test_ray_weight_sync_scheme_serialize_before_init (self ):
759+ """Test that uninitialized RayWeightSyncScheme can be pickled."""
760+ scheme = RayWeightSyncScheme (strategy = "state_dict" )
761+
762+ # Serialize and deserialize
763+ pickled = pickle .dumps (scheme )
764+ restored = pickle .loads (pickled )
765+
766+ # Check that configuration is preserved
767+ assert restored .strategy == "state_dict"
768+ assert restored ._sender is None
769+ assert restored ._receiver is None
770+
771+ @pytest .mark .skipif (not _has_ray , reason = "Ray not available" )
772+ def test_ray_module_transform_scheme_serialize_before_init (self ):
773+ """Test that uninitialized RayModuleTransformScheme can be pickled."""
774+
775+ scheme = RayModuleTransformScheme (strategy = "tensordict" )
776+
777+ # Serialize and deserialize
778+ pickled = pickle .dumps (scheme )
779+ restored = pickle .loads (pickled )
780+
781+ # Check that configuration is preserved
782+ assert restored .strategy == "tensordict"
783+ assert restored ._sender is None
784+ assert restored ._receiver is None
785+
786+ @pytest .mark .skipif (
787+ not torch .distributed .is_available (), reason = "torch.distributed not available"
788+ )
789+ def test_rpc_weight_sync_scheme_serialize_before_init (self ):
790+ """Test that uninitialized RPCWeightSyncScheme can be pickled."""
791+
792+ scheme = RPCWeightSyncScheme (strategy = "state_dict" )
793+
794+ # Serialize and deserialize
795+ pickled = pickle .dumps (scheme )
796+ restored = pickle .loads (pickled )
797+
798+ # Check that configuration is preserved
799+ assert restored .strategy == "state_dict"
800+ assert restored ._sender is None
801+ assert restored ._receiver is None
802+
803+ def test_scheme_reinitialization_after_unpickle (self ):
804+ """Test that a scheme can be re-initialized after unpickling.
805+
806+ This is the expected workflow: pickle a scheme, unpickle it in a worker,
807+ then call init_on_worker() to establish new runtime resources.
808+ """
809+ # Initialize and pickle a scheme
810+ parent_pipe , child_pipe = mp .Pipe ()
811+
812+ scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
813+ scheme .init_on_sender (model_id = "policy" , pipes = [parent_pipe ])
814+
815+ pickled = pickle .dumps (scheme )
816+
817+ # Clean up original
818+ parent_pipe .close ()
819+
820+ # Unpickle and re-initialize
821+ restored = pickle .loads (pickled )
822+
823+ # Should be able to initialize again with new pipes
824+ new_parent , new_child = mp .Pipe ()
825+
826+ # Re-initialize on sender
827+ restored .init_on_sender (model_id = "policy" , pipes = [new_parent ])
828+ sender = restored .get_sender ()
829+
830+ assert sender is not None
831+ assert restored ._initialized_on_sender
832+
833+ # Clean up
834+ new_parent .close ()
835+ new_child .close ()
836+ child_pipe .close ()
837+
838+
571839if __name__ == "__main__" :
572840 args , unknown = argparse .ArgumentParser ().parse_known_args ()
573841 pytest .main ([__file__ , "--capture" , "no" , "--exitfirst" , "-v" ] + unknown )
0 commit comments