Skip to content

Commit afcbcec

Browse files
authored
[BugFix] make device arg in TensorDict constructor respect cuda current device (#1369)
1 parent 5e67cb3 commit afcbcec

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tensordict/_td.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,13 @@ def __init__(
268268
sub_non_blocking = non_blocking
269269
device = torch.device(device)
270270
# Auto-index the device
271-
if device.type not in ("cpu", "meta") and device.index is None:
272-
device = torch.device(device.type, index=0)
271+
if device.index is None:
272+
if device.type == "cuda":
273+
device = torch.device(
274+
device.type, index=torch.cuda.current_device()
275+
)
276+
elif device.type not in ("cpu", "meta"):
277+
device = torch.device(device.type, index=0)
273278
if device.type == "cuda":
274279
# CUDA does its sync by itself
275280
call_sync = False

0 commit comments

Comments
 (0)