@@ -267,6 +267,8 @@ def _call(
267267 "The output of the function must be a tensordict, a tensorclass or None. Got "
268268 f"type(out)={ type (out )} ."
269269 )
270+ if is_tensor_collection (out ):
271+ out .lock_ ()
270272 self ._out = out
271273 self .counter += 1
272274 if self ._out_matches_in :
@@ -302,14 +304,11 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
302304 torch ._foreach_copy_ (dests , srcs )
303305 torch .cuda .synchronize ()
304306 self .graph .replay ()
305- if self ._return_unchanged == "clone" :
306- result = self ._out .clone ()
307- elif self ._return_unchanged :
307+ if self ._return_unchanged :
308308 result = self ._out
309309 else :
310- result = tree_map (
311- lambda x : x .detach ().clone () if x is not None else x ,
312- self ._out ,
310+ result = tree_unflatten (
311+ [out .clone () for out in self ._out ], self ._out_struct
313312 )
314313 return result
315314
@@ -340,7 +339,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
340339 self .graph = torch .cuda .CUDAGraph ()
341340 with torch .cuda .graph (self .graph ):
342341 out = self .module (* self ._args , ** self ._kwargs )
343- self ._out = out
342+ self ._out , self . _out_struct = tree_flatten ( out )
344343 self .counter += 1
345344 # Check that there is not intersection between the indentity of inputs and outputs, otherwise warn
346345 # user.
@@ -356,11 +355,10 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
356355 f"and the identity between input and output will not match anymore. "
357356 f"Make sure you don't rely on input-output identity further in the code."
358357 )
359- if isinstance (self ._out , torch .Tensor ) or self ._out is None :
360- self ._return_unchanged = (
361- "clone" if self ._out is not None else True
362- )
358+ if not self ._out :
359+ self ._return_unchanged = True
363360 else :
361+ self ._out = [out .lock_ () if is_tensor_collection (out ) else out for out in self ._out ]
364362 self ._return_unchanged = False
365363 return this_out
366364
0 commit comments