Skip to content

Commit 3ac1ee0

Browse files
committed
[Refactor] Weight sync schemes refactor
ghstack-source-id: ec75dca Pull-Request: #3230
1 parent 7fccac8 commit 3ac1ee0

File tree

4 files changed

+1396
-246
lines changed

4 files changed

+1396
-246
lines changed

test/test_weightsync.py

Lines changed: 294 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from __future__ import annotations
66

77
import argparse
8+
import importlib.util
9+
import pickle
10+
import time
811

912
import pytest
1013
import torch
@@ -16,14 +19,20 @@
1619
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
1720
from torchrl.weight_update.weight_sync_schemes import (
1821
_resolve_model,
22+
DistributedWeightSyncScheme,
1923
MPTransport,
2024
MultiProcessWeightSyncScheme,
2125
NoWeightSyncScheme,
26+
RayModuleTransformScheme,
27+
RayWeightSyncScheme,
28+
RPCWeightSyncScheme,
2229
SharedMemTransport,
2330
SharedMemWeightSyncScheme,
2431
WeightStrategy,
2532
)
2633

34+
_has_ray = importlib.util.find_spec("ray") is not None
35+
2736

2837
def worker_update_policy(pipe, timeout=5.0):
2938
policy = nn.Linear(4, 2)
@@ -32,9 +41,8 @@ def worker_update_policy(pipe, timeout=5.0):
3241
policy.bias.fill_(0.0)
3342

3443
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
35-
receiver = scheme.create_receiver()
36-
receiver.register_model(policy)
37-
receiver.register_worker_transport(pipe)
44+
scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy)
45+
receiver = scheme.get_receiver()
3846

3947
if receiver._transport.pipe.poll(timeout):
4048
data, msg = receiver._transport.pipe.recv()
@@ -52,9 +60,8 @@ def worker_update_policy_tensordict(pipe, timeout=5.0):
5260
policy.bias.fill_(0.0)
5361

5462
scheme = MultiProcessWeightSyncScheme(strategy="tensordict")
55-
receiver = scheme.create_receiver()
56-
receiver.register_model(policy)
57-
receiver.register_worker_transport(pipe)
63+
scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy)
64+
receiver = scheme.get_receiver()
5865

5966
if receiver._transport.pipe.poll(timeout):
6067
data, msg = receiver._transport.pipe.recv()
@@ -75,8 +82,6 @@ def worker_shared_mem(pipe, timeout=10.0):
7582
shared_weights.to_module(policy)
7683
pipe.send((None, "registered"))
7784

78-
import time
79-
8085
time.sleep(0.5)
8186

8287
return policy.weight.sum().item(), policy.bias.sum().item()
@@ -192,39 +197,46 @@ def test_cross_format_conversion(self):
192197

193198

194199
class TestWeightSyncSchemes:
200+
"""Tests for weight sync schemes using the new simplified API.
201+
202+
Lower-level transport and legacy API tests are in TestTransportBackends.
203+
"""
204+
195205
def test_multiprocess_scheme_state_dict(self):
196206
parent_pipe, child_pipe = mp.Pipe()
197207

198208
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
199-
sender = scheme.create_sender()
200-
sender.register_worker(0, parent_pipe)
209+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
210+
sender = scheme.get_sender()
201211

202212
proc = mp.Process(target=worker_update_policy, args=(child_pipe,))
203-
proc.start()
213+
try:
214+
proc.start()
204215

205-
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
206-
sender.update_weights(weights)
207-
208-
proc.join(timeout=10.0)
209-
assert not proc.is_alive()
216+
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
217+
sender.send(weights)
218+
finally:
219+
proc.join(timeout=10.0)
220+
assert not proc.is_alive()
210221

211222
def test_multiprocess_scheme_tensordict(self):
212223
parent_pipe, child_pipe = mp.Pipe()
213224

214225
scheme = MultiProcessWeightSyncScheme(strategy="tensordict")
215-
sender = scheme.create_sender()
216-
sender.register_worker(0, parent_pipe)
226+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
227+
sender = scheme.get_sender()
217228

218229
proc = mp.Process(target=worker_update_policy_tensordict, args=(child_pipe,))
219-
proc.start()
230+
try:
231+
proc.start()
220232

221-
weights = TensorDict(
222-
{"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[]
223-
)
224-
sender.update_weights(weights)
225-
226-
proc.join(timeout=10.0)
227-
assert not proc.is_alive()
233+
weights = TensorDict(
234+
{"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[]
235+
)
236+
sender.send(weights)
237+
finally:
238+
proc.join(timeout=10.0)
239+
assert not proc.is_alive()
228240

229241
def test_shared_mem_scheme(self):
230242
shared_buffer = TensorDict(
@@ -270,6 +282,51 @@ def test_no_weight_sync_scheme(self):
270282
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
271283
transport.send_weights("policy", weights)
272284

285+
@classmethod
286+
def _worker_with_receive(cls, pipe, scheme):
287+
policy = nn.Linear(4, 2)
288+
with torch.no_grad():
289+
policy.weight.fill_(0.0)
290+
policy.bias.fill_(0.0)
291+
292+
scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy)
293+
receiver = scheme.get_receiver()
294+
295+
# Non-blocking receive should return False when no data
296+
result = receiver.receive(timeout=0.001)
297+
assert result is False
298+
299+
# Now actually receive the weights
300+
result = receiver.receive(timeout=5.0)
301+
assert result is True
302+
303+
# Check weights were applied
304+
return policy.weight.sum().item(), policy.bias.sum().item()
305+
306+
def test_receiver_receive_method(self):
307+
"""Test the new non-blocking receive() method."""
308+
309+
parent_pipe, child_pipe = mp.Pipe()
310+
311+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
312+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
313+
sender = scheme.get_sender()
314+
315+
proc = mp.Process(target=self._worker_with_receive, args=(child_pipe, scheme))
316+
try:
317+
proc.start()
318+
319+
# Give worker time to call receive with no data
320+
321+
time.sleep(0.1)
322+
323+
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
324+
sender.send(weights)
325+
326+
finally:
327+
proc.join(timeout=10.0)
328+
assert not proc.is_alive()
329+
273330

274331
class TestCollectorIntegration:
275332
@pytest.fixture
@@ -568,6 +625,217 @@ def test_weight_strategy_parametrized(strategy):
568625
assert torch.allclose(policy.bias, target.bias)
569626

570627

628+
class TestSerializeScheme:
629+
"""Test that WeightSyncScheme instances can be serialized after initialization.
630+
631+
This is critical for multiprocessing and Ray, where schemes may be pickled
632+
and sent across process boundaries. The _sender and _receiver attributes
633+
contain non-serializable objects (pipes, weak references, etc.) and must
634+
be excluded from serialization.
635+
"""
636+
637+
def test_multiprocess_scheme_serialize_before_init(self):
638+
"""Test that uninitialized scheme can be pickled."""
639+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
640+
641+
# Serialize and deserialize
642+
pickled = pickle.dumps(scheme)
643+
restored = pickle.loads(pickled)
644+
645+
# Check that configuration is preserved
646+
assert restored.strategy == "state_dict"
647+
assert restored._sender is None
648+
assert restored._receiver is None
649+
assert not restored._initialized_on_sender
650+
assert not restored._initialized_on_worker
651+
652+
def test_multiprocess_scheme_serialize_after_sender_init(self):
653+
"""Test that initialized sender can be pickled (excluding runtime state)."""
654+
parent_pipe, child_pipe = mp.Pipe()
655+
656+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
657+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
658+
659+
# Scheme now has _sender with non-serializable pipes
660+
assert scheme._sender is not None
661+
assert scheme._initialized_on_sender
662+
663+
# Serialize and deserialize
664+
pickled = pickle.dumps(scheme)
665+
restored = pickle.loads(pickled)
666+
667+
# Check that configuration is preserved but runtime state is cleared
668+
assert restored.strategy == "state_dict"
669+
assert restored._sender is None # Runtime state excluded
670+
assert restored._receiver is None
671+
assert not restored._initialized_on_sender # Reset
672+
assert not restored._initialized_on_worker
673+
674+
# Clean up
675+
parent_pipe.close()
676+
child_pipe.close()
677+
678+
def test_shared_mem_scheme_serialize_before_init(self):
679+
"""Test that uninitialized SharedMemWeightSyncScheme can be pickled."""
680+
scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
681+
682+
# Serialize and deserialize
683+
pickled = pickle.dumps(scheme)
684+
restored = pickle.loads(pickled)
685+
686+
# Check that configuration is preserved
687+
assert restored.strategy == "tensordict"
688+
assert restored._sender is None
689+
assert restored._receiver is None
690+
691+
def test_shared_mem_scheme_serialize_after_init(self):
692+
"""Test that initialized SharedMemWeightSyncScheme can be pickled."""
693+
parent_pipe, child_pipe = mp.Pipe()
694+
695+
# Create shared buffer
696+
shared_buffer = TensorDict(
697+
{"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[]
698+
).share_memory_()
699+
700+
scheme = SharedMemWeightSyncScheme(
701+
policy_weights={"policy": shared_buffer},
702+
strategy="tensordict",
703+
auto_register=False,
704+
)
705+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
706+
707+
# Scheme now has _sender with non-serializable state
708+
assert scheme._sender is not None
709+
710+
# Serialize and deserialize
711+
pickled = pickle.dumps(scheme)
712+
restored = pickle.loads(pickled)
713+
714+
# Check that configuration is preserved but runtime state is cleared
715+
assert restored.strategy == "tensordict"
716+
assert restored._sender is None
717+
assert not restored._initialized_on_sender
718+
719+
# Note: policy_weights dict is preserved (but may need re-sharing)
720+
assert "policy" in restored.policy_weights
721+
722+
# Clean up
723+
parent_pipe.close()
724+
child_pipe.close()
725+
726+
def test_no_weight_sync_scheme_serialize(self):
727+
"""Test that NoWeightSyncScheme can be pickled."""
728+
scheme = NoWeightSyncScheme()
729+
scheme.init_on_sender(model_id="policy")
730+
731+
# Serialize and deserialize
732+
pickled = pickle.dumps(scheme)
733+
restored = pickle.loads(pickled)
734+
735+
# Check that it's still a no-op scheme
736+
assert restored._sender is None
737+
assert restored._receiver is None
738+
739+
@pytest.mark.skipif(
740+
not torch.distributed.is_available(), reason="torch.distributed not available"
741+
)
742+
def test_distributed_scheme_serialize_before_init(self):
743+
"""Test that uninitialized DistributedWeightSyncScheme can be pickled."""
744+
745+
scheme = DistributedWeightSyncScheme(backend="gloo", sync=True)
746+
747+
# Serialize and deserialize
748+
pickled = pickle.dumps(scheme)
749+
restored = pickle.loads(pickled)
750+
751+
# Check that configuration is preserved
752+
assert restored.backend == "gloo"
753+
assert restored.sync is True
754+
assert restored._sender is None
755+
assert restored._receiver is None
756+
757+
@pytest.mark.skipif(not _has_ray, reason="Ray not available")
758+
def test_ray_weight_sync_scheme_serialize_before_init(self):
759+
"""Test that uninitialized RayWeightSyncScheme can be pickled."""
760+
scheme = RayWeightSyncScheme(strategy="state_dict")
761+
762+
# Serialize and deserialize
763+
pickled = pickle.dumps(scheme)
764+
restored = pickle.loads(pickled)
765+
766+
# Check that configuration is preserved
767+
assert restored.strategy == "state_dict"
768+
assert restored._sender is None
769+
assert restored._receiver is None
770+
771+
@pytest.mark.skipif(not _has_ray, reason="Ray not available")
772+
def test_ray_module_transform_scheme_serialize_before_init(self):
773+
"""Test that uninitialized RayModuleTransformScheme can be pickled."""
774+
775+
scheme = RayModuleTransformScheme(strategy="tensordict")
776+
777+
# Serialize and deserialize
778+
pickled = pickle.dumps(scheme)
779+
restored = pickle.loads(pickled)
780+
781+
# Check that configuration is preserved
782+
assert restored.strategy == "tensordict"
783+
assert restored._sender is None
784+
assert restored._receiver is None
785+
786+
@pytest.mark.skipif(
787+
not torch.distributed.is_available(), reason="torch.distributed not available"
788+
)
789+
def test_rpc_weight_sync_scheme_serialize_before_init(self):
790+
"""Test that uninitialized RPCWeightSyncScheme can be pickled."""
791+
792+
scheme = RPCWeightSyncScheme(strategy="state_dict")
793+
794+
# Serialize and deserialize
795+
pickled = pickle.dumps(scheme)
796+
restored = pickle.loads(pickled)
797+
798+
# Check that configuration is preserved
799+
assert restored.strategy == "state_dict"
800+
assert restored._sender is None
801+
assert restored._receiver is None
802+
803+
def test_scheme_reinitialization_after_unpickle(self):
804+
"""Test that a scheme can be re-initialized after unpickling.
805+
806+
This is the expected workflow: pickle a scheme, unpickle it in a worker,
807+
then call init_on_worker() to establish new runtime resources.
808+
"""
809+
# Initialize and pickle a scheme
810+
parent_pipe, child_pipe = mp.Pipe()
811+
812+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
813+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
814+
815+
pickled = pickle.dumps(scheme)
816+
817+
# Clean up original
818+
parent_pipe.close()
819+
820+
# Unpickle and re-initialize
821+
restored = pickle.loads(pickled)
822+
823+
# Should be able to initialize again with new pipes
824+
new_parent, new_child = mp.Pipe()
825+
826+
# Re-initialize on sender
827+
restored.init_on_sender(model_id="policy", pipes=[new_parent])
828+
sender = restored.get_sender()
829+
830+
assert sender is not None
831+
assert restored._initialized_on_sender
832+
833+
# Clean up
834+
new_parent.close()
835+
new_child.close()
836+
child_pipe.close()
837+
838+
571839
if __name__ == "__main__":
572840
args, unknown = argparse.ArgumentParser().parse_known_args()
573841
pytest.main([__file__, "--capture", "no", "--exitfirst", "-v"] + unknown)

0 commit comments

Comments
 (0)