Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/source/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,30 @@ output keys by querying the in_keys and out_keys attributes. It is also possible
:class:`~tensordict.nn.TensorDictSequential` with only the modules that are indispensable to satisfy those requirements.
The :class:`~tensordict.nn.TensorDictModule` is also compatible with :func:`~torch.vmap` and other ``torch.func``
capabilities.

TorchScript fork/wait interop (TorchBind)
----------------------------------------

TorchScript Futures can only carry JIT-visible types. To enable passing a TensorDict-like payload through
``torch.jit.fork``/``wait``, a lightweight TorchBind class is provided. Use a scripted factory to return the
TorchBind class, then wrap back to :class:`~tensordict.TensorDict` in Python:

.. code-block:: python

import torch
from tensordict import TensorDict

@torch.jit.script
def make_td_scripted(x: torch.Tensor) -> torch.classes.tensordict.TensorDict: # type: ignore[attr-defined]
keys = ["x", "y"]
vals = [x, x + 1]
return torch.classes.tensordict.TensorDict.from_pairs(keys, vals, [], x.device) # type: ignore[attr-defined]

f1 = torch.jit.fork(make_td_scripted, torch.tensor(1))
f2 = torch.jit.fork(make_td_scripted, torch.tensor(2))
obj1 = torch.jit.wait(f1)
obj2 = torch.jit.wait(f2)
td1 = TensorDict.from_torchbind(obj1)
td2 = TensorDict.from_torchbind(obj2)

Limitations in v1: flat structure (no nesting) and single-device TensorDicts (``td.device`` must be set).
17 changes: 17 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,21 @@ def build_extension(self, ext):
else:
# For regular installs, place the extension in the build directory
extdir = os.path.abspath(os.path.join(self.build_lib, "tensordict"))

# Try to import torch to obtain its CMake prefix for find_package(Torch)
cmake_prefix = os.environ.get("CMAKE_PREFIX_PATH", "")
try:
import torch # noqa: F401
from torch.utils import cmake_prefix_path as torch_cmake_prefix_path # type: ignore

# Prepend Torch's cmake prefix so CMake can find Torch
cmake_prefix = (
f"{torch_cmake_prefix_path}:{cmake_prefix}" if cmake_prefix else torch_cmake_prefix_path
)
except Exception:
# Torch not importable at build time; rely on environment CMAKE_PREFIX_PATH
pass

cmake_args = [
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY={extdir}",
Expand All @@ -73,6 +88,8 @@ def build_extension(self, ext):
# for windows
"-DCMAKE_BUILD_TYPE=Release",
]
if cmake_prefix:
cmake_args.append(f"-DCMAKE_PREFIX_PATH={cmake_prefix}")

build_args = []
if not os.path.exists(self.build_temp):
Expand Down
19 changes: 19 additions & 0 deletions tensordict/_contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,22 @@ def _reverse_to(self, args, kwargs, out):


LAST_OP_MAPS["to"] = _reverse_to


def _reverse_to_torchbind(self, args, kwargs, out):
"""Reverse the to_torchbind() operation by converting back to TensorDict.

Uses from_torchbind to convert the TorchBind object back to a Python TensorDict,
then updates the original tensordict with the potentially modified values.
"""
if out is None:
return self
# Convert the TorchBind object back to TensorDict
from tensordict._td import TensorDict

td_from_tb = TensorDict.from_torchbind(out)
# Update self with the values from the TorchBind object
return self.update(td_from_tb, inplace=False)


LAST_OP_MAPS["to_torchbind"] = _reverse_to_torchbind
66 changes: 66 additions & 0 deletions tensordict/_jit_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""TorchScript-compatible ops for TensorDict operations."""

import torch

# Ensure C++ extension is loaded
try:
import tensordict._C # noqa: F401
except ImportError:
pass


def _get_tb_class():
"""Get the TorchBind TensorDict class if available."""
try:
clsns = getattr(torch, "classes", None)
tdns = getattr(clsns, "tensordict", None) if clsns is not None else None
return getattr(tdns, "TensorDict", None) if tdns is not None else None
except Exception:
return None


_TB_TensorDict = _get_tb_class()

# Script functions that work on TorchBind class
if _TB_TensorDict is not None:

@torch.jit.script
def _td_get_scripted(td: torch.classes.tensordict.TensorDict, key: str) -> torch.Tensor: # type: ignore[attr-defined]
"""Scripted version of TensorDict.get()."""
return torch.ops.tensordict.get(td, key)

@torch.jit.script
def _td_set_scripted(td: torch.classes.tensordict.TensorDict, key: str, value: torch.Tensor) -> torch.classes.tensordict.TensorDict: # type: ignore[attr-defined]
"""Scripted version of TensorDict.set()."""
return torch.ops.tensordict.set(td, key, value)

@torch.jit.script
def _td_has_scripted(td: torch.classes.tensordict.TensorDict, key: str) -> bool: # type: ignore[attr-defined]
"""Scripted version of TensorDict.has()."""
return torch.ops.tensordict.has(td, key)

@torch.jit.script
def _td_keys_scripted(td: torch.classes.tensordict.TensorDict) -> list[str]: # type: ignore[attr-defined]
"""Scripted version of TensorDict.keys()."""
return torch.ops.tensordict.keys(td)

@torch.jit.script
def _td_device_scripted(td: torch.classes.tensordict.TensorDict) -> torch.device: # type: ignore[attr-defined]
"""Scripted version of TensorDict.device."""
return torch.ops.tensordict.device(td)

@torch.jit.script
def _td_to_scripted(td: torch.classes.tensordict.TensorDict, device: torch.device) -> torch.classes.tensordict.TensorDict: # type: ignore[attr-defined]
"""Scripted version of TensorDict.to()."""
return torch.ops.tensordict.to(td, device)

@torch.jit.script
def _td_clone_scripted(td: torch.classes.tensordict.TensorDict) -> torch.classes.tensordict.TensorDict: # type: ignore[attr-defined]
"""Scripted version of TensorDict.clone()."""
return torch.ops.tensordict.clone(td)

5 changes: 3 additions & 2 deletions tensordict/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ endif()

find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
find_package(pybind11 2.13 REQUIRED)
find_package(Torch REQUIRED)

file(GLOB SOURCES "*.cpp")

Expand Down Expand Up @@ -39,10 +40,10 @@ endif()
find_package(Python COMPONENTS Development.Static)
target_link_libraries(_C ${Python_STATIC_LIBRARIES})

target_include_directories(_C PRIVATE ${PROJECT_SOURCE_DIR})
target_include_directories(_C PRIVATE ${PROJECT_SOURCE_DIR} ${TORCH_INCLUDE_DIRS})

#if(APPLE OR WIN32) # Check if the target OS is OSX/macOS
target_link_libraries(_C PRIVATE pybind11::module)
target_link_libraries(_C PRIVATE pybind11::module ${TORCH_LIBRARIES})
#else()
# target_link_libraries(_C PRIVATE Python3::Python pybind11::module)
#endif()
Expand Down
135 changes: 135 additions & 0 deletions tensordict/csrc/tensordict_bind.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#include <ATen/ATen.h>
#include <c10/core/Device.h>
#include <torch/script.h>

#include <unordered_map>
#include <vector>

namespace tensordict_bind_internal {

using Tensor = at::Tensor;

// Flat, single-device TensorDict TorchBind class
struct TensorDictBind : torch::CustomClassHolder {
TensorDictBind(const c10::Device& device)
: device_(device) {}

// Factory from key/value pairs with explicit batch_size and device
static c10::intrusive_ptr<TensorDictBind> from_pairs(
const std::vector<std::string>& keys,
const std::vector<Tensor>& values,
const std::vector<int64_t>& batch_size,
const c10::Device& device) {
if (keys.size() != values.size()) {
TORCH_CHECK(false, "keys and values must have the same length");
}
auto obj = c10::make_intrusive<TensorDictBind>(device);
obj->batch_size_ = batch_size;
for (size_t i = 0; i < keys.size(); ++i) {
obj->validate_tensor(values[i]);
obj->validate_batch_size_prefix(values[i].sizes());
obj->data_.emplace(keys[i], values[i]);
}
return obj;
}

bool has(const std::string& key) const {
return data_.find(key) != data_.end();
}

Tensor get(const std::string& key) const {
auto it = data_.find(key);
TORCH_CHECK(it != data_.end(), "Key not found: ", key);
return it->second;
}

void set(const std::string& key, const Tensor& value) {
validate_tensor(value);
validate_batch_size_prefix(value.sizes());
data_[key] = value;
}

std::vector<std::string> keys() const {
std::vector<std::string> out;
out.reserve(data_.size());
for (const auto& kv : data_) {
out.push_back(kv.first);
}
return out;
}

std::vector<int64_t> batch_size() const { return batch_size_; }

c10::Device device() const { return device_; }

c10::intrusive_ptr<TensorDictBind> to(const c10::Device& new_device) const {
auto obj = c10::make_intrusive<TensorDictBind>(new_device);
obj->batch_size_ = batch_size_;
for (const auto& kv : data_) {
obj->data_.emplace(kv.first, kv.second.to(new_device));
}
return obj;
}

c10::intrusive_ptr<TensorDictBind> clone() const {
auto obj = c10::make_intrusive<TensorDictBind>(device_);
obj->batch_size_ = batch_size_;
for (const auto& kv : data_) {
obj->data_.emplace(kv.first, kv.second.clone());
}
return obj;
}

private:
void validate_tensor(const Tensor& t) const {
TORCH_CHECK(
t.defined(),
"TensorDictBind: cannot store an undefined tensor");
TORCH_CHECK(
t.device() == device_,
"All tensors must be on device ", device_.str(),
", but got ", t.device().str());
}

void validate_batch_size_prefix(at::IntArrayRef sizes) const {
// If no batch_size specified, accept any.
if (batch_size_.empty()) {
return;
}
TORCH_CHECK(
sizes.size() >= batch_size_.size(),
"Tensor has fewer dims (", sizes.size(), ") than batch_size prefix (",
batch_size_.size(), ")");
for (size_t i = 0; i < batch_size_.size(); ++i) {
TORCH_CHECK(
sizes[i] == batch_size_[i],
"Tensor batch dim ", i, " mismatch: expected ", batch_size_[i],
", got ", sizes[i]);
}
}

std::unordered_map<std::string, Tensor> data_;
std::vector<int64_t> batch_size_;
c10::Device device_;
};

} // namespace tensordict_bind_internal

TORCH_LIBRARY(tensordict, m) {
using tensordict_bind_internal::TensorDictBind;
m.class_<TensorDictBind>("TensorDict")
.def(torch::init<c10::Device>())
.def_static("from_pairs", &TensorDictBind::from_pairs)
.def("has", &TensorDictBind::has)
.def("get", &TensorDictBind::get)
.def("set", &TensorDictBind::set)
.def("keys", &TensorDictBind::keys)
.def("batch_size", &TensorDictBind::batch_size)
.def("device", &TensorDictBind::device)
.def("to", &TensorDictBind::to)
.def("clone", &TensorDictBind::clone);
}




66 changes: 66 additions & 0 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,69 @@
lock_blocked,
NestedKey,
)

# TorchBind interop helpers (flat, single-device)
try:
import torch # local import to avoid issues if torch unavailable at parse time
import weakref

_clsns = getattr(torch, "classes", None)
_tdns = getattr(_clsns, "tensordict", None) if _clsns is not None else None
_TB_TensorDict = (
getattr(_tdns, "TensorDict", None) if _tdns is not None else None
)
if _TB_TensorDict is not None:

class _TorchBindContextManager:
"""Context manager wrapper for to_torchbind() that handles conversion back."""

def __init__(self, td, tb):
self._td_ref = weakref.ref(td)
self._tb = tb

def __enter__(self):
return self._tb

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
return False
# Convert TorchBind back to TensorDict and update original
td = self._td_ref()
if td is not None:
from tensordict._td import TensorDict as _TD_class
td_from_tb = _TD_class.from_torchbind(self._tb)
td.update(td_from_tb, inplace=False)
return False

def to_torchbind(td): # type: ignore[no-redef]
"""Convert TensorDict to TorchBind format.

Can be used as a context manager to automatically convert back:
with td.to_torchbind() as tb:
# use tb (TorchBind object)
# td is automatically updated from tb
"""
if td.device is None:
raise RuntimeError(
"TensorDict.to_torchbind requires a non-None device; call td = td.to(device) first"
)
keys = list(td.keys())
values = [td.get(k) for k in keys]
bs = list(td.batch_size)
tb = _TB_TensorDict.from_pairs(keys, values, bs, td.device)
# Return a context manager that wraps the TorchBind object
return _TorchBindContextManager(td, tb)

def _from_torchbind(cls, obj): # type: ignore[no-redef]
keys = list(obj.keys())
d = {k: obj.get(k) for k in keys}
return cls(d, batch_size=tuple(obj.batch_size()), device=obj.device())

# Expose on class for convenience
from tensordict._td import TensorDict as _TD # local import to avoid cycles

_TD.to_torchbind = to_torchbind # type: ignore[attr-defined]
_TD.from_torchbind = classmethod(_from_torchbind) # type: ignore[attr-defined]
except Exception:
# If the extension is not built/available, leave helpers undefined
pass
Loading