Skip to content

Commit cb104a1

Browse files
authored
[BugFix] Fix "none"/"None" env vars (#1372)
1 parent 0bb94c0 commit cb104a1

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tensordict/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,7 +2030,10 @@ def capture_non_tensor_stack(allow_none=False):
20302030
return None
20312031
elif _CAPTURE_NONTENSOR_STACK is None:
20322032
return _DEFAULT_CAPTURE_NONTENSOR_STACK
2033-
elif _CAPTURE_NONTENSOR_STACK == "none":
2033+
elif (
2034+
isinstance(_CAPTURE_NONTENSOR_STACK, str)
2035+
and _CAPTURE_NONTENSOR_STACK.lower() == "none"
2036+
):
20342037
return _DEFAULT_CAPTURE_NONTENSOR_STACK
20352038
return (
20362039
strtobool(_CAPTURE_NONTENSOR_STACK)
@@ -2126,7 +2129,7 @@ def list_to_stack(allow_none=False):
21262129
return None
21272130
elif _LIST_TO_STACK is None:
21282131
return _DEFAULT_LIST_TO_STACK
2129-
elif _LIST_TO_STACK == "none":
2132+
elif isinstance(_LIST_TO_STACK, str) and _LIST_TO_STACK.lower() == "none":
21302133
return _DEFAULT_LIST_TO_STACK
21312134
return (
21322135
strtobool(_LIST_TO_STACK) if isinstance(_LIST_TO_STACK, str) else _LIST_TO_STACK
@@ -2860,7 +2863,7 @@ def parse_tensor_dict_string(s: str):
28602863
device_matches = re.findall(device_pattern, s)
28612864
if device_matches:
28622865
device = device_matches[-1] # Take the last match
2863-
if device == "None":
2866+
if isinstance(device, str) and device.lower() == "none":
28642867
device = None
28652868
else:
28662869
device = torch.device(device)

0 commit comments

Comments
 (0)