@@ -3085,13 +3085,30 @@ def update(
30853085 if input_dict_or_td .ndim <= self_upd .stack_dim or input_dict_or_td .batch_size [
30863086 self_upd .stack_dim
30873087 ] != len (self_upd .tensordicts ):
3088+ # We receive a tensordict with a different batch-size than self.
3089+ # If update_batch_size is True, we can just convert the input tensordict to a lazy stack and update.
3090+ # This will change self, most likely removing a bunch of sub-tensordicts but we're good because the
3091+ # user is interested in modifying the batch-size.
3092+ # If update_batch_size is False, we need to try to change the batch-size of the input tensordict.
3093+ # That can only be done in restricted cases, so we raise an error if the batch-size of self (which must
3094+ # remain unchanged) is incompatible with the content of the input tensordict.
3095+ if update_batch_size :
3096+ return self .update (
3097+ input_dict_or_td .to_lazystack (self .stack_dim ),
3098+ clone = clone ,
3099+ keys_to_update = keys_to_update ,
3100+ non_blocking = non_blocking ,
3101+ is_leaf = is_leaf ,
3102+ update_batch_size = update_batch_size ,
3103+ ** kwargs ,
3104+ )
3105+ # if the batch-size does not permit unbinding, let's first try to reset the batch-size.
3106+ input_dict_or_td = input_dict_or_td .copy ()
3107+ batch_size = self_upd .batch_size
3108+ if self_upd .hook_out is not None :
3109+ batch_size = list (batch_size )
3110+ batch_size .insert (self_upd .stack_dim , len (self_upd .tensordicts ))
30883111 try :
3089- # if the batch-size does not permit unbinding, let's first try to reset the batch-size.
3090- input_dict_or_td = input_dict_or_td .copy ()
3091- batch_size = self_upd .batch_size
3092- if self_upd .hook_out is not None :
3093- batch_size = list (batch_size )
3094- batch_size .insert (self_upd .stack_dim , len (self_upd .tensordicts ))
30953112 input_dict_or_td .batch_size = batch_size
30963113 except RuntimeError as err :
30973114 raise ValueError (
0 commit comments