Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions .github/unittest/linux/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,28 @@ eval "$(${conda_dir}/bin/conda shell.bash hook)"
printf "python: ${PYTHON_VERSION}\n"
if [ ! -d "${env_dir}" ]; then
printf "* Creating a test environment\n"
conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION"
if [ "${PYTHON_VERSION}" == "3.14t" ]; then
# Install free-threaded Python 3.14 from conda-forge
conda create --prefix "${env_dir}" -y -c conda-forge python-freethreading
# Set PYTHON_GIL=0 to keep GIL disabled
export PYTHON_GIL=0
else
conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION"
fi
fi
conda activate "${env_dir}"

# For free-threaded Python, ensure PYTHON_GIL=0 is set
if [ "${PYTHON_VERSION}" == "3.14t" ]; then
export PYTHON_GIL=0
fi

# 3. Install Conda dependencies
printf "* Installing dependencies (except PyTorch)\n"
echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml"
# Don't add python version constraint for free-threaded builds
if [ "${PYTHON_VERSION}" != "3.14t" ]; then
echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml"
fi
cat "${this_dir}/environment.yml"

pip install pip --upgrade
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
test-cpu:
strategy:
matrix:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14", "3.14t"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
Expand Down
2 changes: 0 additions & 2 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2828,8 +2828,6 @@ def _memmap_(
if inplace:
self._is_memmap = True
self._is_shared = False # since they are mutually exclusive
if self._validate_value_cached is not None:
delattr(self, "_validate_value_cached")
self._device = torch.device("cpu")
else:
dest._is_memmap = True
Expand Down
32 changes: 10 additions & 22 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,6 @@ def __getstate__(self) -> dict[str, Any]:
"_last_op",
"_cache",
"__lock_parents_weakrefs",
"_validate_value_cached",
):
result.pop(key, None)
return result
Expand Down Expand Up @@ -3491,8 +3490,6 @@ def dtype(self):
return self._dtype()

def _batch_size_setter(self, new_batch_size: torch.Size) -> None:
if self._validate_value_cached is not None:
delattr(self, "_validate_value_cached")
if new_batch_size == self.batch_size:
return
if self._lazy:
Expand Down Expand Up @@ -5240,17 +5237,13 @@ def clear_device_(self) -> Self:

"""
self._device = None
if self._validate_value_cached is not None:
delattr(self, "_validate_value_cached")
for value in self.values():
if _is_tensor_collection(type(value)):
value.clear_device_()
return self

def _set_device(self, device: torch.device) -> Self:
self._device = device
if self._validate_value_cached is not None:
delattr(self, "_validate_value_cached")
for value in self.values():
if _is_tensor_collection(type(value)):
value._set_device(device=device)
Expand Down Expand Up @@ -12399,26 +12392,21 @@ def _validate_key(self, key: NestedKey) -> NestedKey:
raise KeyError(_GENERIC_NESTED_ERR.format(key))
return key

_validate_value_cached: str | None = None

@property
def _validate_value(self):
if is_compiling():
return self._validate_value_generic
_validate_value_cached = self._validate_value_cached
if _validate_value_cached is None:
if self.device:
if self.batch_size:
_validate_value_cached = "_validate_value_generic"
else:
_validate_value_cached = "_validate_value_batchfree"
if self.device:
if self.batch_size:
method_name = "_validate_value_generic"
else:
if self.batch_size:
_validate_value_cached = "_validate_value_devicefree"
else:
_validate_value_cached = "_validate_value_batchfree_devicefree"
self._validate_value_cached = _validate_value_cached
return getattr(self, _validate_value_cached)
method_name = "_validate_value_batchfree"
else:
if self.batch_size:
method_name = "_validate_value_devicefree"
else:
method_name = "_validate_value_batchfree_devicefree"
return getattr(self, method_name)

def _validate_value_generic(
self,
Expand Down
2 changes: 1 addition & 1 deletion tensordict/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

namespace py = pybind11;

PYBIND11_MODULE(_C, m) {
PYBIND11_MODULE(_C, m, py::mod_gil_not_used()) {
m.def("unravel_keys", &unravel_key, py::arg("key")); // for bc compat
m.def("unravel_key", &unravel_key, py::arg("key"));
m.def("_unravel_key_to_tuple", &_unravel_key_to_tuple, py::arg("key"));
Expand Down