From 0501f2fa35471dd26226eae2d17ac84e93297f19 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 18 Oct 2025 15:53:37 -0700 Subject: [PATCH 1/4] Update [ghstack-poisoned] --- examples/collectors/weight_sync_collectors.py | 68 +++++------ examples/collectors/weight_sync_standalone.py | 108 +++++++++--------- test/test_env.py | 3 + 3 files changed, 94 insertions(+), 85 deletions(-) diff --git a/examples/collectors/weight_sync_collectors.py b/examples/collectors/weight_sync_collectors.py index fbb1a8a1166..a3962966c8c 100644 --- a/examples/collectors/weight_sync_collectors.py +++ b/examples/collectors/weight_sync_collectors.py @@ -17,7 +17,7 @@ import torch.nn as nn from tensordict import TensorDict from tensordict.nn import TensorDictModule -from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector +from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.envs import GymEnv from torchrl.weight_update import ( MultiProcessWeightSyncScheme, @@ -27,25 +27,24 @@ def example_single_collector_multiprocess(): """Example 1: Single collector with multiprocess scheme.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Example 1: Single Collector with Multiprocess Scheme") - print("="*70) - + print("=" * 70) + # Create environment and policy env = GymEnv("CartPole-v1") policy = TensorDictModule( nn.Linear( - env.observation_spec["observation"].shape[-1], - env.action_spec.shape[-1] + env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] ), in_keys=["observation"], out_keys=["action"], ) env.close() - + # Create weight sync scheme scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - + print("Creating collector with multiprocess weight sync...") collector = SyncDataCollector( create_env_fn=lambda: GymEnv("CartPole-v1"), @@ -54,46 +53,45 @@ def example_single_collector_multiprocess(): total_frames=200, weight_sync_schemes={"policy": scheme}, ) - + # Collect data and update weights periodically print("Collecting data...") for i, data in enumerate(collector): print(f"Iteration {i}: Collected {data.numel()} transitions") - + # Update policy weights every 2 iterations if i % 2 == 0: new_weights = policy.state_dict() collector.update_policy_weights_(new_weights) print(" → Updated policy weights") - + if i >= 2: # Just run a few iterations for demo break - + collector.shutdown() print("✓ Single collector example completed!\n") def example_multi_collector_shared_memory(): """Example 2: Multiple collectors with shared memory.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Example 2: Multiple Collectors with Shared Memory") - print("="*70) - + print("=" * 70) + # Create environment and policy env = GymEnv("CartPole-v1") policy = TensorDictModule( nn.Linear( - env.observation_spec["observation"].shape[-1], - env.action_spec.shape[-1] + env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] ), in_keys=["observation"], out_keys=["action"], ) env.close() - + # Shared memory is more efficient for frequent updates scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - + print("Creating multi-collector with shared memory...") collector = MultiSyncDataCollector( create_env_fn=[ @@ -106,49 +104,51 @@ def example_multi_collector_shared_memory(): total_frames=400, weight_sync_schemes={"policy": scheme}, ) - + # Workers automatically see weight updates via shared memory print("Collecting data...") for i, data in enumerate(collector): print(f"Iteration {i}: Collected {data.numel()} transitions") - + # Update weights frequently (shared memory makes this very fast) collector.update_policy_weights_(TensorDict.from_module(policy)) print(" → Updated policy weights via shared memory") - + if i >= 1: # Just run a couple iterations for demo break - + collector.shutdown() print("✓ Multi-collector with shared memory example completed!\n") def main(): """Run all examples.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Weight Synchronization Schemes - Collector Integration Examples") - print("="*70) - + print("=" * 70) + # Set multiprocessing start method import torch.multiprocessing as mp + try: - mp.set_start_method('spawn') + mp.set_start_method("spawn") except RuntimeError: pass # Already set - + # Run examples example_single_collector_multiprocess() example_multi_collector_shared_memory() - - print("\n" + "="*70) + + print("\n" + "=" * 70) print("All examples completed successfully!") - print("="*70) + print("=" * 70) print("\nKey takeaways:") print(" • MultiProcessWeightSyncScheme: Good for general multiprocess scenarios") - print(" • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers") - print("="*70 + "\n") + print( + " • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers" + ) + print("=" * 70 + "\n") if __name__ == "__main__": main() - diff --git a/examples/collectors/weight_sync_standalone.py b/examples/collectors/weight_sync_standalone.py index 83492256412..69d9947bdc7 100644 --- a/examples/collectors/weight_sync_standalone.py +++ b/examples/collectors/weight_sync_standalone.py @@ -16,8 +16,8 @@ import torch import torch.nn as nn -from torch import multiprocessing as mp from tensordict import TensorDict +from torch import multiprocessing as mp from torchrl.weight_update import ( MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme, @@ -27,21 +27,21 @@ def worker_process_mp(child_pipe, model_state): """Worker process that receives weights via multiprocessing pipe.""" print("Worker: Starting...") - + # Create a policy on the worker side policy = nn.Linear(4, 2) with torch.no_grad(): policy.weight.fill_(0.0) policy.bias.fill_(0.0) - + # Create receiver and register the policy scheme = MultiProcessWeightSyncScheme(strategy="state_dict") receiver = scheme.create_receiver() receiver.register_model(policy) receiver.register_worker_transport(child_pipe) - + print(f"Worker: Before update - weight sum: {policy.weight.sum().item():.4f}") - + # Receive and apply weights result = receiver._transport.receive_weights(timeout=5.0) if result is not None: @@ -50,19 +50,19 @@ def worker_process_mp(child_pipe, model_state): print(f"Worker: After update - weight sum: {policy.weight.sum().item():.4f}") else: print("Worker: No weights received") - + # Store final state for verification - model_state['weight_sum'] = policy.weight.sum().item() - model_state['bias_sum'] = policy.bias.sum().item() + model_state["weight_sum"] = policy.weight.sum().item() + model_state["bias_sum"] = policy.bias.sum().item() def worker_process_shared_mem(child_pipe, model_state): """Worker process that receives shared memory buffer reference.""" print("SharedMem Worker: Starting...") - + # Create a policy on the worker side policy = nn.Linear(4, 2) - + # Wait for shared memory buffer registration if child_pipe.poll(timeout=10.0): data, msg = child_pipe.recv() @@ -73,129 +73,135 @@ def worker_process_shared_mem(child_pipe, model_state): shared_weights.to_module(policy) # Send acknowledgment child_pipe.send((None, "registered")) - + # Small delay to ensure main process updates shared memory import time + time.sleep(0.5) - + print(f"SharedMem Worker: weight sum: {policy.weight.sum().item():.4f}") - + # Store final state for verification - model_state['weight_sum'] = policy.weight.sum().item() - model_state['bias_sum'] = policy.bias.sum().item() + model_state["weight_sum"] = policy.weight.sum().item() + model_state["bias_sum"] = policy.bias.sum().item() def example_multiprocess_sync(): """Example 1: Multiprocess weight synchronization with state_dict.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Example 1: Multiprocess Weight Synchronization") - print("="*70) - + print("=" * 70) + # Create a simple policy on main process policy = nn.Linear(4, 2) with torch.no_grad(): policy.weight.fill_(1.0) policy.bias.fill_(0.5) - + print(f"Main: Policy weight sum: {policy.weight.sum().item():.4f}") - + # Create scheme and sender scheme = MultiProcessWeightSyncScheme(strategy="state_dict") sender = scheme.create_sender() - + # Create pipe for communication parent_pipe, child_pipe = mp.Pipe() sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) - + # Start worker process manager = mp.Manager() model_state = manager.dict() process = mp.Process(target=worker_process_mp, args=(child_pipe, model_state)) process.start() - + # Send weights to worker weights = policy.state_dict() print("Main: Sending weights to worker...") sender.update_weights(weights) - + # Wait for worker to complete process.join(timeout=10.0) - + if process.is_alive(): print("Warning: Worker process did not terminate in time") process.terminate() else: - print(f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}") - print(f"✓ Weight synchronization successful!") + print( + f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}" + ) + print("✓ Weight synchronization successful!") def example_shared_memory_sync(): """Example 2: Shared memory weight synchronization.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Example 2: Shared Memory Weight Synchronization") - print("="*70) - + print("=" * 70) + # Create a simple policy policy = nn.Linear(4, 2) - + # Create shared memory scheme with auto-registration scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) sender = scheme.create_sender() - + # Create pipe for lazy registration parent_pipe, child_pipe = mp.Pipe() sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) - + # Start worker process manager = mp.Manager() model_state = manager.dict() - process = mp.Process(target=worker_process_shared_mem, args=(child_pipe, model_state)) + process = mp.Process( + target=worker_process_shared_mem, args=(child_pipe, model_state) + ) process.start() - + # Send weights (automatically creates shared buffer on first send) weights_td = TensorDict.from_module(policy) with torch.no_grad(): weights_td["weight"].fill_(2.0) weights_td["bias"].fill_(1.0) - - print(f"Main: Sending weights via shared memory...") + + print("Main: Sending weights via shared memory...") sender.update_weights(weights_td) - + # Workers automatically see updates via shared memory! print("Main: Weights are now in shared memory, workers can access them") - + # Wait for worker to complete process.join(timeout=10.0) - + if process.is_alive(): print("Warning: Worker process did not terminate in time") process.terminate() else: - print(f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}") - print(f"✓ Shared memory synchronization successful!") + print( + f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}" + ) + print("✓ Shared memory synchronization successful!") def main(): """Run all examples.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Weight Synchronization Schemes - Standalone Usage Examples") - print("="*70) - + print("=" * 70) + # Set multiprocessing start method try: - mp.set_start_method('spawn') + mp.set_start_method("spawn") except RuntimeError: pass # Already set - + # Run examples example_multiprocess_sync() example_shared_memory_sync() - - print("\n" + "="*70) + + print("\n" + "=" * 70) print("All examples completed successfully!") - print("="*70 + "\n") + print("=" * 70 + "\n") if __name__ == "__main__": main() - diff --git a/test/test_env.py b/test/test_env.py index b092be80e1c..b27e5d1b696 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -13,6 +13,7 @@ import pickle import random import re +import time from collections import defaultdict from functools import partial from sys import platform @@ -3822,6 +3823,8 @@ def test_parallel(self, bwad, use_buffers, maybe_fork_ParallelEnv): finally: env.close(raise_if_closed=False) del env + time.sleep(0.1) + gc.collect() class AddString(Transform): def __init__(self): From 5c2d8a80b3c8c6bbac3ee400063bd2754b7e752c Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 22 Oct 2025 21:06:06 -0700 Subject: [PATCH 2/4] Update [ghstack-poisoned] --- test/test_collector.py | 1 + torchrl/collectors/collectors.py | 13 +++++++++++++ torchrl/envs/batched_envs.py | 4 ++-- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index bb0c0330bf7..dd62c063006 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1512,6 +1512,7 @@ def create_env(): cudagraph_policy=cudagraph, weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()}, ) + assert "policy" in collector._weight_senders, collector._weight_senders.keys() try: # collect state_dict state_dict = collector.state_dict() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 355e6e98db0..d4ca4c0d872 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -307,6 +307,19 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: else None ) + # If no weights were provided and a sync scheme exists, extract the latest + # weights from the current model using the scheme strategy (state_dict or tensordict). + # This ensures we don't return stale cached weights. + if weights is None and scheme is not None: + from torchrl.weight_update.weight_sync_schemes import ( + _resolve_model, + WeightStrategy, + ) + + strategy = WeightStrategy(extract_as=scheme.strategy) + model = _resolve_model(self, model_id) + return strategy.extract_weights(model) + if weights is None: if model_id == "policy" and hasattr(self, "policy_weights"): return self.policy_weights diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 9ab545d286a..c6bdaf6383e 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -2492,7 +2492,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): # Set event before sending non-tensor data so parent knows worker is done # The recv() call itself will provide synchronization for the pipe mp_event.set() - + if _non_tensor_keys: child_pipe.send( ("non_tensor", next_td.select(*_non_tensor_keys, strict=False)) @@ -2534,7 +2534,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): # Set event before sending non-tensor data so parent knows worker is done # The recv() call itself will provide synchronization for the pipe mp_event.set() - + if _non_tensor_keys: ntd = root_next_td.select(*_non_tensor_keys) ntd.set("next", td_next.select(*_non_tensor_keys)) From bbd8b930c8c38f4d193df2b0933834e3d9d74267 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 23 Oct 2025 10:10:14 -0700 Subject: [PATCH 3/4] Update [ghstack-poisoned] --- torchrl/collectors/collectors.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index d4ca4c0d872..ffdbe1e740c 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -475,6 +475,18 @@ def update_policy_weights_( # Apply to local policy if hasattr(self, "policy") and isinstance(self.policy, nn.Module): strategy.apply_weights(self.policy, weights) + elif ( + hasattr(self, "_original_policy") + and isinstance(self._original_policy, nn.Module) + and hasattr(self, "policy") + and isinstance(self.policy, nn.Module) + ): + # If no weights were provided, mirror weights from the original (trainer) policy + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + strategy = WeightStrategy(extract_as="tensordict") + weights = strategy.extract_weights(self._original_policy) + strategy.apply_weights(self.policy, weights) # Otherwise, no action needed - policy is local and changes are immediately visible def __iter__(self) -> Iterator[TensorDictBase]: From c9e9b981c930a0b08b5de356d75b270024dc5146 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 23 Oct 2025 11:42:43 -0700 Subject: [PATCH 4/4] Update [ghstack-poisoned] --- test/test_collector.py | 2 +- torchrl/collectors/collectors.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_collector.py b/test/test_collector.py index dd62c063006..921299d4c40 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -162,7 +162,7 @@ def forward(self, observation): output = self.linear(observation) if self.multiple_outputs: return output, output.sum(), output.min(), output.max() - return self.linear(observation) + return output class UnwrappablePolicy(nn.Module): diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index ffdbe1e740c..c14b1490cd6 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -486,6 +486,9 @@ def update_policy_weights_( strategy = WeightStrategy(extract_as="tensordict") weights = strategy.extract_weights(self._original_policy) + # Cast weights to the policy device before applying + if self.policy_device is not None: + weights = weights.to(self.policy_device) strategy.apply_weights(self.policy, weights) # Otherwise, no action needed - policy is local and changes are immediately visible