Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,15 +929,11 @@ def _grad(
)

if grad_outputs is not None:
tup_grad_outputs = tuple(
grad_outputs._values_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)
)
tup_grad_outputs = tuple(grad_outputs[k] for k in outputs.keys(True, True))
else:
tup_grad_outputs = None

tup_outputs = tuple(
outputs._values_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)
)
tup_outputs = tuple(outputs[k] for k in outputs.keys(True, True))

keys, all_inputs = inputs._items_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)

Expand Down
7 changes: 7 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10608,6 +10608,13 @@ def test_update_batch_size(self, source_is_lazy):
assert td.batch_size == (2, 4)
assert td.batch_size == td2.batch_size

def test_autograd_grad_mixed_types(self):
inputs = TensorDict(a=torch.randn(2, 3, requires_grad=True))
outputs = inputs + 1
grad_outputs = LazyStackedTensorDict(TensorDict(a=torch.ones(3)), TensorDict(a=torch.ones(3)), stack_dim=0)
grads = torch.autograd.grad(outputs, inputs, grad_outputs)
assert (grads == 1).all()


@pytest.mark.skipif(
not _has_torchsnapshot, reason=f"torchsnapshot not found: err={TORCHSNAPSHOT_ERR}"
Expand Down
Loading