diff --git a/test/test_collector.py b/test/test_collector.py index bb0c0330bf7..921299d4c40 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -162,7 +162,7 @@ def forward(self, observation): output = self.linear(observation) if self.multiple_outputs: return output, output.sum(), output.min(), output.max() - return self.linear(observation) + return output class UnwrappablePolicy(nn.Module): @@ -1512,6 +1512,7 @@ def create_env(): cudagraph_policy=cudagraph, weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()}, ) + assert "policy" in collector._weight_senders, collector._weight_senders.keys() try: # collect state_dict state_dict = collector.state_dict() diff --git a/test/test_env.py b/test/test_env.py index 7cd3c61ae0c..7aa00e98d2d 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3836,6 +3836,8 @@ def test_parallel(self, bwad, use_buffers, maybe_fork_ParallelEnv): finally: env.close(raise_if_closed=False) del env + time.sleep(0.1) + gc.collect() class AddString(Transform): def __init__(self): diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 355e6e98db0..c14b1490cd6 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -307,6 +307,19 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: else None ) + # If no weights were provided and a sync scheme exists, extract the latest + # weights from the current model using the scheme strategy (state_dict or tensordict). + # This ensures we don't return stale cached weights. + if weights is None and scheme is not None: + from torchrl.weight_update.weight_sync_schemes import ( + _resolve_model, + WeightStrategy, + ) + + strategy = WeightStrategy(extract_as=scheme.strategy) + model = _resolve_model(self, model_id) + return strategy.extract_weights(model) + if weights is None: if model_id == "policy" and hasattr(self, "policy_weights"): return self.policy_weights @@ -462,6 +475,21 @@ def update_policy_weights_( # Apply to local policy if hasattr(self, "policy") and isinstance(self.policy, nn.Module): strategy.apply_weights(self.policy, weights) + elif ( + hasattr(self, "_original_policy") + and isinstance(self._original_policy, nn.Module) + and hasattr(self, "policy") + and isinstance(self.policy, nn.Module) + ): + # If no weights were provided, mirror weights from the original (trainer) policy + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + strategy = WeightStrategy(extract_as="tensordict") + weights = strategy.extract_weights(self._original_policy) + # Cast weights to the policy device before applying + if self.policy_device is not None: + weights = weights.to(self.policy_device) + strategy.apply_weights(self.policy, weights) # Otherwise, no action needed - policy is local and changes are immediately visible def __iter__(self) -> Iterator[TensorDictBase]: diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 2baa465b74b..c6bdaf6383e 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -2489,14 +2489,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda): # Make sure the root is updated root_shared_tensordict.update_(env._step_mdp(input)) + # Set event before sending non-tensor data so parent knows worker is done + # The recv() call itself will provide synchronization for the pipe + mp_event.set() + if _non_tensor_keys: child_pipe.send( ("non_tensor", next_td.select(*_non_tensor_keys, strict=False)) ) - # Set event only after non-tensor data is sent to avoid race condition - mp_event.set() - del next_td elif cmd == "step_and_maybe_reset": @@ -2530,14 +2531,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda): event.record() event.synchronize() + # Set event before sending non-tensor data so parent knows worker is done + # The recv() call itself will provide synchronization for the pipe + mp_event.set() + if _non_tensor_keys: ntd = root_next_td.select(*_non_tensor_keys) ntd.set("next", td_next.select(*_non_tensor_keys)) child_pipe.send(("non_tensor", ntd)) - # Set event only after non-tensor data is sent to avoid race condition - mp_event.set() - del td, root_next_td elif cmd == "close":