Skip to content

Commit c858a0a

Browse files
committed
Update
[ghstack-poisoned]
1 parent 3cf1674 commit c858a0a

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

torchrl/collectors/collectors.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2187,10 +2187,11 @@ def get_model(self, model_id: str):
21872187
ValueError: If model_id is not recognized
21882188
"""
21892189
if model_id == "policy":
2190-
# Return the wrapped policy instance
2191-
if hasattr(self, "_wrapped_policy") and self._wrapped_policy is not None:
2192-
return self._wrapped_policy
2193-
elif hasattr(self, "policy") and self.policy is not None:
2190+
# Return the unwrapped policy instance for weight synchronization
2191+
# The unwrapped policy has the same parameter structure as what's
2192+
# extracted in the main process, avoiding key mismatches when
2193+
# the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule)
2194+
if hasattr(self, "policy") and self.policy is not None:
21942195
return self.policy
21952196
else:
21962197
raise ValueError(f"No policy found for model_id '{model_id}'")
@@ -4680,12 +4681,21 @@ def _main_async_collector(
46804681
# Only apply if the model is an nn.Module (has learnable parameters)
46814682
try:
46824683
model = receiver._resolve_model_ref()
4683-
if isinstance(model, nn.Module):
4684-
receiver.apply_weights(shared_buffer)
4685-
except (ValueError, AttributeError):
4686-
# Model not registered or not an nn.Module (e.g., RandomPolicy)
4687-
# Skip weight application - this is expected for policies without parameters
4688-
pass
4684+
except (ValueError, AttributeError) as e:
4685+
# Model not registered or reference is invalid
4686+
if verbose:
4687+
torchrl_logger.warning(
4688+
f"worker {idx} could not resolve model '{model_id}': {e}"
4689+
)
4690+
continue
4691+
4692+
if isinstance(model, nn.Module):
4693+
receiver.apply_weights(shared_buffer)
4694+
else:
4695+
if verbose:
4696+
torchrl_logger.info(
4697+
f"worker {idx} skipping weight application for non-nn.Module model '{model_id}'"
4698+
)
46894699

46904700
if verbose:
46914701
torchrl_logger.info(

torchrl/weight_update/weight_sync_schemes.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,12 @@ def apply_weights(self, destination: Any, weights: Any) -> None:
777777
if isinstance(weights, TensorDictBase):
778778
# Apply TensorDict format
779779
if isinstance(destination, TensorDictBase):
780-
destination.data.update_(weights.data)
780+
try:
781+
destination.data.update_(weights.data)
782+
except Exception as e:
783+
raise KeyError(
784+
f"Error updating destination: {e}. Destination keys: {destination.keys(True, True)}, weights keys: {weights.keys(True, True)}"
785+
)
781786
else:
782787
raise ValueError(
783788
f"Unsupported destination type for TensorDict: {type(destination)}"

0 commit comments

Comments
 (0)