@@ -302,14 +302,11 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
302302 torch ._foreach_copy_ (dests , srcs )
303303 torch .cuda .synchronize ()
304304 self .graph .replay ()
305- if self ._return_unchanged == "clone" :
306- result = self ._out .clone ()
307- elif self ._return_unchanged :
305+ if self ._return_unchanged :
308306 result = self ._out
309307 else :
310- result = tree_map (
311- lambda x : x .detach ().clone () if x is not None else x ,
312- self ._out ,
308+ result = tree_unflatten (
309+ torch ._foreach_add (self ._out , 0.0 ), self ._out_struct
313310 )
314311 return result
315312
@@ -340,7 +337,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
340337 self .graph = torch .cuda .CUDAGraph ()
341338 with torch .cuda .graph (self .graph ):
342339 out = self .module (* self ._args , ** self ._kwargs )
343- self ._out = out
340+ self ._out , self . _out_struct = tree_flatten ( out )
344341 self .counter += 1
345342 # Check that there is not intersection between the indentity of inputs and outputs, otherwise warn
346343 # user.
@@ -356,10 +353,8 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
356353 f"and the identity between input and output will not match anymore. "
357354 f"Make sure you don't rely on input-output identity further in the code."
358355 )
359- if isinstance (self ._out , torch .Tensor ) or self ._out is None :
360- self ._return_unchanged = (
361- "clone" if self ._out is not None else True
362- )
356+ if not self ._out :
357+ self ._return_unchanged = True
363358 else :
364359 self ._return_unchanged = False
365360 return this_out
0 commit comments