Skip to content

Commit 44cfcca

Browse files
committed
Update
[ghstack-poisoned]
1 parent 00baa00 commit 44cfcca

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

test/test_weightsync.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,29 @@ def test_shared_mem_scheme_serialize_after_init(self):
702702
strategy="tensordict",
703703
auto_register=False,
704704
)
705-
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
705+
706+
def init_on_sender(scheme, child_pipe):
707+
(model_id, data), msg = child_pipe.recv()
708+
if msg == "register_shared_weights":
709+
child_pipe.send((None, "registered"))
710+
else:
711+
raise ValueError(f"Expected 'register_shared_weights' but got {msg}")
712+
713+
# Initialize the scheme with the pipes, in 2 separate threads because init requires acknowledgement from the worker
714+
import threading
715+
716+
future_sender = threading.Thread(
717+
target=scheme.init_on_sender,
718+
kwargs={"model_id": "policy", "pipes": [parent_pipe]},
719+
)
720+
future_receiver = threading.Thread(
721+
target=init_on_sender,
722+
kwargs={"scheme": scheme, "child_pipe": child_pipe},
723+
)
724+
future_receiver.start()
725+
future_sender.start()
726+
future_receiver.join()
727+
future_sender.join()
706728

707729
# Scheme now has _sender with non-serializable state
708730
assert scheme._sender is not None

torchrl/weight_update/weight_sync_schemes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,9 @@ def register_weights(self, model_id: str, weights: TensorDictBase) -> None:
202202
f"Model '{model_id}' has already been registered with workers."
203203
)
204204

205-
def _send_buffer_to_workers(self, model_id: str, buffer: TensorDictBase) -> None:
205+
def _send_buffer_to_workers(
206+
self, model_id: str, buffer: TensorDictBase, timeout: float = 10.0
207+
) -> None:
206208
"""Send shared memory buffer reference to all workers via pipes.
207209
208210
This is called once per model_id when lazy registration occurs.
@@ -219,6 +221,8 @@ def _send_buffer_to_workers(self, model_id: str, buffer: TensorDictBase) -> None
219221

220222
# Wait for acknowledgments from all workers
221223
for pipe in self._pipes:
224+
if not pipe.poll(timeout):
225+
raise TimeoutError(f"Timeout waiting for acknowledgment from worker")
222226
_, msg = pipe.recv()
223227
if msg != "registered":
224228
raise RuntimeError(f"Expected 'registered' acknowledgment, got '{msg}'")

0 commit comments

Comments
 (0)