Skip to content

Commit 910c953

Browse files
author
Vincent Moens
committed
[BugFix] Fix .item() warning on tensors that require grad
ghstack-source-id: 3857ef3 Pull Request resolved: #1283
1 parent 6ad496b commit 910c953

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

tensordict/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,22 +1440,27 @@ def istensor(cls):
14401440
raise RuntimeError(
14411441
f"Failed to compare key {prefix + (key,)}. Scroll up for more details."
14421442
) from err
1443-
mse = mse.div(input1.numel()).sqrt().item()
1443+
mse = mse.data.div(input1.numel()).sqrt().item()
14441444

14451445
local_msg = f"key {prefix + (key,)} does not match, got mse = {mse:4.4f}"
14461446
new_msg = ",\t".join([local_msg, msg]) if len(msg) else local_msg
14471447
if input1.is_nested:
14481448
torch.testing.assert_close(
1449-
input1v,
1450-
input2v,
1449+
input1v.data,
1450+
input2v.data,
14511451
rtol=rtol,
14521452
atol=atol,
14531453
equal_nan=equal_nan,
14541454
msg=new_msg,
14551455
)
14561456
else:
14571457
torch.testing.assert_close(
1458-
input1, input2, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg
1458+
input1.data,
1459+
input2.data,
1460+
rtol=rtol,
1461+
atol=atol,
1462+
equal_nan=equal_nan,
1463+
msg=new_msg,
14591464
)
14601465
local_msg = f"key {prefix + (key,)} matches"
14611466
msg = "\t".join([local_msg, msg]) if len(msg) else local_msg

0 commit comments

Comments
 (0)