Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions docs/source/reference/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,85 @@ to build distributions from network outputs and get summary statistics or sample
device=None,
is_shared=False)

Type-Safe TensorClass Modules
------------------------------

The :class:`~.TensorClassModuleBase` provides a type-safe way to define modules that work with
:class:`~tensordict.tensorclass.TensorClass` inputs and outputs. This offers compile-time type
checking and improved code clarity compared to working with string-based keys.

A :class:`~.TensorClassModuleBase` subclass specifies its input and output types through generic
type parameters. The module can be converted to work with :class:`~.TensorDict` objects using the
:meth:`~.TensorClassModuleBase.as_td_module` method, which returns a :class:`~.TensorClassModuleWrapper`:

.. code-block::

>>> import torch
>>> from tensordict.tensorclass import TensorClass
>>> from tensordict.nn import TensorClassModuleBase
>>> from tensordict import TensorDict
>>>
>>> # Define input and output TensorClass types
>>> class InputTC(TensorClass):
... a: torch.Tensor
... b: torch.Tensor
...
>>> class OutputTC(TensorClass):
... sum: torch.Tensor
... difference: torch.Tensor
...
>>> # Create a type-safe module
>>> class MyModule(TensorClassModuleBase[InputTC, OutputTC]):
... def forward(self, x: InputTC) -> OutputTC:
... return OutputTC(
... sum=x.a + x.b,
... difference=x.a - x.b,
... batch_size=x.batch_size
... )
...
>>> # Use with TensorClass
>>> module = MyModule()
>>> input_tc = InputTC(a=torch.tensor([1.0, 2.0]), b=torch.tensor([3.0, 4.0]), batch_size=[2])
>>> output = module(input_tc)
>>> print(output.sum)
tensor([4., 6.])
>>> print(output.difference)
tensor([-2., -2.])
>>>
>>> # Convert to TensorDictModule for use in TensorDict workflows
>>> td_module = module.as_td_module()
>>> td = TensorDict({"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])}, batch_size=[2])
>>> result = td_module(td)
>>> print(result)
TensorDict(
fields={
a: Tensor(torch.Size([2]), dtype=torch.float32),
b: Tensor(torch.Size([2]), dtype=torch.float32),
difference: Tensor(torch.Size([2]), dtype=torch.float32),
sum: Tensor(torch.Size([2]), dtype=torch.float32)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)

The type-safe approach offers several benefits:

* **Type checking**: IDEs and type checkers can verify correct usage at development time
* **Self-documenting**: The input and output structure is clear from the type signature
* **Refactoring**: Renaming fields in TensorClass definitions is caught by type checkers
* **Nested structures**: Support for nested TensorClass types with automatic key extraction

:class:`~.TensorClassModuleBase` modules can be composed and used in :class:`~.TensorDictSequential`
after conversion via :meth:`~.TensorClassModuleBase.as_td_module`.


.. autosummary::
:toctree: generated/
:template: td_template_noinherit.rst

TensorDictModuleBase
TensorDictModule
TensorClassModuleBase
TensorClassModuleWrapper
ProbabilisticTensorDictModule
ProbabilisticTensorDictSequential
TensorDictSequential
Expand Down
9 changes: 8 additions & 1 deletion tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@
unravel_key_list,
)
from tensordict._pytree import *
from tensordict.nn import as_tensordict_module, TensorDictParams
from tensordict.nn import (
as_tensordict_module,
TensorClassModuleBase,
TensorClassModuleWrapper,
TensorDictParams,
)

try:
from tensordict._version import __version__ # @manual=//pytorch/tensordict:version
Expand Down Expand Up @@ -149,6 +154,8 @@
"NonTensorStack",
# NN imports
"as_tensordict_module",
"TensorClassModuleBase",
"TensorClassModuleWrapper",
"TensorDictParams",
# Version
"__version__",
Expand Down
6 changes: 6 additions & 0 deletions tensordict/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
set_interaction_type,
)
from tensordict.nn.sequence import TensorDictSequential
from tensordict.nn.tensorclass_module import (
TensorClassModuleBase,
TensorClassModuleWrapper,
)
from tensordict.nn.utils import (
add_custom_mapping,
biased_softplus,
Expand All @@ -57,6 +61,8 @@
"TensorDictSequential",
"EnsembleModule",
"CudaGraphModule",
"TensorClassModuleBase",
"TensorClassModuleWrapper",
# Probabilistic modules
"ProbabilisticTensorDictModule",
"ProbabilisticTensorDictSequential",
Expand Down
238 changes: 238 additions & 0 deletions tensordict/nn/tensorclass_module.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new classes need to be added to the doc

Copy link
Author

@az0uz az0uz Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I need to write anything else than adding them to docs/source/reference/nn.rst like above?

I've added a section with description and examples. let me know if that works for you.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's really cool thx

Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import Field
from typing import Any, cast, Generic, get_args, get_origin, TypeVar

from tensordict._td import TensorDict
from tensordict.nn.common import dispatch, TensorDictModuleBase
from tensordict.tensorclass import TensorClass
from torch import nn, Tensor

__all__ = ["TensorClassModuleBase", "TensorClassModuleWrapper"]


def _tensor_class_keys(tensorclass_type: type[TensorClass]) -> list[tuple[str, ...]]:
"""Extract all keys from a TensorClass type, including nested keys.

Args:
tensorclass_type (type[TensorClass]): The TensorClass type to extract keys from.

Returns:
list[tuple[str, ...]]: A list of key tuples representing all fields in the TensorClass.

"""
fields = cast("Iterable[Field[Any]]", tensorclass_type.fields())
keys: list[tuple[str, ...]] = []
for field in fields:
key = field.name
if issubclass(field.type, TensorClass):
subkeys = _tensor_class_keys(cast(type[TensorClass], field.type))
for subkey in subkeys:
keys.append((key,) + subkey)
else:
keys.append((key,))
return keys


InputTensorClass = TypeVar("InputTensorClass", bound=TensorClass)
OutputTensorClass = TypeVar("OutputTensorClass", bound=TensorClass)


class TensorClassModuleWrapper(TensorDictModuleBase):
"""Wrapper class for TensorClassModuleBase objects.

This wrapper allows TensorClassModuleBase instances to be used in TensorDict-based
workflows by handling the conversion between TensorDict and TensorClass representations.
When called with a TensorDict, the wrapper converts it to a TensorClass, passes it through
the wrapped module, and converts the output back to a TensorDict.

Args:
module (TensorClassModuleBase): The TensorClassModuleBase instance to wrap.

Examples:
>>> from tensordict import TensorDict
>>> from tensordict.tensorclass import TensorClass
>>> from tensordict.nn import TensorClassModuleBase
>>> import torch
>>>
>>> class InputTC(TensorClass):
... x: torch.Tensor
...
>>> class OutputTC(TensorClass):
... y: torch.Tensor
...
>>> class MyModule(TensorClassModuleBase[InputTC, OutputTC]):
... def forward(self, input: InputTC) -> OutputTC:
... return OutputTC(y=input.x + 1, batch_size=input.batch_size)
...
>>> module = MyModule()
>>> td_module = module.as_td_module()
>>> td = TensorDict({"x": torch.zeros(3)}, batch_size=[3])
>>> result = td_module(td)
>>> assert "y" in result

"""

def __init__(
self, module: TensorClassModuleBase[InputTensorClass, OutputTensorClass]
) -> None:
super().__init__()
self.tc_module = module
self.in_keys = _tensor_class_keys(cast(type[TensorClass], module.input_type))
self.out_keys = _tensor_class_keys(cast(type[TensorClass], module.output_type))

@dispatch(auto_batch_size=False)
def forward(self, tensordict: TensorDict, *args, **kwargs) -> TensorDict:
"""Forward pass converting TensorDict to TensorClass and back.

Args:
tensordict (TensorDict): Input tensordict.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.

Returns:
TensorDict: Output tensordict.

"""
return self.tc_module(
self.tc_module.input_type.from_tensordict(tensordict)
).to_tensordict()


InputClass = TypeVar("InputClass", bound=(TensorClass | Tensor))
OutputClass = TypeVar("OutputClass", bound=(TensorClass | Tensor))


class TensorClassModuleBase(Generic[InputClass, OutputClass], ABC, nn.Module):
"""A TensorClassModuleBase is a base class for modules that operate on TensorClass instances.

TensorClassModuleBase subclasses provide a type-safe way to define modules that work with TensorClass
inputs and outputs. The class automatically extracts input and output type information from the
generic type parameters.

The module can be converted to a TensorDictModule using the :meth:`as_td_module`
method, allowing it to be used in TensorDict-based workflows.

Type Parameters:
InputClass: The input type, must be a TensorClass or Tensor.
OutputClass: The output type, must be a TensorClass or Tensor.

Attributes:
input_type (type[InputClass]): The input type class.
output_type (type[OutputClass]): The output type class.

Examples:
>>> from tensordict.tensorclass import TensorClass
>>> from tensordict.nn import TensorClassModuleBase
>>> import torch
>>>
>>> class InputTC(TensorClass):
... a: torch.Tensor
... b: torch.Tensor
...
>>> class OutputTC(TensorClass):
... result: torch.Tensor
...
>>> class AddModule(TensorClassModuleBase[InputTC, OutputTC]):
... def forward(self, x: InputTC) -> OutputTC:
... return OutputTC(
... result=x.a + x.b,
... batch_size=x.batch_size
... )
...
>>> module = AddModule()
>>> input_tc = InputTC(a=torch.tensor([1.0]), b=torch.tensor([2.0]), batch_size=[1])
>>> output = module(input_tc)
>>> assert output.result == torch.tensor([3.0])

"""

input_type: type[InputClass]
output_type: type[OutputClass]

def __init_subclass__(cls) -> None:
"""Initialize subclass by extracting type information from generic parameters."""
super().__init_subclass__()
for base in cls.__orig_bases__: # type:ignore[attr-defined]
origin = get_origin(base)
if origin is TensorClassModuleBase:
generic_args = get_args(base)
if generic_args:
cls.input_type, cls.output_type = generic_args
else:
raise ValueError(
"Generic input/output types not set in TensorClassModuleBase"
)

@abstractmethod
def forward(self, x: InputClass) -> OutputClass:
"""Forward pass of the module.

Args:
x (InputClass): Input instance.

Returns:
OutputClass: Output instance.

"""
...

def __call__(self, x: InputClass) -> OutputClass:
"""Call the module's forward method.

Args:
x (InputClass): Input instance.

Returns:
OutputClass: Output instance.

"""
return cast("OutputClass", super().__call__(x))

def as_td_module(self) -> TensorClassModuleWrapper:
"""Convert this module to a TensorDictModule.

This method wraps the TensorClassModuleBase in a TensorClassModuleWrapper,
allowing it to be used with TensorDict inputs and outputs.

Returns:
TensorClassModuleWrapper: A wrapper that converts between TensorDict
and TensorClass representations.

Raises:
ValueError: If either input_type or output_type is not a TensorClass.

Examples:
>>> from tensordict import TensorDict
>>> from tensordict.tensorclass import TensorClass
>>> from tensordict.nn import TensorClassModuleBase
>>> import torch
>>>
>>> class InputTC(TensorClass):
... x: torch.Tensor
...
>>> class OutputTC(TensorClass):
... y: torch.Tensor
...
>>> class MyModule(TensorClassModuleBase[InputTC, OutputTC]):
... def forward(self, input: InputTC) -> OutputTC:
... return OutputTC(y=input.x * 2, batch_size=input.batch_size)
...
>>> module = MyModule()
>>> td_module = module.as_td_module()
>>> td = TensorDict({"x": torch.ones(3)}, batch_size=[3])
>>> result = td_module(td)
>>> assert (result["y"] == 2).all()

"""
if not (
issubclass(self.input_type, TensorClass)
and issubclass(self.output_type, TensorClass)
):
raise ValueError(
"Only TensorClassModuleBase implementations with both input and "
"output type as TensorClass can be converted to TensorDictModule"
)
return TensorClassModuleWrapper(self) # type:ignore[arg-type,type-var]
Loading