Skip to content

Commit 4786d54

Browse files
author
Vincent Moens
committed
[Performance] Faster clone
ghstack-source-id: 6eecbac Pull Request resolved: #1040
1 parent 088d953 commit 4786d54

File tree

3 files changed

+26
-11
lines changed

3 files changed

+26
-11
lines changed

tensordict/_td.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3009,6 +3009,9 @@ def is_contiguous(self) -> bool:
30093009
return all([value.is_contiguous() for _, value in self.items()])
30103010

30113011
def _clone(self, recurse: bool = True) -> T:
3012+
if recurse and self.device is not None and self.device.type == "cuda":
3013+
return self._clone_recurse()
3014+
30123015
result = TensorDict._new_unsafe(
30133016
source={key: _clone_value(value, recurse) for key, value in self.items()},
30143017
batch_size=self.batch_size,

tensordict/base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8156,6 +8156,23 @@ def cosh_(self) -> T:
81568156
torch._foreach_cosh_(self._values_list(True, True))
81578157
return self
81588158

8159+
def _clone_recurse(self) -> TensorDictBase: # noqa: D417
8160+
keys, vals = self._items_list(True, True)
8161+
vals = torch._foreach_add(vals, 0)
8162+
items = dict(zip(keys, vals))
8163+
result = self._fast_apply(
8164+
lambda name, val: items.pop(name, None),
8165+
named=True,
8166+
nested_keys=True,
8167+
is_leaf=_NESTED_TENSORS_AS_LISTS,
8168+
propagate_lock=True,
8169+
filter_empty=True,
8170+
default=None,
8171+
)
8172+
if items:
8173+
result.update(items)
8174+
return result
8175+
81598176
def add(
81608177
self,
81618178
other: TensorDictBase | torch.Tensor,

tensordict/nn/cudagraphs.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -302,14 +302,11 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
302302
torch._foreach_copy_(dests, srcs)
303303
torch.cuda.synchronize()
304304
self.graph.replay()
305-
if self._return_unchanged == "clone":
306-
result = self._out.clone()
307-
elif self._return_unchanged:
305+
if self._return_unchanged:
308306
result = self._out
309307
else:
310-
result = tree_map(
311-
lambda x: x.detach().clone() if x is not None else x,
312-
self._out,
308+
result = tree_unflatten(
309+
torch._foreach_add(self._out, 0.0), self._out_struct
313310
)
314311
return result
315312

@@ -340,7 +337,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
340337
self.graph = torch.cuda.CUDAGraph()
341338
with torch.cuda.graph(self.graph):
342339
out = self.module(*self._args, **self._kwargs)
343-
self._out = out
340+
self._out, self._out_struct = tree_flatten(out)
344341
self.counter += 1
345342
# Check that there is not intersection between the indentity of inputs and outputs, otherwise warn
346343
# user.
@@ -356,10 +353,8 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
356353
f"and the identity between input and output will not match anymore. "
357354
f"Make sure you don't rely on input-output identity further in the code."
358355
)
359-
if isinstance(self._out, torch.Tensor) or self._out is None:
360-
self._return_unchanged = (
361-
"clone" if self._out is not None else True
362-
)
356+
if not self._out:
357+
self._return_unchanged = True
363358
else:
364359
self._return_unchanged = False
365360
return this_out

0 commit comments

Comments
 (0)