diff --git a/docs/source/overview.rst b/docs/source/overview.rst index 7bbd5a745..833c38df3 100644 --- a/docs/source/overview.rst +++ b/docs/source/overview.rst @@ -466,3 +466,30 @@ output keys by querying the in_keys and out_keys attributes. It is also possible :class:`~tensordict.nn.TensorDictSequential` with only the modules that are indispensable to satisfy those requirements. The :class:`~tensordict.nn.TensorDictModule` is also compatible with :func:`~torch.vmap` and other ``torch.func`` capabilities. + +TorchScript fork/wait interop (TorchBind) +---------------------------------------- + +TorchScript Futures can only carry JIT-visible types. To enable passing a TensorDict-like payload through +``torch.jit.fork``/``wait``, a lightweight TorchBind class is provided. Use a scripted factory to return the +TorchBind class, then wrap back to :class:`~tensordict.TensorDict` in Python: + +.. code-block:: python + + import torch + from tensordict import TensorDict + + @torch.jit.script + def make_td_scripted(x: torch.Tensor) -> torch.classes.tensordict.TensorDict: # type: ignore[attr-defined] + keys = ["x", "y"] + vals = [x, x + 1] + return torch.classes.tensordict.TensorDict.from_pairs(keys, vals, [], x.device) # type: ignore[attr-defined] + + f1 = torch.jit.fork(make_td_scripted, torch.tensor(1)) + f2 = torch.jit.fork(make_td_scripted, torch.tensor(2)) + obj1 = torch.jit.wait(f1) + obj2 = torch.jit.wait(f2) + td1 = TensorDict.from_torchbind(obj1) + td2 = TensorDict.from_torchbind(obj2) + +Limitations in v1: flat structure (no nesting) and single-device TensorDicts (``td.device`` must be set). diff --git a/setup.py b/setup.py index 2671e82c7..c83ba1e98 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,21 @@ def build_extension(self, ext): else: # For regular installs, place the extension in the build directory extdir = os.path.abspath(os.path.join(self.build_lib, "tensordict")) + + # Try to import torch to obtain its CMake prefix for find_package(Torch) + cmake_prefix = os.environ.get("CMAKE_PREFIX_PATH", "") + try: + import torch # noqa: F401 + from torch.utils import cmake_prefix_path as torch_cmake_prefix_path # type: ignore + + # Prepend Torch's cmake prefix so CMake can find Torch + cmake_prefix = ( + f"{torch_cmake_prefix_path}:{cmake_prefix}" if cmake_prefix else torch_cmake_prefix_path + ) + except Exception: + # Torch not importable at build time; rely on environment CMAKE_PREFIX_PATH + pass + cmake_args = [ f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY={extdir}", @@ -73,6 +88,8 @@ def build_extension(self, ext): # for windows "-DCMAKE_BUILD_TYPE=Release", ] + if cmake_prefix: + cmake_args.append(f"-DCMAKE_PREFIX_PATH={cmake_prefix}") build_args = [] if not os.path.exists(self.build_temp): diff --git a/tensordict/_contextlib.py b/tensordict/_contextlib.py index f75dbb453..16859b6b8 100644 --- a/tensordict/_contextlib.py +++ b/tensordict/_contextlib.py @@ -361,3 +361,22 @@ def _reverse_to(self, args, kwargs, out): LAST_OP_MAPS["to"] = _reverse_to + + +def _reverse_to_torchbind(self, args, kwargs, out): + """Reverse the to_torchbind() operation by converting back to TensorDict. + + Uses from_torchbind to convert the TorchBind object back to a Python TensorDict, + then updates the original tensordict with the potentially modified values. + """ + if out is None: + return self + # Convert the TorchBind object back to TensorDict + from tensordict._td import TensorDict + + td_from_tb = TensorDict.from_torchbind(out) + # Update self with the values from the TorchBind object + return self.update(td_from_tb, inplace=False) + + +LAST_OP_MAPS["to_torchbind"] = _reverse_to_torchbind diff --git a/tensordict/_jit_ops.py b/tensordict/_jit_ops.py new file mode 100644 index 000000000..329018f6c --- /dev/null +++ b/tensordict/_jit_ops.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""TorchScript-compatible ops for TensorDict operations.""" + +import torch + +# Ensure C++ extension is loaded +try: + import tensordict._C # noqa: F401 +except ImportError: + pass + + +def _get_tb_class(): + """Get the TorchBind TensorDict class if available.""" + try: + clsns = getattr(torch, "classes", None) + tdns = getattr(clsns, "tensordict", None) if clsns is not None else None + return getattr(tdns, "TensorDict", None) if tdns is not None else None + except Exception: + return None + + +_TB_TensorDict = _get_tb_class() + +# Script functions that work on TorchBind class +if _TB_TensorDict is not None: + + @torch.jit.script + def _td_get_scripted(td: torch.classes.tensordict.TensorDict, key: str) -> torch.Tensor: # type: ignore[attr-defined] + """Scripted version of TensorDict.get().""" + return torch.ops.tensordict.get(td, key) + + @torch.jit.script + def _td_set_scripted(td: torch.classes.tensordict.TensorDict, key: str, value: torch.Tensor) -> torch.classes.tensordict.TensorDict: # type: ignore[attr-defined] + """Scripted version of TensorDict.set().""" + return torch.ops.tensordict.set(td, key, value) + + @torch.jit.script + def _td_has_scripted(td: torch.classes.tensordict.TensorDict, key: str) -> bool: # type: ignore[attr-defined] + """Scripted version of TensorDict.has().""" + return torch.ops.tensordict.has(td, key) + + @torch.jit.script + def _td_keys_scripted(td: torch.classes.tensordict.TensorDict) -> list[str]: # type: ignore[attr-defined] + """Scripted version of TensorDict.keys().""" + return torch.ops.tensordict.keys(td) + + @torch.jit.script + def _td_device_scripted(td: torch.classes.tensordict.TensorDict) -> torch.device: # type: ignore[attr-defined] + """Scripted version of TensorDict.device.""" + return torch.ops.tensordict.device(td) + + @torch.jit.script + def _td_to_scripted(td: torch.classes.tensordict.TensorDict, device: torch.device) -> torch.classes.tensordict.TensorDict: # type: ignore[attr-defined] + """Scripted version of TensorDict.to().""" + return torch.ops.tensordict.to(td, device) + + @torch.jit.script + def _td_clone_scripted(td: torch.classes.tensordict.TensorDict) -> torch.classes.tensordict.TensorDict: # type: ignore[attr-defined] + """Scripted version of TensorDict.clone().""" + return torch.ops.tensordict.clone(td) + diff --git a/tensordict/csrc/CMakeLists.txt b/tensordict/csrc/CMakeLists.txt index 85b5519d6..c2ee7dcaa 100644 --- a/tensordict/csrc/CMakeLists.txt +++ b/tensordict/csrc/CMakeLists.txt @@ -12,6 +12,7 @@ endif() find_package(Python3 REQUIRED COMPONENTS Interpreter Development) find_package(pybind11 2.13 REQUIRED) +find_package(Torch REQUIRED) file(GLOB SOURCES "*.cpp") @@ -39,10 +40,10 @@ endif() find_package(Python COMPONENTS Development.Static) target_link_libraries(_C ${Python_STATIC_LIBRARIES}) -target_include_directories(_C PRIVATE ${PROJECT_SOURCE_DIR}) +target_include_directories(_C PRIVATE ${PROJECT_SOURCE_DIR} ${TORCH_INCLUDE_DIRS}) #if(APPLE OR WIN32) # Check if the target OS is OSX/macOS -target_link_libraries(_C PRIVATE pybind11::module) +target_link_libraries(_C PRIVATE pybind11::module ${TORCH_LIBRARIES}) #else() # target_link_libraries(_C PRIVATE Python3::Python pybind11::module) #endif() diff --git a/tensordict/csrc/tensordict_bind.cpp b/tensordict/csrc/tensordict_bind.cpp new file mode 100644 index 000000000..8e9f340f7 --- /dev/null +++ b/tensordict/csrc/tensordict_bind.cpp @@ -0,0 +1,135 @@ +#include +#include +#include + +#include +#include + +namespace tensordict_bind_internal { + +using Tensor = at::Tensor; + +// Flat, single-device TensorDict TorchBind class +struct TensorDictBind : torch::CustomClassHolder { + TensorDictBind(const c10::Device& device) + : device_(device) {} + + // Factory from key/value pairs with explicit batch_size and device + static c10::intrusive_ptr from_pairs( + const std::vector& keys, + const std::vector& values, + const std::vector& batch_size, + const c10::Device& device) { + if (keys.size() != values.size()) { + TORCH_CHECK(false, "keys and values must have the same length"); + } + auto obj = c10::make_intrusive(device); + obj->batch_size_ = batch_size; + for (size_t i = 0; i < keys.size(); ++i) { + obj->validate_tensor(values[i]); + obj->validate_batch_size_prefix(values[i].sizes()); + obj->data_.emplace(keys[i], values[i]); + } + return obj; + } + + bool has(const std::string& key) const { + return data_.find(key) != data_.end(); + } + + Tensor get(const std::string& key) const { + auto it = data_.find(key); + TORCH_CHECK(it != data_.end(), "Key not found: ", key); + return it->second; + } + + void set(const std::string& key, const Tensor& value) { + validate_tensor(value); + validate_batch_size_prefix(value.sizes()); + data_[key] = value; + } + + std::vector keys() const { + std::vector out; + out.reserve(data_.size()); + for (const auto& kv : data_) { + out.push_back(kv.first); + } + return out; + } + + std::vector batch_size() const { return batch_size_; } + + c10::Device device() const { return device_; } + + c10::intrusive_ptr to(const c10::Device& new_device) const { + auto obj = c10::make_intrusive(new_device); + obj->batch_size_ = batch_size_; + for (const auto& kv : data_) { + obj->data_.emplace(kv.first, kv.second.to(new_device)); + } + return obj; + } + + c10::intrusive_ptr clone() const { + auto obj = c10::make_intrusive(device_); + obj->batch_size_ = batch_size_; + for (const auto& kv : data_) { + obj->data_.emplace(kv.first, kv.second.clone()); + } + return obj; + } + + private: + void validate_tensor(const Tensor& t) const { + TORCH_CHECK( + t.defined(), + "TensorDictBind: cannot store an undefined tensor"); + TORCH_CHECK( + t.device() == device_, + "All tensors must be on device ", device_.str(), + ", but got ", t.device().str()); + } + + void validate_batch_size_prefix(at::IntArrayRef sizes) const { + // If no batch_size specified, accept any. + if (batch_size_.empty()) { + return; + } + TORCH_CHECK( + sizes.size() >= batch_size_.size(), + "Tensor has fewer dims (", sizes.size(), ") than batch_size prefix (", + batch_size_.size(), ")"); + for (size_t i = 0; i < batch_size_.size(); ++i) { + TORCH_CHECK( + sizes[i] == batch_size_[i], + "Tensor batch dim ", i, " mismatch: expected ", batch_size_[i], + ", got ", sizes[i]); + } + } + + std::unordered_map data_; + std::vector batch_size_; + c10::Device device_; +}; + +} // namespace tensordict_bind_internal + +TORCH_LIBRARY(tensordict, m) { + using tensordict_bind_internal::TensorDictBind; + m.class_("TensorDict") + .def(torch::init()) + .def_static("from_pairs", &TensorDictBind::from_pairs) + .def("has", &TensorDictBind::has) + .def("get", &TensorDictBind::get) + .def("set", &TensorDictBind::set) + .def("keys", &TensorDictBind::keys) + .def("batch_size", &TensorDictBind::batch_size) + .def("device", &TensorDictBind::device) + .def("to", &TensorDictBind::to) + .def("clone", &TensorDictBind::clone); +} + + + + diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 4b043f416..bab3bfd0e 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -34,3 +34,69 @@ lock_blocked, NestedKey, ) + +# TorchBind interop helpers (flat, single-device) +try: + import torch # local import to avoid issues if torch unavailable at parse time + import weakref + + _clsns = getattr(torch, "classes", None) + _tdns = getattr(_clsns, "tensordict", None) if _clsns is not None else None + _TB_TensorDict = ( + getattr(_tdns, "TensorDict", None) if _tdns is not None else None + ) + if _TB_TensorDict is not None: + + class _TorchBindContextManager: + """Context manager wrapper for to_torchbind() that handles conversion back.""" + + def __init__(self, td, tb): + self._td_ref = weakref.ref(td) + self._tb = tb + + def __enter__(self): + return self._tb + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + return False + # Convert TorchBind back to TensorDict and update original + td = self._td_ref() + if td is not None: + from tensordict._td import TensorDict as _TD_class + td_from_tb = _TD_class.from_torchbind(self._tb) + td.update(td_from_tb, inplace=False) + return False + + def to_torchbind(td): # type: ignore[no-redef] + """Convert TensorDict to TorchBind format. + + Can be used as a context manager to automatically convert back: + with td.to_torchbind() as tb: + # use tb (TorchBind object) + # td is automatically updated from tb + """ + if td.device is None: + raise RuntimeError( + "TensorDict.to_torchbind requires a non-None device; call td = td.to(device) first" + ) + keys = list(td.keys()) + values = [td.get(k) for k in keys] + bs = list(td.batch_size) + tb = _TB_TensorDict.from_pairs(keys, values, bs, td.device) + # Return a context manager that wraps the TorchBind object + return _TorchBindContextManager(td, tb) + + def _from_torchbind(cls, obj): # type: ignore[no-redef] + keys = list(obj.keys()) + d = {k: obj.get(k) for k in keys} + return cls(d, batch_size=tuple(obj.batch_size()), device=obj.device()) + + # Expose on class for convenience + from tensordict._td import TensorDict as _TD # local import to avoid cycles + + _TD.to_torchbind = to_torchbind # type: ignore[attr-defined] + _TD.from_torchbind = classmethod(_from_torchbind) # type: ignore[attr-defined] +except Exception: + # If the extension is not built/available, leave helpers undefined + pass diff --git a/test/test_tensordict_torchbind.py b/test/test_tensordict_torchbind.py new file mode 100644 index 000000000..aa7f0e51c --- /dev/null +++ b/test/test_tensordict_torchbind.py @@ -0,0 +1,265 @@ +import pytest +import torch +from tensordict import TensorDict + +# Ensure C++ extension is loaded so TorchBind registration runs +import tensordict._C # noqa: F401 + + +def _tb_class_available(): + try: + clsns = torch.classes + return clsns.tensordict.TensorDict is not None + except Exception: + return False + + +pytestmark = pytest.mark.skipif( + not _tb_class_available(), reason="TorchBind tensordict class not available" +) + + +@torch.jit.script +def _make_td_scripted(x: torch.Tensor) -> torch.classes.tensordict.TensorDict: # type: ignore[attr-defined] + keys = ["x", "y"] + vals = [x, x + 1] + return torch.classes.tensordict.TensorDict.from_pairs(keys, vals, [], x.device) # type: ignore[attr-defined] + + +def test_scripted_factory_roundtrip(): + obj = _make_td_scripted(torch.tensor(3)) + # Wrap back to Python TensorDict + from tensordict import TensorDict + + td = TensorDict.from_torchbind(obj) + assert set(td.keys()) == {"x", "y"} + assert td.get("x").item() == 3 + assert td.get("y").item() == 4 + + +def test_fork_wait_roundtrip(): + f1 = torch.jit.fork(_make_td_scripted, torch.tensor(1)) + f2 = torch.jit.fork(_make_td_scripted, torch.tensor(2)) + obj1 = torch.jit.wait(f1) + obj2 = torch.jit.wait(f2) + + from tensordict import TensorDict + + td1 = TensorDict.from_torchbind(obj1) + td2 = TensorDict.from_torchbind(obj2) + assert td1.get("x").item() == 1 and td1.get("y").item() == 2 + assert td2.get("x").item() == 2 and td2.get("y").item() == 3 + + +def test_original_example_tensordict(): + """Test the original example from user query - TensorDict fork/wait.""" + from tensordict import TensorDict + + @torch.jit.script + def make_td(x: torch.Tensor) -> torch.classes.tensordict.TensorDict: # type: ignore[attr-defined] + keys = ["x", "y"] + vals = [x, x + 1] + return torch.classes.tensordict.TensorDict.from_pairs(keys, vals, [], x.device) # type: ignore[attr-defined] + + def parallel(): + fut1 = torch.jit.fork(make_td, torch.tensor(1)) + fut2 = torch.jit.fork(make_td, torch.tensor(2)) + obj1 = torch.jit.wait(fut1) + obj2 = torch.jit.wait(fut2) + # Convert back to Python TensorDict + td1 = TensorDict.from_torchbind(obj1) + td2 = TensorDict.from_torchbind(obj2) + return td1, td2 + + td1, td2 = parallel() + assert td1.get("x").item() == 1 + assert td1.get("y").item() == 2 + assert td2.get("x").item() == 2 + assert td2.get("y").item() == 3 + + +def test_torchbind_methods_in_scripted(): + """Test that TorchBind methods work inside scripted functions.""" + @torch.jit.script + def process_td(td: torch.classes.tensordict.TensorDict) -> torch.classes.tensordict.TensorDict: # type: ignore[attr-defined] + x = td.get("x") + td.set("y", x + 10) + return td + + tb = torch.classes.tensordict.TensorDict(torch.device("cpu")) # type: ignore[attr-defined] + tb.set("x", torch.tensor(5)) + result = process_td(tb) + assert result.get("y").item() == 15 + + +def test_to_torchbind_requires_device(): + from tensordict import TensorDict, tensordict as td_mod + + td = TensorDict({"x": torch.tensor(1), "y": torch.tensor(2)}, batch_size=[]) + # Device is None by default when not enforced at construction + assert td.device is None + with pytest.raises(RuntimeError, match="requires a non-None device"): + td_mod.to_torchbind(td) + + +def test_to_torchbind_with_device(): + """Test conversion to TorchBind when device is set.""" + from tensordict import TensorDict + + td = TensorDict({"x": torch.tensor(1), "y": torch.tensor(2)}, batch_size=[], device="cpu") + tb = td.to_torchbind() + assert tb.get("x").item() == 1 + assert tb.get("y").item() == 2 + assert tb.device() == torch.device("cpu") + + +def test_batch_size_validation(): + # batch_size [] should accept 0-d tensors + obj = _make_td_scripted(torch.tensor(0)) + from tensordict import TensorDict + + td = TensorDict.from_torchbind(obj) + assert tuple(td.batch_size) == () + + +def test_torchbind_keys(): + """Test keys() method on TorchBind class.""" + tb = torch.classes.tensordict.TensorDict(torch.device("cpu")) # type: ignore[attr-defined] + tb.set("a", torch.ones(1)) + tb.set("b", torch.ones(2)) + keys = tb.keys() + assert set(keys) == {"a", "b"} + + +def test_torchbind_clone(): + """Test clone() method on TorchBind class.""" + tb = torch.classes.tensordict.TensorDict(torch.device("cpu")) # type: ignore[attr-defined] + tb.set("x", torch.ones(1)) + tb_clone = tb.clone() + assert tb_clone.get("x").item() == 1 + # Modify clone shouldn't affect original + tb_clone.set("x", torch.tensor(2.0)) + assert tb.get("x").item() == 1 + assert tb_clone.get("x").item() == 2 + + +def test_torchbind_to_device(): + """Test to() method for device conversion.""" + tb = torch.classes.tensordict.TensorDict(torch.device("cpu")) # type: ignore[attr-defined] + tb.set("x", torch.ones(1)) + assert tb.device() == torch.device("cpu") + # Note: CUDA test is separate below + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_device_validation_cuda(): + # Create a CPU-class then try to set CUDA tensor -> should error + cpu_device = torch.device("cpu") + tb = torch.classes.tensordict.TensorDict(cpu_device) # type: ignore[attr-defined] + cpu_t = torch.ones(1) + tb.set("x", cpu_t) # OK + cuda_t = torch.ones(1, device="cuda") + with pytest.raises(RuntimeError, match="All tensors must be on device"): + tb.set("y", cuda_t) + + +def test_jit_script_end_to_end(): + """End-to-end test demonstrating JIT scripting with TensorDict. + + This test shows the complete workflow: + 1. Create a Python TensorDict + 2. Convert to TorchBind format + 3. Use in a scripted function + 4. Convert back to Python TensorDict + 5. Verify results + """ + + # Step 1: Create a Python TensorDict with some data + td = TensorDict( + { + "input": torch.tensor([1.0, 2.0, 3.0]), + "weight": torch.tensor([0.5, 1.5, 2.5]), + }, + batch_size=[], + device="cpu", + ) + + # Step 2: Convert Python TensorDict to TorchBind format + tb = td.to_torchbind() + assert tb.device() == torch.device("cpu") + assert set(tb.keys()) == {"input", "weight"} + + # Step 3: Define a scripted function that processes the TensorDict + @torch.jit.script + def process_tensordict( + td: torch.classes.tensordict.TensorDict, # type: ignore[attr-defined] + scale: float, + ) -> torch.classes.tensordict.TensorDict: # type: ignore[attr-defined] + """Process TensorDict: multiply input by weight, add scale, store result.""" + input_val = td.get("input") + weight_val = td.get("weight") + # Compute result + result = input_val * weight_val + scale + # Store result back in TensorDict + td.set("output", result) + return td + + # Step 4: Call the scripted function with TorchBind TensorDict + result_tb = process_tensordict(tb, scale=10.0) + + # Step 5: Convert back to Python TensorDict for verification + result_td = TensorDict.from_torchbind(result_tb) + + # Step 6: Verify the results + assert set(result_td.keys()) == {"input", "weight", "output"} + assert torch.allclose(result_td.get("input"), torch.tensor([1.0, 2.0, 3.0])) + assert torch.allclose(result_td.get("weight"), torch.tensor([0.5, 1.5, 2.5])) + # output = input * weight + scale = [1*0.5, 2*1.5, 3*2.5] + 10 = [0.5, 3.0, 7.5] + 10 = [10.5, 13.0, 17.5] + expected_output = torch.tensor([10.5, 13.0, 17.5]) + assert torch.allclose(result_td.get("output"), expected_output) + + +def test_to_torchbind_context_manager(): + """Test that to_torchbind() works as a context manager, automatically converting back.""" + from tensordict import TensorDict + + # Create a Python TensorDict + td = TensorDict( + {"x": torch.tensor(1.0), "y": torch.tensor(2.0)}, batch_size=[], device="cpu" + ) + + # Define a scripted function that modifies the TorchBind TensorDict + @torch.jit.script + def modify_tb(tb: torch.classes.tensordict.TensorDict) -> None: # type: ignore[attr-defined] + x = tb.get("x") + tb.set("x", x + 10) + tb.set("z", torch.tensor(30.0)) + + # Use as context manager + with td.to_torchbind() as tb: + # Verify we have a TorchBind object inside the context + assert tb.device() == torch.device("cpu") + + # Modify the TorchBind TensorDict + modify_tb(tb) + + # Verify modifications in TorchBind object + assert tb.get("x").item() == 11.0 + assert tb.get("y").item() == 2.0 + assert tb.get("z").item() == 30.0 + + # After exiting context, td should be automatically updated from TorchBind + # Verify td is still a Python TensorDict + assert isinstance(td, TensorDict) + + # Verify the modifications were propagated back + assert td.get("x").item() == 11.0 + assert td.get("y").item() == 2.0 + assert td.get("z").item() == 30.0 + + # Verify we can still use Python TensorDict methods + assert set(td.keys()) == {"x", "y", "z"} + + + +