@@ -2187,10 +2187,11 @@ def get_model(self, model_id: str):
21872187 ValueError: If model_id is not recognized
21882188 """
21892189 if model_id == "policy" :
2190- # Return the wrapped policy instance
2191- if hasattr (self , "_wrapped_policy" ) and self ._wrapped_policy is not None :
2192- return self ._wrapped_policy
2193- elif hasattr (self , "policy" ) and self .policy is not None :
2190+ # Return the unwrapped policy instance for weight synchronization
2191+ # The unwrapped policy has the same parameter structure as what's
2192+ # extracted in the main process, avoiding key mismatches when
2193+ # the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule)
2194+ if hasattr (self , "policy" ) and self .policy is not None :
21942195 return self .policy
21952196 else :
21962197 raise ValueError (f"No policy found for model_id '{ model_id } '" )
@@ -4680,12 +4681,21 @@ def _main_async_collector(
46804681 # Only apply if the model is an nn.Module (has learnable parameters)
46814682 try :
46824683 model = receiver ._resolve_model_ref ()
4683- if isinstance (model , nn .Module ):
4684- receiver .apply_weights (shared_buffer )
4685- except (ValueError , AttributeError ):
4686- # Model not registered or not an nn.Module (e.g., RandomPolicy)
4687- # Skip weight application - this is expected for policies without parameters
4688- pass
4684+ except (ValueError , AttributeError ) as e :
4685+ # Model not registered or reference is invalid
4686+ if verbose :
4687+ torchrl_logger .warning (
4688+ f"worker { idx } could not resolve model '{ model_id } ': { e } "
4689+ )
4690+ continue
4691+
4692+ if isinstance (model , nn .Module ):
4693+ receiver .apply_weights (shared_buffer )
4694+ else :
4695+ if verbose :
4696+ torchrl_logger .info (
4697+ f"worker { idx } skipping weight application for non-nn.Module model '{ model_id } '"
4698+ )
46894699
46904700 if verbose :
46914701 torchrl_logger .info (
0 commit comments