@@ -471,7 +471,10 @@ def _weight_update_impl(
471471 processed_weights = self ._extract_weights_if_needed (
472472 weights , target_model_id
473473 )
474- self ._weight_senders [target_model_id ].update_weights (processed_weights )
474+ # Use new send() API with worker_ids support
475+ self ._weight_senders [target_model_id ].send (
476+ weights = processed_weights , worker_ids = worker_ids
477+ )
475478 elif self ._weight_updater is not None :
476479 # unreachable
477480 raise RuntimeError
@@ -2154,6 +2157,33 @@ def getattr_rb(self, attr):
21542157 # send command to rb to return the attr
21552158 return getattr (self .replay_buffer , attr )
21562159
2160+ def get_model (self , model_id : str ):
2161+ """Get model instance by ID (for weight sync schemes).
2162+
2163+ Args:
2164+ model_id: Model identifier (e.g., "policy", "value_net")
2165+
2166+ Returns:
2167+ The model instance
2168+
2169+ Raises:
2170+ ValueError: If model_id is not recognized
2171+ """
2172+ if model_id == "policy" :
2173+ # Return the wrapped policy instance
2174+ if hasattr (self , "_wrapped_policy" ) and self ._wrapped_policy is not None :
2175+ return self ._wrapped_policy
2176+ elif hasattr (self , "policy" ) and self .policy is not None :
2177+ return self .policy
2178+ else :
2179+ raise ValueError (f"No policy found for model_id '{ model_id } '" )
2180+ else :
2181+ # Try to resolve via attribute access
2182+ if hasattr (self , model_id ):
2183+ return getattr (self , model_id )
2184+ else :
2185+ raise ValueError (f"Unknown model_id: { model_id } " )
2186+
21572187
21582188class _MultiDataCollector (DataCollectorBase ):
21592189 """Runs a given number of DataCollectors on separate processes.
@@ -2890,15 +2920,7 @@ def _run_processes(self) -> None:
28902920 1 , torch .get_num_threads () - total_workers
28912921 ) # 1 more thread for this proc
28922922
2893- # Initialize weight senders for multiprocess collectors
2894- if self ._weight_sync_schemes :
2895- # Create one sender per model using scheme's factory method
2896- for model_id , scheme in self ._weight_sync_schemes .items ():
2897- sender = scheme .create_sender ()
2898- sender ._model_id = model_id
2899- if hasattr (sender , "set_context" ):
2900- sender .set_context (self , model_id )
2901- self ._weight_senders [model_id ] = sender
2923+ # Weight senders will be initialized after workers are ready (via init_on_sender)
29022924 torch .set_num_threads (self .num_threads )
29032925 queue_out = mp .Queue (self ._queue_len ) # sends data from proc to main
29042926 self .procs = []
@@ -3010,11 +3032,7 @@ def _run_processes(self) -> None:
30103032 self .procs .append (proc )
30113033 self .pipes .append (pipe_parent )
30123034
3013- # Register worker with senders
3014- if self ._weight_senders :
3015- for _ , sender in self ._weight_senders .items ():
3016- sender .register_worker (i , pipe_parent )
3017-
3035+ # Worker registration now handled by init_on_sender() after workers are ready
30183036 for i , pipe_parent in enumerate (self .pipes ):
30193037 pipe_parent .poll (timeout = INSTANTIATE_TIMEOUT )
30203038 try :
@@ -3066,30 +3084,20 @@ def _run_processes(self) -> None:
30663084 # Legacy string error message
30673085 raise RuntimeError (msg )
30683086
3069- # For SharedMemWeightSyncScheme, pre-register shared weights now that workers are ready
3070- # This avoids deadlock when workers are busy collecting and can't respond to registration messages
3087+ # Initialize all weight sync schemes now that workers are ready
3088+ # This calls init_on_sender() for each scheme which:
3089+ # 1. Creates transports for all workers
3090+ # 2. Creates and configures the sender
3091+ # 3. For SharedMemWeightSyncScheme, distributes buffer references to avoid deadlock
30713092 if self ._weight_sync_schemes :
30723093 for model_id , scheme in self ._weight_sync_schemes .items ():
3073- if isinstance (scheme , SharedMemWeightSyncScheme ):
3074- sender = self ._weight_senders [model_id ]
3075- # Get the shared memory weights from _policy_weights_dict
3076- # Use prepare_weights with None to trigger cache lookup
3077- from torchrl .weight_update .weight_sync_schemes import _get_strategy
3078-
3079- strategy = _get_strategy (scheme .strategy )
3080- weights = scheme .prepare_weights (
3081- weights = None ,
3082- model_id = model_id ,
3083- strategy = strategy ,
3084- context = self ,
3085- )
3086- if weights is not None :
3087- # Register the shared weights directly with each transport
3088- # This ensures the transports use the same shared memory buffer
3089- # that we'll update later, rather than creating a clone
3090- for transport in sender ._iterate_transports ():
3091- if hasattr (transport , "register_weights" ):
3092- transport .register_weights (model_id , weights )
3094+ # Check if scheme has new API or legacy API
3095+ if hasattr (scheme , "init_on_sender" ):
3096+ # Use new API
3097+ scheme .init_on_sender (model_id = model_id , context = self )
3098+ # Get the initialized sender
3099+ self ._weight_senders [model_id ] = scheme .get_sender ()
3100+ # else: keep using legacy _weight_senders initialization from before
30933101
30943102 self .queue_out = queue_out
30953103 self .closed = False
@@ -3451,6 +3459,52 @@ def getattr_rb(self, attr):
34513459 """Get an attribute from the replay buffer."""
34523460 return getattr (self .replay_buffer , attr )
34533461
3462+ def get_model (self , model_id : str ):
3463+ """Get model instance by ID (for weight sync schemes).
3464+
3465+ Args:
3466+ model_id: Model identifier (e.g., "policy", "value_net")
3467+
3468+ Returns:
3469+ The model instance
3470+
3471+ Raises:
3472+ ValueError: If model_id is not recognized
3473+ """
3474+ if model_id == "policy" :
3475+ # Return the fallback policy instance
3476+ if hasattr (self , "_fallback_policy" ) and self ._fallback_policy is not None :
3477+ return self ._fallback_policy
3478+ elif hasattr (self , "policy" ) and self .policy is not None :
3479+ return self .policy
3480+ else :
3481+ raise ValueError (f"No policy found for model_id '{ model_id } '" )
3482+ else :
3483+ # Try to resolve via attribute access
3484+ if hasattr (self , model_id ):
3485+ return getattr (self , model_id )
3486+ else :
3487+ raise ValueError (f"Unknown model_id: { model_id } " )
3488+
3489+ def get_cached_weights (self , model_id : str ):
3490+ """Get cached shared memory weights if available (for weight sync schemes).
3491+
3492+ Args:
3493+ model_id: Model identifier
3494+
3495+ Returns:
3496+ Cached TensorDict weights or None if not available
3497+ """
3498+ if model_id == "policy" and hasattr (self , "_policy_weights_dict" ):
3499+ # Get the policy device (first device if list)
3500+ policy_device = self .policy_device
3501+ if isinstance (policy_device , (list , tuple )):
3502+ policy_device = policy_device [0 ] if len (policy_device ) > 0 else None
3503+
3504+ # Return cached weights for this device
3505+ return self ._policy_weights_dict .get (policy_device )
3506+ return None
3507+
34543508
34553509@accept_remote_rref_udf_invocation
34563510class MultiSyncDataCollector (_MultiDataCollector ):
@@ -4422,13 +4476,21 @@ def _main_async_collector(
44224476 # Set up weight receivers for worker process
44234477 if weight_sync_schemes :
44244478 inner_collector ._weight_receivers = {}
4479+ inner_collector .pipe = pipe_child # Add pipe attribute for context
44254480 for model_id , scheme in weight_sync_schemes .items ():
4426- receiver = scheme .create_receiver ()
4427- receiver .set_context (inner_collector )
4428- receiver .register_worker_transport (pipe_child )
4481+ # Check if scheme has new API or legacy API
4482+ if hasattr (scheme , "init_on_worker" ):
4483+ # Use new API
4484+ scheme .init_on_worker (model_id = model_id , context = inner_collector )
4485+ receiver = scheme .get_receiver ()
4486+ else :
4487+ # Legacy API
4488+ receiver = scheme .create_receiver ()
4489+ receiver .set_context (inner_collector )
4490+ receiver .register_worker_transport (pipe_child )
44294491
4430- model = _resolve_model (inner_collector , model_id )
4431- receiver .register_model (model )
4492+ model = _resolve_model (inner_collector , model_id )
4493+ receiver .register_model (model )
44324494
44334495 inner_collector ._weight_receivers [model_id ] = receiver
44344496 else :
@@ -4617,6 +4679,13 @@ def _main_async_collector(
46174679 inner_collector .init_random_frames = float ("inf" )
46184680 else :
46194681 inner_collector .init_random_frames = - 1
4682+
4683+ # Check for and apply weight updates before collecting next batch
4684+ if inner_collector ._weight_receivers :
4685+ for receiver in inner_collector ._weight_receivers .values ():
4686+ # Non-blocking check for new weights
4687+ receiver .receive (timeout = 0.0001 )
4688+
46204689 next_data = next (dc_iter )
46214690 if pipe_child .poll (_MIN_TIMEOUT ):
46224691 # in this case, main send a message to the worker while it was busy collecting trajectories.
0 commit comments