Skip to content

Commit 3852c9f

Browse files
committed
Update
[ghstack-poisoned]
1 parent c858a0a commit 3852c9f

File tree

3 files changed

+488
-84
lines changed

3 files changed

+488
-84
lines changed

test/test_weightsync.py

Lines changed: 264 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from __future__ import annotations
66

77
import argparse
8+
import importlib.util
9+
import pickle
10+
import time
811

912
import pytest
1013
import torch
@@ -16,14 +19,20 @@
1619
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
1720
from torchrl.weight_update.weight_sync_schemes import (
1821
_resolve_model,
22+
DistributedWeightSyncScheme,
1923
MPTransport,
2024
MultiProcessWeightSyncScheme,
2125
NoWeightSyncScheme,
26+
RayModuleTransformScheme,
27+
RayWeightSyncScheme,
28+
RPCWeightSyncScheme,
2229
SharedMemTransport,
2330
SharedMemWeightSyncScheme,
2431
WeightStrategy,
2532
)
2633

34+
_has_ray = importlib.util.find_spec("ray") is not None
35+
2736

2837
def worker_update_policy(pipe, timeout=5.0):
2938
policy = nn.Linear(4, 2)
@@ -73,8 +82,6 @@ def worker_shared_mem(pipe, timeout=10.0):
7382
shared_weights.to_module(policy)
7483
pipe.send((None, "registered"))
7584

76-
import time
77-
7885
time.sleep(0.5)
7986

8087
return policy.weight.sum().item(), policy.bias.sum().item()
@@ -203,13 +210,14 @@ def test_multiprocess_scheme_state_dict(self):
203210
sender = scheme.get_sender()
204211

205212
proc = mp.Process(target=worker_update_policy, args=(child_pipe,))
206-
proc.start()
207-
208-
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
209-
sender.send(weights)
213+
try:
214+
proc.start()
210215

211-
proc.join(timeout=10.0)
212-
assert not proc.is_alive()
216+
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
217+
sender.send(weights)
218+
finally:
219+
proc.join(timeout=10.0)
220+
assert not proc.is_alive()
213221

214222
def test_multiprocess_scheme_tensordict(self):
215223
parent_pipe, child_pipe = mp.Pipe()
@@ -219,15 +227,16 @@ def test_multiprocess_scheme_tensordict(self):
219227
sender = scheme.get_sender()
220228

221229
proc = mp.Process(target=worker_update_policy_tensordict, args=(child_pipe,))
222-
proc.start()
230+
try:
231+
proc.start()
223232

224-
weights = TensorDict(
225-
{"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[]
226-
)
227-
sender.send(weights)
228-
229-
proc.join(timeout=10.0)
230-
assert not proc.is_alive()
233+
weights = TensorDict(
234+
{"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[]
235+
)
236+
sender.send(weights)
237+
finally:
238+
proc.join(timeout=10.0)
239+
assert not proc.is_alive()
231240

232241
def test_shared_mem_scheme(self):
233242
shared_buffer = TensorDict(
@@ -273,49 +282,50 @@ def test_no_weight_sync_scheme(self):
273282
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
274283
transport.send_weights("policy", weights)
275284

276-
def test_receiver_receive_method(self):
277-
"""Test the new non-blocking receive() method."""
285+
@classmethod
286+
def _worker_with_receive(cls, pipe, scheme):
287+
policy = nn.Linear(4, 2)
288+
with torch.no_grad():
289+
policy.weight.fill_(0.0)
290+
policy.bias.fill_(0.0)
278291

279-
def worker_with_receive(pipe):
280-
policy = nn.Linear(4, 2)
281-
with torch.no_grad():
282-
policy.weight.fill_(0.0)
283-
policy.bias.fill_(0.0)
292+
scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy)
293+
receiver = scheme.get_receiver()
284294

285-
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
286-
scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy)
287-
receiver = scheme.get_receiver()
295+
# Non-blocking receive should return False when no data
296+
result = receiver.receive(timeout=0.001)
297+
assert result is False
288298

289-
# Non-blocking receive should return False when no data
290-
result = receiver.receive(timeout=0.001)
291-
assert result is False
299+
# Now actually receive the weights
300+
result = receiver.receive(timeout=5.0)
301+
assert result is True
292302

293-
# Now actually receive the weights
294-
result = receiver.receive(timeout=5.0)
295-
assert result is True
303+
# Check weights were applied
304+
return policy.weight.sum().item(), policy.bias.sum().item()
296305

297-
# Check weights were applied
298-
return policy.weight.sum().item(), policy.bias.sum().item()
306+
def test_receiver_receive_method(self):
307+
"""Test the new non-blocking receive() method."""
299308

300309
parent_pipe, child_pipe = mp.Pipe()
301310

302311
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
303312
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
304313
sender = scheme.get_sender()
305314

306-
proc = mp.Process(target=worker_with_receive, args=(child_pipe,))
307-
proc.start()
315+
proc = mp.Process(target=self._worker_with_receive, args=(child_pipe, scheme))
316+
try:
317+
proc.start()
308318

309-
# Give worker time to call receive with no data
310-
import time
319+
# Give worker time to call receive with no data
311320

312-
time.sleep(0.1)
321+
time.sleep(0.1)
313322

314-
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
315-
sender.send(weights)
323+
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
324+
sender.send(weights)
316325

317-
proc.join(timeout=10.0)
318-
assert not proc.is_alive()
326+
finally:
327+
proc.join(timeout=10.0)
328+
assert not proc.is_alive()
319329

320330

321331
class TestCollectorIntegration:
@@ -615,6 +625,217 @@ def test_weight_strategy_parametrized(strategy):
615625
assert torch.allclose(policy.bias, target.bias)
616626

617627

628+
class TestSerializeScheme:
629+
"""Test that WeightSyncScheme instances can be serialized after initialization.
630+
631+
This is critical for multiprocessing and Ray, where schemes may be pickled
632+
and sent across process boundaries. The _sender and _receiver attributes
633+
contain non-serializable objects (pipes, weak references, etc.) and must
634+
be excluded from serialization.
635+
"""
636+
637+
def test_multiprocess_scheme_serialize_before_init(self):
638+
"""Test that uninitialized scheme can be pickled."""
639+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
640+
641+
# Serialize and deserialize
642+
pickled = pickle.dumps(scheme)
643+
restored = pickle.loads(pickled)
644+
645+
# Check that configuration is preserved
646+
assert restored.strategy == "state_dict"
647+
assert restored._sender is None
648+
assert restored._receiver is None
649+
assert not restored._initialized_on_sender
650+
assert not restored._initialized_on_worker
651+
652+
def test_multiprocess_scheme_serialize_after_sender_init(self):
653+
"""Test that initialized sender can be pickled (excluding runtime state)."""
654+
parent_pipe, child_pipe = mp.Pipe()
655+
656+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
657+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
658+
659+
# Scheme now has _sender with non-serializable pipes
660+
assert scheme._sender is not None
661+
assert scheme._initialized_on_sender
662+
663+
# Serialize and deserialize
664+
pickled = pickle.dumps(scheme)
665+
restored = pickle.loads(pickled)
666+
667+
# Check that configuration is preserved but runtime state is cleared
668+
assert restored.strategy == "state_dict"
669+
assert restored._sender is None # Runtime state excluded
670+
assert restored._receiver is None
671+
assert not restored._initialized_on_sender # Reset
672+
assert not restored._initialized_on_worker
673+
674+
# Clean up
675+
parent_pipe.close()
676+
child_pipe.close()
677+
678+
def test_shared_mem_scheme_serialize_before_init(self):
679+
"""Test that uninitialized SharedMemWeightSyncScheme can be pickled."""
680+
scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
681+
682+
# Serialize and deserialize
683+
pickled = pickle.dumps(scheme)
684+
restored = pickle.loads(pickled)
685+
686+
# Check that configuration is preserved
687+
assert restored.strategy == "tensordict"
688+
assert restored._sender is None
689+
assert restored._receiver is None
690+
691+
def test_shared_mem_scheme_serialize_after_init(self):
692+
"""Test that initialized SharedMemWeightSyncScheme can be pickled."""
693+
parent_pipe, child_pipe = mp.Pipe()
694+
695+
# Create shared buffer
696+
shared_buffer = TensorDict(
697+
{"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[]
698+
).share_memory_()
699+
700+
scheme = SharedMemWeightSyncScheme(
701+
policy_weights={"policy": shared_buffer},
702+
strategy="tensordict",
703+
auto_register=False,
704+
)
705+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
706+
707+
# Scheme now has _sender with non-serializable state
708+
assert scheme._sender is not None
709+
710+
# Serialize and deserialize
711+
pickled = pickle.dumps(scheme)
712+
restored = pickle.loads(pickled)
713+
714+
# Check that configuration is preserved but runtime state is cleared
715+
assert restored.strategy == "tensordict"
716+
assert restored._sender is None
717+
assert not restored._initialized_on_sender
718+
719+
# Note: policy_weights dict is preserved (but may need re-sharing)
720+
assert "policy" in restored.policy_weights
721+
722+
# Clean up
723+
parent_pipe.close()
724+
child_pipe.close()
725+
726+
def test_no_weight_sync_scheme_serialize(self):
727+
"""Test that NoWeightSyncScheme can be pickled."""
728+
scheme = NoWeightSyncScheme()
729+
scheme.init_on_sender(model_id="policy")
730+
731+
# Serialize and deserialize
732+
pickled = pickle.dumps(scheme)
733+
restored = pickle.loads(pickled)
734+
735+
# Check that it's still a no-op scheme
736+
assert restored._sender is None
737+
assert restored._receiver is None
738+
739+
@pytest.mark.skipif(
740+
not torch.distributed.is_available(), reason="torch.distributed not available"
741+
)
742+
def test_distributed_scheme_serialize_before_init(self):
743+
"""Test that uninitialized DistributedWeightSyncScheme can be pickled."""
744+
745+
scheme = DistributedWeightSyncScheme(backend="gloo", sync=True)
746+
747+
# Serialize and deserialize
748+
pickled = pickle.dumps(scheme)
749+
restored = pickle.loads(pickled)
750+
751+
# Check that configuration is preserved
752+
assert restored.backend == "gloo"
753+
assert restored.sync is True
754+
assert restored._sender is None
755+
assert restored._receiver is None
756+
757+
@pytest.mark.skipif(not _has_ray, reason="Ray not available")
758+
def test_ray_weight_sync_scheme_serialize_before_init(self):
759+
"""Test that uninitialized RayWeightSyncScheme can be pickled."""
760+
scheme = RayWeightSyncScheme(strategy="state_dict")
761+
762+
# Serialize and deserialize
763+
pickled = pickle.dumps(scheme)
764+
restored = pickle.loads(pickled)
765+
766+
# Check that configuration is preserved
767+
assert restored.strategy == "state_dict"
768+
assert restored._sender is None
769+
assert restored._receiver is None
770+
771+
@pytest.mark.skipif(not _has_ray, reason="Ray not available")
772+
def test_ray_module_transform_scheme_serialize_before_init(self):
773+
"""Test that uninitialized RayModuleTransformScheme can be pickled."""
774+
775+
scheme = RayModuleTransformScheme(strategy="tensordict")
776+
777+
# Serialize and deserialize
778+
pickled = pickle.dumps(scheme)
779+
restored = pickle.loads(pickled)
780+
781+
# Check that configuration is preserved
782+
assert restored.strategy == "tensordict"
783+
assert restored._sender is None
784+
assert restored._receiver is None
785+
786+
@pytest.mark.skipif(
787+
not torch.distributed.is_available(), reason="torch.distributed not available"
788+
)
789+
def test_rpc_weight_sync_scheme_serialize_before_init(self):
790+
"""Test that uninitialized RPCWeightSyncScheme can be pickled."""
791+
792+
scheme = RPCWeightSyncScheme(strategy="state_dict")
793+
794+
# Serialize and deserialize
795+
pickled = pickle.dumps(scheme)
796+
restored = pickle.loads(pickled)
797+
798+
# Check that configuration is preserved
799+
assert restored.strategy == "state_dict"
800+
assert restored._sender is None
801+
assert restored._receiver is None
802+
803+
def test_scheme_reinitialization_after_unpickle(self):
804+
"""Test that a scheme can be re-initialized after unpickling.
805+
806+
This is the expected workflow: pickle a scheme, unpickle it in a worker,
807+
then call init_on_worker() to establish new runtime resources.
808+
"""
809+
# Initialize and pickle a scheme
810+
parent_pipe, child_pipe = mp.Pipe()
811+
812+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
813+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
814+
815+
pickled = pickle.dumps(scheme)
816+
817+
# Clean up original
818+
parent_pipe.close()
819+
820+
# Unpickle and re-initialize
821+
restored = pickle.loads(pickled)
822+
823+
# Should be able to initialize again with new pipes
824+
new_parent, new_child = mp.Pipe()
825+
826+
# Re-initialize on sender
827+
restored.init_on_sender(model_id="policy", pipes=[new_parent])
828+
sender = restored.get_sender()
829+
830+
assert sender is not None
831+
assert restored._initialized_on_sender
832+
833+
# Clean up
834+
new_parent.close()
835+
new_child.close()
836+
child_pipe.close()
837+
838+
618839
if __name__ == "__main__":
619840
args, unknown = argparse.ArgumentParser().parse_known_args()
620841
pytest.main([__file__, "--capture", "no", "--exitfirst", "-v"] + unknown)

0 commit comments

Comments
 (0)