From 108ba8cd70db32f266a577543ce71de418ab0e09 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Tue, 25 Nov 2025 13:25:55 +0000 Subject: [PATCH 1/2] Add free threading declaration --- .github/unittest/linux/scripts/setup_env.sh | 19 +++++++++++++++++-- .github/workflows/test-linux.yml | 2 +- tensordict/csrc/pybind.cpp | 2 +- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/.github/unittest/linux/scripts/setup_env.sh b/.github/unittest/linux/scripts/setup_env.sh index ec85563a2..5ed9e630c 100755 --- a/.github/unittest/linux/scripts/setup_env.sh +++ b/.github/unittest/linux/scripts/setup_env.sh @@ -40,13 +40,28 @@ eval "$(${conda_dir}/bin/conda shell.bash hook)" printf "python: ${PYTHON_VERSION}\n" if [ ! -d "${env_dir}" ]; then printf "* Creating a test environment\n" - conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" + if [ "${PYTHON_VERSION}" == "3.14t" ]; then + # Install free-threaded Python 3.14 from conda-forge + conda create --prefix "${env_dir}" -y -c conda-forge python-freethreading + # Set PYTHON_GIL=0 to keep GIL disabled + export PYTHON_GIL=0 + else + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" + fi fi conda activate "${env_dir}" +# For free-threaded Python, ensure PYTHON_GIL=0 is set +if [ "${PYTHON_VERSION}" == "3.14t" ]; then + export PYTHON_GIL=0 +fi + # 3. Install Conda dependencies printf "* Installing dependencies (except PyTorch)\n" -echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +# Don't add python version constraint for free-threaded builds +if [ "${PYTHON_VERSION}" != "3.14t" ]; then + echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +fi cat "${this_dir}/environment.yml" pip install pip --upgrade diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index 62e591e8f..774ff610e 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -58,7 +58,7 @@ jobs: test-cpu: strategy: matrix: - python_version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + python_version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14", "3.14t"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main permissions: diff --git a/tensordict/csrc/pybind.cpp b/tensordict/csrc/pybind.cpp index b10238220..2dce2b6e1 100644 --- a/tensordict/csrc/pybind.cpp +++ b/tensordict/csrc/pybind.cpp @@ -13,7 +13,7 @@ namespace py = pybind11; -PYBIND11_MODULE(_C, m) { +PYBIND11_MODULE(_C, m, py::mod_gil_not_used()) { m.def("unravel_keys", &unravel_key, py::arg("key")); // for bc compat m.def("unravel_key", &unravel_key, py::arg("key")); m.def("_unravel_key_to_tuple", &_unravel_key_to_tuple, py::arg("key")); From 3db9a7c56fc8eb3a56a6f04f7bbd66e3c3934dce Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Tue, 25 Nov 2025 15:56:27 +0000 Subject: [PATCH 2/2] Fix race condition --- tensordict/_td.py | 2 -- tensordict/base.py | 32 ++++++++++---------------------- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 57529590d..67c60feb6 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2828,8 +2828,6 @@ def _memmap_( if inplace: self._is_memmap = True self._is_shared = False # since they are mutually exclusive - if self._validate_value_cached is not None: - delattr(self, "_validate_value_cached") self._device = torch.device("cpu") else: dest._is_memmap = True diff --git a/tensordict/base.py b/tensordict/base.py index bca715032..e9d414e5d 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -653,7 +653,6 @@ def __getstate__(self) -> dict[str, Any]: "_last_op", "_cache", "__lock_parents_weakrefs", - "_validate_value_cached", ): result.pop(key, None) return result @@ -3491,8 +3490,6 @@ def dtype(self): return self._dtype() def _batch_size_setter(self, new_batch_size: torch.Size) -> None: - if self._validate_value_cached is not None: - delattr(self, "_validate_value_cached") if new_batch_size == self.batch_size: return if self._lazy: @@ -5240,8 +5237,6 @@ def clear_device_(self) -> Self: """ self._device = None - if self._validate_value_cached is not None: - delattr(self, "_validate_value_cached") for value in self.values(): if _is_tensor_collection(type(value)): value.clear_device_() @@ -5249,8 +5244,6 @@ def clear_device_(self) -> Self: def _set_device(self, device: torch.device) -> Self: self._device = device - if self._validate_value_cached is not None: - delattr(self, "_validate_value_cached") for value in self.values(): if _is_tensor_collection(type(value)): value._set_device(device=device) @@ -12399,26 +12392,21 @@ def _validate_key(self, key: NestedKey) -> NestedKey: raise KeyError(_GENERIC_NESTED_ERR.format(key)) return key - _validate_value_cached: str | None = None - @property def _validate_value(self): if is_compiling(): return self._validate_value_generic - _validate_value_cached = self._validate_value_cached - if _validate_value_cached is None: - if self.device: - if self.batch_size: - _validate_value_cached = "_validate_value_generic" - else: - _validate_value_cached = "_validate_value_batchfree" + if self.device: + if self.batch_size: + method_name = "_validate_value_generic" else: - if self.batch_size: - _validate_value_cached = "_validate_value_devicefree" - else: - _validate_value_cached = "_validate_value_batchfree_devicefree" - self._validate_value_cached = _validate_value_cached - return getattr(self, _validate_value_cached) + method_name = "_validate_value_batchfree" + else: + if self.batch_size: + method_name = "_validate_value_devicefree" + else: + method_name = "_validate_value_batchfree_devicefree" + return getattr(self, method_name) def _validate_value_generic( self,