Skip to content

Commit 709bec7

Browse files
committed
[Refactor] Weight sync schemes refactor
ghstack-source-id: f44e83d Pull-Request: #3230
1 parent 509e77d commit 709bec7

File tree

3 files changed

+711
-103
lines changed

3 files changed

+711
-103
lines changed

test/test_weightsync.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def worker_update_policy(pipe, timeout=5.0):
3232
policy.bias.fill_(0.0)
3333

3434
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
35-
receiver = scheme.create_receiver()
36-
receiver.register_model(policy)
37-
receiver.register_worker_transport(pipe)
35+
# Use new API
36+
scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy)
37+
receiver = scheme.get_receiver()
3838

3939
if receiver._transport.pipe.poll(timeout):
4040
data, msg = receiver._transport.pipe.recv()
@@ -52,9 +52,9 @@ def worker_update_policy_tensordict(pipe, timeout=5.0):
5252
policy.bias.fill_(0.0)
5353

5454
scheme = MultiProcessWeightSyncScheme(strategy="tensordict")
55-
receiver = scheme.create_receiver()
56-
receiver.register_model(policy)
57-
receiver.register_worker_transport(pipe)
55+
# Use new API
56+
scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy)
57+
receiver = scheme.get_receiver()
5858

5959
if receiver._transport.pipe.poll(timeout):
6060
data, msg = receiver._transport.pipe.recv()
@@ -192,18 +192,24 @@ def test_cross_format_conversion(self):
192192

193193

194194
class TestWeightSyncSchemes:
195+
"""Tests for weight sync schemes using the new simplified API.
196+
197+
Lower-level transport and legacy API tests are in TestTransportBackends.
198+
"""
199+
195200
def test_multiprocess_scheme_state_dict(self):
196201
parent_pipe, child_pipe = mp.Pipe()
197202

198203
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
199-
sender = scheme.create_sender()
200-
sender.register_worker(0, parent_pipe)
204+
# Use new API
205+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
206+
sender = scheme.get_sender()
201207

202208
proc = mp.Process(target=worker_update_policy, args=(child_pipe,))
203209
proc.start()
204210

205211
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
206-
sender.update_weights(weights)
212+
sender.send(weights)
207213

208214
proc.join(timeout=10.0)
209215
assert not proc.is_alive()
@@ -212,16 +218,17 @@ def test_multiprocess_scheme_tensordict(self):
212218
parent_pipe, child_pipe = mp.Pipe()
213219

214220
scheme = MultiProcessWeightSyncScheme(strategy="tensordict")
215-
sender = scheme.create_sender()
216-
sender.register_worker(0, parent_pipe)
221+
# Use new API
222+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
223+
sender = scheme.get_sender()
217224

218225
proc = mp.Process(target=worker_update_policy_tensordict, args=(child_pipe,))
219226
proc.start()
220227

221228
weights = TensorDict(
222229
{"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[]
223230
)
224-
sender.update_weights(weights)
231+
sender.send(weights)
225232

226233
proc.join(timeout=10.0)
227234
assert not proc.is_alive()
@@ -270,6 +277,50 @@ def test_no_weight_sync_scheme(self):
270277
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
271278
transport.send_weights("policy", weights)
272279

280+
def test_receiver_receive_method(self):
281+
"""Test the new non-blocking receive() method."""
282+
283+
def worker_with_receive(pipe):
284+
policy = nn.Linear(4, 2)
285+
with torch.no_grad():
286+
policy.weight.fill_(0.0)
287+
policy.bias.fill_(0.0)
288+
289+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
290+
scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy)
291+
receiver = scheme.get_receiver()
292+
293+
# Non-blocking receive should return False when no data
294+
result = receiver.receive(timeout=0.001)
295+
assert result is False
296+
297+
# Now actually receive the weights
298+
result = receiver.receive(timeout=5.0)
299+
assert result is True
300+
301+
# Check weights were applied
302+
return policy.weight.sum().item(), policy.bias.sum().item()
303+
304+
parent_pipe, child_pipe = mp.Pipe()
305+
306+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
307+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
308+
sender = scheme.get_sender()
309+
310+
proc = mp.Process(target=worker_with_receive, args=(child_pipe,))
311+
proc.start()
312+
313+
# Give worker time to call receive with no data
314+
import time
315+
316+
time.sleep(0.1)
317+
318+
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
319+
sender.send(weights)
320+
321+
proc.join(timeout=10.0)
322+
assert not proc.is_alive()
323+
273324

274325
class TestCollectorIntegration:
275326
@pytest.fixture

torchrl/collectors/collectors.py

Lines changed: 111 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,10 @@ def _weight_update_impl(
471471
processed_weights = self._extract_weights_if_needed(
472472
weights, target_model_id
473473
)
474-
self._weight_senders[target_model_id].update_weights(processed_weights)
474+
# Use new send() API with worker_ids support
475+
self._weight_senders[target_model_id].send(
476+
weights=processed_weights, worker_ids=worker_ids
477+
)
475478
elif self._weight_updater is not None:
476479
# unreachable
477480
raise RuntimeError
@@ -2154,6 +2157,33 @@ def getattr_rb(self, attr):
21542157
# send command to rb to return the attr
21552158
return getattr(self.replay_buffer, attr)
21562159

2160+
def get_model(self, model_id: str):
2161+
"""Get model instance by ID (for weight sync schemes).
2162+
2163+
Args:
2164+
model_id: Model identifier (e.g., "policy", "value_net")
2165+
2166+
Returns:
2167+
The model instance
2168+
2169+
Raises:
2170+
ValueError: If model_id is not recognized
2171+
"""
2172+
if model_id == "policy":
2173+
# Return the wrapped policy instance
2174+
if hasattr(self, "_wrapped_policy") and self._wrapped_policy is not None:
2175+
return self._wrapped_policy
2176+
elif hasattr(self, "policy") and self.policy is not None:
2177+
return self.policy
2178+
else:
2179+
raise ValueError(f"No policy found for model_id '{model_id}'")
2180+
else:
2181+
# Try to resolve via attribute access
2182+
if hasattr(self, model_id):
2183+
return getattr(self, model_id)
2184+
else:
2185+
raise ValueError(f"Unknown model_id: {model_id}")
2186+
21572187

21582188
class _MultiDataCollector(DataCollectorBase):
21592189
"""Runs a given number of DataCollectors on separate processes.
@@ -2890,15 +2920,7 @@ def _run_processes(self) -> None:
28902920
1, torch.get_num_threads() - total_workers
28912921
) # 1 more thread for this proc
28922922

2893-
# Initialize weight senders for multiprocess collectors
2894-
if self._weight_sync_schemes:
2895-
# Create one sender per model using scheme's factory method
2896-
for model_id, scheme in self._weight_sync_schemes.items():
2897-
sender = scheme.create_sender()
2898-
sender._model_id = model_id
2899-
if hasattr(sender, "set_context"):
2900-
sender.set_context(self, model_id)
2901-
self._weight_senders[model_id] = sender
2923+
# Weight senders will be initialized after workers are ready (via init_on_sender)
29022924
torch.set_num_threads(self.num_threads)
29032925
queue_out = mp.Queue(self._queue_len) # sends data from proc to main
29042926
self.procs = []
@@ -3010,11 +3032,7 @@ def _run_processes(self) -> None:
30103032
self.procs.append(proc)
30113033
self.pipes.append(pipe_parent)
30123034

3013-
# Register worker with senders
3014-
if self._weight_senders:
3015-
for _, sender in self._weight_senders.items():
3016-
sender.register_worker(i, pipe_parent)
3017-
3035+
# Worker registration now handled by init_on_sender() after workers are ready
30183036
for i, pipe_parent in enumerate(self.pipes):
30193037
pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT)
30203038
try:
@@ -3066,30 +3084,20 @@ def _run_processes(self) -> None:
30663084
# Legacy string error message
30673085
raise RuntimeError(msg)
30683086

3069-
# For SharedMemWeightSyncScheme, pre-register shared weights now that workers are ready
3070-
# This avoids deadlock when workers are busy collecting and can't respond to registration messages
3087+
# Initialize all weight sync schemes now that workers are ready
3088+
# This calls init_on_sender() for each scheme which:
3089+
# 1. Creates transports for all workers
3090+
# 2. Creates and configures the sender
3091+
# 3. For SharedMemWeightSyncScheme, distributes buffer references to avoid deadlock
30713092
if self._weight_sync_schemes:
30723093
for model_id, scheme in self._weight_sync_schemes.items():
3073-
if isinstance(scheme, SharedMemWeightSyncScheme):
3074-
sender = self._weight_senders[model_id]
3075-
# Get the shared memory weights from _policy_weights_dict
3076-
# Use prepare_weights with None to trigger cache lookup
3077-
from torchrl.weight_update.weight_sync_schemes import _get_strategy
3078-
3079-
strategy = _get_strategy(scheme.strategy)
3080-
weights = scheme.prepare_weights(
3081-
weights=None,
3082-
model_id=model_id,
3083-
strategy=strategy,
3084-
context=self,
3085-
)
3086-
if weights is not None:
3087-
# Register the shared weights directly with each transport
3088-
# This ensures the transports use the same shared memory buffer
3089-
# that we'll update later, rather than creating a clone
3090-
for transport in sender._iterate_transports():
3091-
if hasattr(transport, "register_weights"):
3092-
transport.register_weights(model_id, weights)
3094+
# Check if scheme has new API or legacy API
3095+
if hasattr(scheme, "init_on_sender"):
3096+
# Use new API
3097+
scheme.init_on_sender(model_id=model_id, context=self)
3098+
# Get the initialized sender
3099+
self._weight_senders[model_id] = scheme.get_sender()
3100+
# else: keep using legacy _weight_senders initialization from before
30933101

30943102
self.queue_out = queue_out
30953103
self.closed = False
@@ -3451,6 +3459,52 @@ def getattr_rb(self, attr):
34513459
"""Get an attribute from the replay buffer."""
34523460
return getattr(self.replay_buffer, attr)
34533461

3462+
def get_model(self, model_id: str):
3463+
"""Get model instance by ID (for weight sync schemes).
3464+
3465+
Args:
3466+
model_id: Model identifier (e.g., "policy", "value_net")
3467+
3468+
Returns:
3469+
The model instance
3470+
3471+
Raises:
3472+
ValueError: If model_id is not recognized
3473+
"""
3474+
if model_id == "policy":
3475+
# Return the fallback policy instance
3476+
if hasattr(self, "_fallback_policy") and self._fallback_policy is not None:
3477+
return self._fallback_policy
3478+
elif hasattr(self, "policy") and self.policy is not None:
3479+
return self.policy
3480+
else:
3481+
raise ValueError(f"No policy found for model_id '{model_id}'")
3482+
else:
3483+
# Try to resolve via attribute access
3484+
if hasattr(self, model_id):
3485+
return getattr(self, model_id)
3486+
else:
3487+
raise ValueError(f"Unknown model_id: {model_id}")
3488+
3489+
def get_cached_weights(self, model_id: str):
3490+
"""Get cached shared memory weights if available (for weight sync schemes).
3491+
3492+
Args:
3493+
model_id: Model identifier
3494+
3495+
Returns:
3496+
Cached TensorDict weights or None if not available
3497+
"""
3498+
if model_id == "policy" and hasattr(self, "_policy_weights_dict"):
3499+
# Get the policy device (first device if list)
3500+
policy_device = self.policy_device
3501+
if isinstance(policy_device, (list, tuple)):
3502+
policy_device = policy_device[0] if len(policy_device) > 0 else None
3503+
3504+
# Return cached weights for this device
3505+
return self._policy_weights_dict.get(policy_device)
3506+
return None
3507+
34543508

34553509
@accept_remote_rref_udf_invocation
34563510
class MultiSyncDataCollector(_MultiDataCollector):
@@ -4422,13 +4476,21 @@ def _main_async_collector(
44224476
# Set up weight receivers for worker process
44234477
if weight_sync_schemes:
44244478
inner_collector._weight_receivers = {}
4479+
inner_collector.pipe = pipe_child # Add pipe attribute for context
44254480
for model_id, scheme in weight_sync_schemes.items():
4426-
receiver = scheme.create_receiver()
4427-
receiver.set_context(inner_collector)
4428-
receiver.register_worker_transport(pipe_child)
4481+
# Check if scheme has new API or legacy API
4482+
if hasattr(scheme, "init_on_worker"):
4483+
# Use new API
4484+
scheme.init_on_worker(model_id=model_id, context=inner_collector)
4485+
receiver = scheme.get_receiver()
4486+
else:
4487+
# Legacy API
4488+
receiver = scheme.create_receiver()
4489+
receiver.set_context(inner_collector)
4490+
receiver.register_worker_transport(pipe_child)
44294491

4430-
model = _resolve_model(inner_collector, model_id)
4431-
receiver.register_model(model)
4492+
model = _resolve_model(inner_collector, model_id)
4493+
receiver.register_model(model)
44324494

44334495
inner_collector._weight_receivers[model_id] = receiver
44344496
else:
@@ -4617,6 +4679,13 @@ def _main_async_collector(
46174679
inner_collector.init_random_frames = float("inf")
46184680
else:
46194681
inner_collector.init_random_frames = -1
4682+
4683+
# Check for and apply weight updates before collecting next batch
4684+
if inner_collector._weight_receivers:
4685+
for receiver in inner_collector._weight_receivers.values():
4686+
# Non-blocking check for new weights
4687+
receiver.receive(timeout=0.0001)
4688+
46204689
next_data = next(dc_iter)
46214690
if pipe_child.poll(_MIN_TIMEOUT):
46224691
# in this case, main send a message to the worker while it was busy collecting trajectories.

0 commit comments

Comments
 (0)