From 3921222c189884bdc254422f726c7d9aa0ebac2b Mon Sep 17 00:00:00 2001 From: Antoine de Maleprade Date: Thu, 30 Oct 2025 16:55:27 +0100 Subject: [PATCH 1/8] feat(TensorClassModuleBase): Implement a type-checked equivalent to TensorDictModuleBase --- pyproject.toml | 1 + tensordict/__init__.py | 9 +- tensordict/nn/__init__.py | 6 + tensordict/nn/tensorclass_module.py | 242 ++++++++++++++++++++++++++++ test/test_compile.py | 18 +-- test/test_tensorclass_module.py | 188 +++++++++++++++++++++ 6 files changed, 450 insertions(+), 14 deletions(-) create mode 100644 tensordict/nn/tensorclass_module.py create mode 100644 test/test_tensorclass_module.py diff --git a/pyproject.toml b/pyproject.toml index a6b6181d2..ba8c2fcee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ tests = [ h5 = ["h5py>=3.8"] dev = ["pybind11", "ninja"] typecheck = ["mypy>=1.0.0"] +onnx = ["onnx", "onnxscript", "onnxruntime"] [tool.setuptools] include-package-data = false diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 162cfb26d..91478555f 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -73,7 +73,12 @@ unravel_key_list, ) from tensordict._pytree import * -from tensordict.nn import as_tensordict_module, TensorDictParams +from tensordict.nn import ( + as_tensordict_module, + TensorClassModuleBase, + TensorClassModuleWrapper, + TensorDictParams, +) try: from tensordict._version import __version__ # @manual=//pytorch/tensordict:version @@ -149,6 +154,8 @@ "NonTensorStack", # NN imports "as_tensordict_module", + "TensorClassModuleBase", + "TensorClassModuleWrapper", "TensorDictParams", # Version "__version__", diff --git a/tensordict/nn/__init__.py b/tensordict/nn/__init__.py index 6e3e68fdd..66a5a7ca8 100644 --- a/tensordict/nn/__init__.py +++ b/tensordict/nn/__init__.py @@ -11,6 +11,10 @@ TensorDictModuleWrapper, WrapModule, ) +from tensordict.nn.tensorclass_module import ( + TensorClassModuleBase, + TensorClassModuleWrapper, +) from tensordict.nn.distributions import ( AddStateIndependentNormalScale, CompositeDistribution, @@ -57,6 +61,8 @@ "TensorDictSequential", "EnsembleModule", "CudaGraphModule", + "TensorClassModuleBase", + "TensorClassModuleWrapper", # Probabilistic modules "ProbabilisticTensorDictModule", "ProbabilisticTensorDictSequential", diff --git a/tensordict/nn/tensorclass_module.py b/tensordict/nn/tensorclass_module.py new file mode 100644 index 000000000..6d879d00a --- /dev/null +++ b/tensordict/nn/tensorclass_module.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Iterable +from dataclasses import Field +from typing import Any, cast, Generic, get_args, get_origin, List, Tuple, TypeVar, Union + +from tensordict._td import TensorDict +from tensordict.nn.common import dispatch, TensorDictModuleBase +from tensordict.tensorclass import TensorClass +from torch import nn, Tensor + +__all__ = ["TensorClassModuleBase", "TensorClassModuleWrapper"] + + +def _tensor_class_keys(tensorclass_type: type[TensorClass]) -> List[Tuple[str, ...]]: + """Extract all keys from a TensorClass type, including nested keys. + + Args: + tensorclass_type (type[TensorClass]): The TensorClass type to extract keys from. + + Returns: + list[tuple[str, ...]]: A list of key tuples representing all fields in the TensorClass. + + """ + fields = cast("Iterable[Field[Any]]", tensorclass_type.fields()) + keys: List[Tuple[str, ...]] = [] + for field in fields: + key = field.name + if issubclass(field.type, TensorClass): + subkeys = _tensor_class_keys(cast(type[TensorClass], field.type)) + for subkey in subkeys: + keys.append((key,) + subkey) + else: + keys.append((key,)) + return keys + + +InputTensorClass = TypeVar("InputTensorClass", bound=TensorClass) +OutputTensorClass = TypeVar("OutputTensorClass", bound=TensorClass) + + +class TensorClassModuleWrapper(TensorDictModuleBase): + """Wrapper class for TensorClassModuleBase objects. + + This wrapper allows TensorClassModuleBase instances to be used in TensorDict-based + workflows by handling the conversion between TensorDict and TensorClass representations. + When called with a TensorDict, the wrapper converts it to a TensorClass, passes it through + the wrapped module, and converts the output back to a TensorDict. + + Args: + module (TensorClassModuleBase): The TensorClassModuleBase instance to wrap. + + Examples: + >>> from tensordict import TensorDict + >>> from tensordict.tensorclass import TensorClass + >>> from tensordict.nn import TensorClassModuleBase + >>> import torch + >>> + >>> class InputTC(TensorClass): + ... x: torch.Tensor + ... + >>> class OutputTC(TensorClass): + ... y: torch.Tensor + ... + >>> class MyModule(TensorClassModuleBase[InputTC, OutputTC]): + ... def forward(self, input: InputTC) -> OutputTC: + ... return OutputTC(y=input.x + 1, batch_size=input.batch_size) + ... + >>> module = MyModule() + >>> td_module = module.as_td_module() + >>> td = TensorDict({"x": torch.zeros(3)}, batch_size=[3]) + >>> result = td_module(td) + >>> assert "y" in result + + """ + + def __init__( + self, module: TensorClassModuleBase[InputTensorClass, OutputTensorClass] + ) -> None: + super().__init__() + self.tc_module = module + self.in_keys = _tensor_class_keys(cast(type[TensorClass], module.input_type)) + self.out_keys = _tensor_class_keys(cast(type[TensorClass], module.output_type)) + + @dispatch(auto_batch_size=False) + def forward(self, tensordict: TensorDict, *args, **kwargs) -> TensorDict: + """Forward pass converting TensorDict to TensorClass and back. + + Args: + tensordict (TensorDict): Input tensordict. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + TensorDict: Output tensordict. + + """ + return self.tc_module( + self.tc_module.input_type.from_dict( + input_dict=tensordict.to_dict(), + batch_size=tensordict.batch_size, + device=tensordict.device, + ) + ).to_tensordict() + + +InputClass = TypeVar("InputClass", bound=Union[TensorClass, Tensor]) +OutputClass = TypeVar("OutputClass", bound=Union[TensorClass, Tensor]) + + +class TensorClassModuleBase(Generic[InputClass, OutputClass], ABC, nn.Module): + """A TensorClassModuleBase is a base class for modules that operate on TensorClass instances. + + TensorClassModuleBase subclasses provide a type-safe way to define modules that work with TensorClass + inputs and outputs. The class automatically extracts input and output type information from the + generic type parameters. + + The module can be converted to a TensorDictModule using the :meth:`as_td_module` + method, allowing it to be used in TensorDict-based workflows. + + Type Parameters: + InputClass: The input type, must be a TensorClass or Tensor. + OutputClass: The output type, must be a TensorClass or Tensor. + + Attributes: + input_type (type[InputClass]): The input type class. + output_type (type[OutputClass]): The output type class. + + Examples: + >>> from tensordict.tensorclass import TensorClass + >>> from tensordict.nn import TensorClassModuleBase + >>> import torch + >>> + >>> class InputTC(TensorClass): + ... a: torch.Tensor + ... b: torch.Tensor + ... + >>> class OutputTC(TensorClass): + ... result: torch.Tensor + ... + >>> class AddModule(TensorClassModuleBase[InputTC, OutputTC]): + ... def forward(self, x: InputTC) -> OutputTC: + ... return OutputTC( + ... result=x.a + x.b, + ... batch_size=x.batch_size + ... ) + ... + >>> module = AddModule() + >>> input_tc = InputTC(a=torch.tensor([1.0]), b=torch.tensor([2.0]), batch_size=[1]) + >>> output = module(input_tc) + >>> assert output.result == torch.tensor([3.0]) + + """ + + input_type: type[InputClass] + output_type: type[OutputClass] + + def __init_subclass__(cls) -> None: + """Initialize subclass by extracting type information from generic parameters.""" + super().__init_subclass__() + for base in cls.__orig_bases__: # type:ignore[attr-defined] + origin = get_origin(base) + if origin is TensorClassModuleBase: + generic_args = get_args(base) + if generic_args: + cls.input_type, cls.output_type = generic_args + else: + raise ValueError( + "Generic input/output types not set in TensorClassModuleBase" + ) + + @abstractmethod + def forward(self, x: InputClass) -> OutputClass: + """Forward pass of the module. + + Args: + x (InputClass): Input instance. + + Returns: + OutputClass: Output instance. + + """ + ... + + def __call__(self, x: InputClass) -> OutputClass: + """Call the module's forward method. + + Args: + x (InputClass): Input instance. + + Returns: + OutputClass: Output instance. + + """ + return cast("OutputClass", super().__call__(x)) + + def as_td_module(self) -> TensorClassModuleWrapper: + """Convert this module to a TensorDictModule. + + This method wraps the TensorClassModuleBase in a TensorClassModuleWrapper, + allowing it to be used with TensorDict inputs and outputs. + + Returns: + TensorClassModuleWrapper: A wrapper that converts between TensorDict + and TensorClass representations. + + Raises: + ValueError: If either input_type or output_type is not a TensorClass. + + Examples: + >>> from tensordict import TensorDict + >>> from tensordict.tensorclass import TensorClass + >>> from tensordict.nn import TensorClassModuleBase + >>> import torch + >>> + >>> class InputTC(TensorClass): + ... x: torch.Tensor + ... + >>> class OutputTC(TensorClass): + ... y: torch.Tensor + ... + >>> class MyModule(TensorClassModuleBase[InputTC, OutputTC]): + ... def forward(self, input: InputTC) -> OutputTC: + ... return OutputTC(y=input.x * 2, batch_size=input.batch_size) + ... + >>> module = MyModule() + >>> td_module = module.as_td_module() + >>> td = TensorDict({"x": torch.ones(3)}, batch_size=[3]) + >>> result = td_module(td) + >>> assert (result["y"] == 2).all() + + """ + if not ( + issubclass(self.input_type, TensorClass) + and issubclass(self.output_type, TensorClass) + ): + raise ValueError( + "Only TensorClassModuleBase implementations with both input and " + "output type as TensorClass can be converted to TensorDictModule" + ) + return TensorClassModuleWrapper(self) # type:ignore[arg-type,type-var] diff --git a/test/test_compile.py b/test/test_compile.py index 27cbcacbc..4a116b880 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -950,9 +950,7 @@ def test_onnx_export_module(self, tmpdir): x = torch.randn(3) y = torch.randn(3) torch_input = {"x": x, "y": y} - onnx_program = torch.onnx.dynamo_export(tdm, **torch_input) - - onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input) + onnx_program = torch.onnx.export(tdm, kwargs=torch_input, dynamo=True) path = Path(tmpdir) / "file.onnx" onnx_program.save(str(path)) @@ -969,9 +967,7 @@ def to_numpy(tensor): else tensor.cpu().numpy() ) - onnxruntime_input = { - k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input) - } + onnxruntime_input = {k: to_numpy(v) for k, v in torch_input.items()} onnxruntime_outputs = ort_session.run(None, onnxruntime_input) torch.testing.assert_close( @@ -986,10 +982,8 @@ def test_onnx_export_seq(self, tmpdir): x = torch.randn(3) y = torch.randn(3) torch_input = {"x": x, "y": y} - torch.onnx.dynamo_export(tdm, x=x, y=y) - onnx_program = torch.onnx.dynamo_export(tdm, **torch_input) - - onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input) + torch.onnx.export(tdm, kwargs=torch_input, dynamo=True) + onnx_program = torch.onnx.export(tdm, kwargs=torch_input, dynamo=True) path = Path(tmpdir) / "file.onnx" onnx_program.save(str(path)) @@ -1006,9 +1000,7 @@ def to_numpy(tensor): else tensor.cpu().numpy() ) - onnxruntime_input = { - k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input) - } + onnxruntime_input = {k: to_numpy(v) for k, v in torch_input.items()} onnxruntime_outputs = ort_session.run(None, onnxruntime_input) torch.testing.assert_close( diff --git a/test/test_tensorclass_module.py b/test/test_tensorclass_module.py new file mode 100644 index 000000000..28accd15e --- /dev/null +++ b/test/test_tensorclass_module.py @@ -0,0 +1,188 @@ +import importlib +from pathlib import Path + +import pytest +import torch +from tensordict.nn import TensorClassModuleBase +from tensordict.tensorclass import TensorClass +from torch import Tensor + +_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None + + +class InputTensorClass(TensorClass): + """Test input TensorClass with two tensor fields.""" + + a: Tensor + b: Tensor + + +class AddDiffResult(TensorClass): + """Test output TensorClass for add/diff operations.""" + + added: Tensor + substracted: Tensor + + +class OutputTensorClass(TensorClass): + """Test output TensorClass with nested structure.""" + + input: InputTensorClass + result: AddDiffResult + + +class AddDiffModule(TensorClassModuleBase[InputTensorClass, AddDiffResult]): + """Test module that adds and subtracts two tensors.""" + + def forward(self, x: InputTensorClass) -> AddDiffResult: + return AddDiffResult( + added=(x.a + x.b), substracted=(x.a - x.b), batch_size=x.batch_size + ) + + +class TestTensorClassModule(TensorClassModuleBase[InputTensorClass, OutputTensorClass]): + """Test module with nested TensorClass output.""" + + def __init__(self) -> None: + super().__init__() + self.add_diff = AddDiffModule() + + def forward(self, x: InputTensorClass) -> OutputTensorClass: + return OutputTensorClass( + input=x, result=self.add_diff(x), batch_size=x.batch_size + ) + + +class TestTensorClassModuleForward: + """Tests for TensorClassModule forward pass.""" + + def test_forward(self) -> None: + """Test basic forward pass with TensorClass input.""" + module = TestTensorClassModule() + value = InputTensorClass(a=10, b=5, batch_size=[]) + output = module.forward(value) + assert isinstance(output, OutputTensorClass) + assert output.result.added == 15 + assert output.result.substracted == 5 + + def test_td_forward(self) -> None: + """Test forward pass with TensorDict input via wrapper.""" + td_module = TestTensorClassModule().as_td_module() + value = InputTensorClass(a=10, b=5, batch_size=[]) + td_output = td_module(value.to_tensordict()) + assert td_output["result", "added"] == 15 + assert td_output["result", "substracted"] == 5 + + def test_wrapper_keys(self) -> None: + """Test that wrapper correctly extracts in_keys and out_keys.""" + module = TestTensorClassModule() + td_module = module.as_td_module() + assert set(td_module.in_keys) == {"a", "b"} + assert set(td_module.out_keys) == { + ("input", "a"), + ("input", "b"), + ("result", "added"), + ("result", "substracted"), + } + + +@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available") +class TestONNXExport: + """Tests for ONNX export functionality.""" + + def test_onnx_export_module(self, tmp_path: Path) -> None: + """Test ONNX export of TensorClassModule.""" + tc_module = TestTensorClassModule() + tc_module.eval() + tc_input = InputTensorClass( + a=torch.tensor([10.0], dtype=torch.float), + b=torch.tensor([5.0], dtype=torch.float), + batch_size=[1], + ) + torch_input = tc_input.to_tensordict().to_dict() + + td_module = tc_module.as_td_module().select_out_keys( + ("result", "added"), ("result", "substracted") + ) + output_names = [ + v if isinstance(v, str) else "_".join(v) for v in td_module.out_keys + ] + + onnx_program = torch.onnx.export( + model=td_module, kwargs=torch_input, output_names=output_names, dynamo=True + ) + + path = tmp_path / "file.onnx" + onnx_program.save(str(path)) + + import onnxruntime + + ort_session = onnxruntime.InferenceSession( + path, providers=["CPUExecutionProvider"] + ) + + def to_numpy(tensor): + return ( + tensor.detach().cpu().numpy() + if tensor.requires_grad + else tensor.cpu().numpy() + ) + + output_names = [output.name for output in ort_session.get_outputs()] + + onnxruntime_input = {k: to_numpy(v) for k, v in torch_input.items()} + + onnxruntime_outputs = ort_session.run(None, onnxruntime_input) + onnxruntime_output_dict = dict(zip(output_names, onnxruntime_outputs)) + + tc_outputs = tc_module(tc_input) + + torch.testing.assert_close( + torch.as_tensor(onnxruntime_output_dict["result_added"]), + tc_outputs.result.added, + ) + torch.testing.assert_close( + torch.as_tensor(onnxruntime_output_dict["result_substracted"]), + tc_outputs.result.substracted, + ) + + +class TestEdgeCases: + """Tests for edge cases and error handling.""" + + def test_non_tensorclass_conversion_error(self) -> None: + """Test that conversion to TensorDictModule fails for non-TensorClass types.""" + + class BadModule(TensorClassModuleBase[Tensor, Tensor]): + def forward(self, x: Tensor) -> Tensor: + return x + 1 + + module = BadModule() + with pytest.raises( + ValueError, + match="Only TensorClassModuleBase implementations with both input and output type as TensorClass", + ): + module.as_td_module() + + def test_batch_size_preservation(self) -> None: + """Test that batch size is correctly preserved through forward pass.""" + module = AddDiffModule() + batch_sizes = [[], [3], [2, 3], [1, 2, 3]] + + for batch_size in batch_sizes: + if batch_size: + input_tc = InputTensorClass( + a=torch.randn(*batch_size), + b=torch.randn(*batch_size), + batch_size=batch_size, + ) + else: + input_tc = InputTensorClass( + a=torch.randn(()), + b=torch.randn(()), + batch_size=batch_size, + ) + output = module(input_tc) + assert output.batch_size == torch.Size(batch_size) + assert output.added.shape == torch.Size(batch_size) + assert output.substracted.shape == torch.Size(batch_size) From 1f8310d3d4268dc4636f2a4cd0c1e701d4c24dfb Mon Sep 17 00:00:00 2001 From: Antoine de Maleprade Date: Thu, 30 Oct 2025 17:07:01 +0100 Subject: [PATCH 2/8] format --- tensordict/nn/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensordict/nn/__init__.py b/tensordict/nn/__init__.py index 66a5a7ca8..08bf63616 100644 --- a/tensordict/nn/__init__.py +++ b/tensordict/nn/__init__.py @@ -11,10 +11,6 @@ TensorDictModuleWrapper, WrapModule, ) -from tensordict.nn.tensorclass_module import ( - TensorClassModuleBase, - TensorClassModuleWrapper, -) from tensordict.nn.distributions import ( AddStateIndependentNormalScale, CompositeDistribution, @@ -38,6 +34,10 @@ set_interaction_type, ) from tensordict.nn.sequence import TensorDictSequential +from tensordict.nn.tensorclass_module import ( + TensorClassModuleBase, + TensorClassModuleWrapper, +) from tensordict.nn.utils import ( add_custom_mapping, biased_softplus, From 04a3db0c5281a97021db176d719d8fc57aa7b98b Mon Sep 17 00:00:00 2001 From: Antoine de Maleprade Date: Mon, 17 Nov 2025 18:01:58 +0100 Subject: [PATCH 3/8] Revert changes fixing onnx export, moved to a different PR --- pyproject.toml | 1 - test/test_compile.py | 18 +++++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ba8c2fcee..a6b6181d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,6 @@ tests = [ h5 = ["h5py>=3.8"] dev = ["pybind11", "ninja"] typecheck = ["mypy>=1.0.0"] -onnx = ["onnx", "onnxscript", "onnxruntime"] [tool.setuptools] include-package-data = false diff --git a/test/test_compile.py b/test/test_compile.py index 4a116b880..27cbcacbc 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -950,7 +950,9 @@ def test_onnx_export_module(self, tmpdir): x = torch.randn(3) y = torch.randn(3) torch_input = {"x": x, "y": y} - onnx_program = torch.onnx.export(tdm, kwargs=torch_input, dynamo=True) + onnx_program = torch.onnx.dynamo_export(tdm, **torch_input) + + onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input) path = Path(tmpdir) / "file.onnx" onnx_program.save(str(path)) @@ -967,7 +969,9 @@ def to_numpy(tensor): else tensor.cpu().numpy() ) - onnxruntime_input = {k: to_numpy(v) for k, v in torch_input.items()} + onnxruntime_input = { + k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input) + } onnxruntime_outputs = ort_session.run(None, onnxruntime_input) torch.testing.assert_close( @@ -982,8 +986,10 @@ def test_onnx_export_seq(self, tmpdir): x = torch.randn(3) y = torch.randn(3) torch_input = {"x": x, "y": y} - torch.onnx.export(tdm, kwargs=torch_input, dynamo=True) - onnx_program = torch.onnx.export(tdm, kwargs=torch_input, dynamo=True) + torch.onnx.dynamo_export(tdm, x=x, y=y) + onnx_program = torch.onnx.dynamo_export(tdm, **torch_input) + + onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input) path = Path(tmpdir) / "file.onnx" onnx_program.save(str(path)) @@ -1000,7 +1006,9 @@ def to_numpy(tensor): else tensor.cpu().numpy() ) - onnxruntime_input = {k: to_numpy(v) for k, v in torch_input.items()} + onnxruntime_input = { + k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input) + } onnxruntime_outputs = ort_session.run(None, onnxruntime_input) torch.testing.assert_close( From 808d1db395664917eee314d1330ff36e1f00fe38 Mon Sep 17 00:00:00 2001 From: Antoine de Maleprade Date: Mon, 17 Nov 2025 18:13:54 +0100 Subject: [PATCH 4/8] use modern type annotation --- tensordict/nn/tensorclass_module.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensordict/nn/tensorclass_module.py b/tensordict/nn/tensorclass_module.py index 6d879d00a..da1f42559 100644 --- a/tensordict/nn/tensorclass_module.py +++ b/tensordict/nn/tensorclass_module.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable from dataclasses import Field -from typing import Any, cast, Generic, get_args, get_origin, List, Tuple, TypeVar, Union +from typing import Any, cast, Generic, get_args, get_origin, TypeVar from tensordict._td import TensorDict from tensordict.nn.common import dispatch, TensorDictModuleBase @@ -13,7 +13,7 @@ __all__ = ["TensorClassModuleBase", "TensorClassModuleWrapper"] -def _tensor_class_keys(tensorclass_type: type[TensorClass]) -> List[Tuple[str, ...]]: +def _tensor_class_keys(tensorclass_type: type[TensorClass]) -> list[tuple[str, ...]]: """Extract all keys from a TensorClass type, including nested keys. Args: @@ -24,7 +24,7 @@ def _tensor_class_keys(tensorclass_type: type[TensorClass]) -> List[Tuple[str, . """ fields = cast("Iterable[Field[Any]]", tensorclass_type.fields()) - keys: List[Tuple[str, ...]] = [] + keys: list[tuple[str, ...]] = [] for field in fields: key = field.name if issubclass(field.type, TensorClass): @@ -105,8 +105,8 @@ def forward(self, tensordict: TensorDict, *args, **kwargs) -> TensorDict: ).to_tensordict() -InputClass = TypeVar("InputClass", bound=Union[TensorClass, Tensor]) -OutputClass = TypeVar("OutputClass", bound=Union[TensorClass, Tensor]) +InputClass = TypeVar("InputClass", bound=(TensorClass | Tensor)) +OutputClass = TypeVar("OutputClass", bound=(TensorClass | Tensor)) class TensorClassModuleBase(Generic[InputClass, OutputClass], ABC, nn.Module): From 57a85f97795333e2ce5b70e10b4261e5cb56ac5b Mon Sep 17 00:00:00 2001 From: Antoine de Maleprade Date: Mon, 17 Nov 2025 18:16:49 +0100 Subject: [PATCH 5/8] use `TensorClass.from_tensordict` --- tensordict/nn/tensorclass_module.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tensordict/nn/tensorclass_module.py b/tensordict/nn/tensorclass_module.py index da1f42559..87f3a35ce 100644 --- a/tensordict/nn/tensorclass_module.py +++ b/tensordict/nn/tensorclass_module.py @@ -97,11 +97,7 @@ def forward(self, tensordict: TensorDict, *args, **kwargs) -> TensorDict: """ return self.tc_module( - self.tc_module.input_type.from_dict( - input_dict=tensordict.to_dict(), - batch_size=tensordict.batch_size, - device=tensordict.device, - ) + self.tc_module.input_type.from_tensordict(tensordict) ).to_tensordict() From ba39ec6951ffd8075c2d918e7a26f81985f0a1d1 Mon Sep 17 00:00:00 2001 From: Antoine de Maleprade Date: Mon, 17 Nov 2025 18:27:51 +0100 Subject: [PATCH 6/8] moved tests to test_nn.py --- test/test_nn.py | 182 +++++++++++++++++++++++++++++++ test/test_tensorclass_module.py | 188 -------------------------------- 2 files changed, 182 insertions(+), 188 deletions(-) delete mode 100644 test/test_tensorclass_module.py diff --git a/test/test_nn.py b/test/test_nn.py index 253514324..f8099d5b6 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7,7 +7,9 @@ import contextlib import copy import functools +import importlib import os +import pathlib import pickle import sys import unittest @@ -32,6 +34,7 @@ dispatch, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, + TensorClassModuleBase, TensorDictModuleBase, TensorDictParams, TensorDictSequential, @@ -81,6 +84,7 @@ except ImportError: from tensordict.utils import Buffer +_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None IS_FB = os.getenv("PYTORCH_TEST_FBCODE") @@ -4106,6 +4110,184 @@ def func(c): assert func(TensorDict(c=0))["d"] == 1 +class InputTensorClass(TensorClass): + """Test input TensorClass with two tensor fields.""" + + a: torch.Tensor + b: torch.Tensor + + +class AddDiffResult(TensorClass): + """Test output TensorClass for add/diff operations.""" + + added: torch.Tensor + substracted: torch.Tensor + + +class OutputTensorClass(TensorClass): + """Test output TensorClass with nested structure.""" + + input: InputTensorClass + result: AddDiffResult + + +class AddDiffModule(TensorClassModuleBase[InputTensorClass, AddDiffResult]): + """Test module that adds and subtracts two tensors.""" + + def forward(self, x: InputTensorClass) -> AddDiffResult: + return AddDiffResult( + added=(x.a + x.b), substracted=(x.a - x.b), batch_size=x.batch_size + ) + + +class TestTensorClassModule(TensorClassModuleBase[InputTensorClass, OutputTensorClass]): + """Test module with nested TensorClass output.""" + + def __init__(self) -> None: + super().__init__() + self.add_diff = AddDiffModule() + + def forward(self, x: InputTensorClass) -> OutputTensorClass: + return OutputTensorClass( + input=x, result=self.add_diff(x), batch_size=x.batch_size + ) + + +class TestTensorClassModuleForward: + """Tests for TensorClassModule forward pass.""" + + def test_forward(self) -> None: + """Test basic forward pass with TensorClass input.""" + module = TestTensorClassModule() + value = InputTensorClass(a=10, b=5, batch_size=[]) + output = module.forward(value) + assert isinstance(output, OutputTensorClass) + assert output.result.added == 15 + assert output.result.substracted == 5 + + def test_td_forward(self) -> None: + """Test forward pass with TensorDict input via wrapper.""" + td_module = TestTensorClassModule().as_td_module() + value = InputTensorClass(a=10, b=5, batch_size=[]) + td_output = td_module(value.to_tensordict()) + assert td_output["result", "added"] == 15 + assert td_output["result", "substracted"] == 5 + + def test_wrapper_keys(self) -> None: + """Test that wrapper correctly extracts in_keys and out_keys.""" + module = TestTensorClassModule() + td_module = module.as_td_module() + assert set(td_module.in_keys) == {"a", "b"} + assert set(td_module.out_keys) == { + ("input", "a"), + ("input", "b"), + ("result", "added"), + ("result", "substracted"), + } + + +@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available") +class TestONNXExport: + """Tests for ONNX export functionality.""" + + def test_onnx_export_module(self, tmp_path: pathlib.Path) -> None: + """Test ONNX export of TensorClassModule.""" + tc_module = TestTensorClassModule() + tc_module.eval() + tc_input = InputTensorClass( + a=torch.tensor([10.0], dtype=torch.float), + b=torch.tensor([5.0], dtype=torch.float), + batch_size=[1], + ) + torch_input = tc_input.to_tensordict().to_dict() + + td_module = tc_module.as_td_module().select_out_keys( + ("result", "added"), ("result", "substracted") + ) + output_names = [ + v if isinstance(v, str) else "_".join(v) for v in td_module.out_keys + ] + + onnx_program = torch.onnx.export( + model=td_module, kwargs=torch_input, output_names=output_names, dynamo=True + ) + + path = tmp_path / "file.onnx" + onnx_program.save(str(path)) + + import onnxruntime + + ort_session = onnxruntime.InferenceSession( + path, providers=["CPUExecutionProvider"] + ) + + def to_numpy(tensor): + return ( + tensor.detach().cpu().numpy() + if tensor.requires_grad + else tensor.cpu().numpy() + ) + + output_names = [output.name for output in ort_session.get_outputs()] + + onnxruntime_input = {k: to_numpy(v) for k, v in torch_input.items()} + + onnxruntime_outputs = ort_session.run(None, onnxruntime_input) + onnxruntime_output_dict = dict(zip(output_names, onnxruntime_outputs)) + + tc_outputs = tc_module(tc_input) + + torch.testing.assert_close( + torch.as_tensor(onnxruntime_output_dict["result_added"]), + tc_outputs.result.added, + ) + torch.testing.assert_close( + torch.as_tensor(onnxruntime_output_dict["result_substracted"]), + tc_outputs.result.substracted, + ) + + +class TestEdgeCases: + """Tests for edge cases and error handling.""" + + def test_non_tensorclass_conversion_error(self) -> None: + """Test that conversion to TensorDictModule fails for non-TensorClass types.""" + + class BadModule(TensorClassModuleBase[torch.Tensor, torch.Tensor]): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + 1 + + module = BadModule() + with pytest.raises( + ValueError, + match="Only TensorClassModuleBase implementations with both input and output type as TensorClass", + ): + module.as_td_module() + + def test_batch_size_preservation(self) -> None: + """Test that batch size is correctly preserved through forward pass.""" + module = AddDiffModule() + batch_sizes = [[], [3], [2, 3], [1, 2, 3]] + + for batch_size in batch_sizes: + if batch_size: + input_tc = InputTensorClass( + a=torch.randn(*batch_size), + b=torch.randn(*batch_size), + batch_size=batch_size, + ) + else: + input_tc = InputTensorClass( + a=torch.randn(()), + b=torch.randn(()), + batch_size=batch_size, + ) + output = module(input_tc) + assert output.batch_size == torch.Size(batch_size) + assert output.added.shape == torch.Size(batch_size) + assert output.substracted.shape == torch.Size(batch_size) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensorclass_module.py b/test/test_tensorclass_module.py deleted file mode 100644 index 28accd15e..000000000 --- a/test/test_tensorclass_module.py +++ /dev/null @@ -1,188 +0,0 @@ -import importlib -from pathlib import Path - -import pytest -import torch -from tensordict.nn import TensorClassModuleBase -from tensordict.tensorclass import TensorClass -from torch import Tensor - -_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None - - -class InputTensorClass(TensorClass): - """Test input TensorClass with two tensor fields.""" - - a: Tensor - b: Tensor - - -class AddDiffResult(TensorClass): - """Test output TensorClass for add/diff operations.""" - - added: Tensor - substracted: Tensor - - -class OutputTensorClass(TensorClass): - """Test output TensorClass with nested structure.""" - - input: InputTensorClass - result: AddDiffResult - - -class AddDiffModule(TensorClassModuleBase[InputTensorClass, AddDiffResult]): - """Test module that adds and subtracts two tensors.""" - - def forward(self, x: InputTensorClass) -> AddDiffResult: - return AddDiffResult( - added=(x.a + x.b), substracted=(x.a - x.b), batch_size=x.batch_size - ) - - -class TestTensorClassModule(TensorClassModuleBase[InputTensorClass, OutputTensorClass]): - """Test module with nested TensorClass output.""" - - def __init__(self) -> None: - super().__init__() - self.add_diff = AddDiffModule() - - def forward(self, x: InputTensorClass) -> OutputTensorClass: - return OutputTensorClass( - input=x, result=self.add_diff(x), batch_size=x.batch_size - ) - - -class TestTensorClassModuleForward: - """Tests for TensorClassModule forward pass.""" - - def test_forward(self) -> None: - """Test basic forward pass with TensorClass input.""" - module = TestTensorClassModule() - value = InputTensorClass(a=10, b=5, batch_size=[]) - output = module.forward(value) - assert isinstance(output, OutputTensorClass) - assert output.result.added == 15 - assert output.result.substracted == 5 - - def test_td_forward(self) -> None: - """Test forward pass with TensorDict input via wrapper.""" - td_module = TestTensorClassModule().as_td_module() - value = InputTensorClass(a=10, b=5, batch_size=[]) - td_output = td_module(value.to_tensordict()) - assert td_output["result", "added"] == 15 - assert td_output["result", "substracted"] == 5 - - def test_wrapper_keys(self) -> None: - """Test that wrapper correctly extracts in_keys and out_keys.""" - module = TestTensorClassModule() - td_module = module.as_td_module() - assert set(td_module.in_keys) == {"a", "b"} - assert set(td_module.out_keys) == { - ("input", "a"), - ("input", "b"), - ("result", "added"), - ("result", "substracted"), - } - - -@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available") -class TestONNXExport: - """Tests for ONNX export functionality.""" - - def test_onnx_export_module(self, tmp_path: Path) -> None: - """Test ONNX export of TensorClassModule.""" - tc_module = TestTensorClassModule() - tc_module.eval() - tc_input = InputTensorClass( - a=torch.tensor([10.0], dtype=torch.float), - b=torch.tensor([5.0], dtype=torch.float), - batch_size=[1], - ) - torch_input = tc_input.to_tensordict().to_dict() - - td_module = tc_module.as_td_module().select_out_keys( - ("result", "added"), ("result", "substracted") - ) - output_names = [ - v if isinstance(v, str) else "_".join(v) for v in td_module.out_keys - ] - - onnx_program = torch.onnx.export( - model=td_module, kwargs=torch_input, output_names=output_names, dynamo=True - ) - - path = tmp_path / "file.onnx" - onnx_program.save(str(path)) - - import onnxruntime - - ort_session = onnxruntime.InferenceSession( - path, providers=["CPUExecutionProvider"] - ) - - def to_numpy(tensor): - return ( - tensor.detach().cpu().numpy() - if tensor.requires_grad - else tensor.cpu().numpy() - ) - - output_names = [output.name for output in ort_session.get_outputs()] - - onnxruntime_input = {k: to_numpy(v) for k, v in torch_input.items()} - - onnxruntime_outputs = ort_session.run(None, onnxruntime_input) - onnxruntime_output_dict = dict(zip(output_names, onnxruntime_outputs)) - - tc_outputs = tc_module(tc_input) - - torch.testing.assert_close( - torch.as_tensor(onnxruntime_output_dict["result_added"]), - tc_outputs.result.added, - ) - torch.testing.assert_close( - torch.as_tensor(onnxruntime_output_dict["result_substracted"]), - tc_outputs.result.substracted, - ) - - -class TestEdgeCases: - """Tests for edge cases and error handling.""" - - def test_non_tensorclass_conversion_error(self) -> None: - """Test that conversion to TensorDictModule fails for non-TensorClass types.""" - - class BadModule(TensorClassModuleBase[Tensor, Tensor]): - def forward(self, x: Tensor) -> Tensor: - return x + 1 - - module = BadModule() - with pytest.raises( - ValueError, - match="Only TensorClassModuleBase implementations with both input and output type as TensorClass", - ): - module.as_td_module() - - def test_batch_size_preservation(self) -> None: - """Test that batch size is correctly preserved through forward pass.""" - module = AddDiffModule() - batch_sizes = [[], [3], [2, 3], [1, 2, 3]] - - for batch_size in batch_sizes: - if batch_size: - input_tc = InputTensorClass( - a=torch.randn(*batch_size), - b=torch.randn(*batch_size), - batch_size=batch_size, - ) - else: - input_tc = InputTensorClass( - a=torch.randn(()), - b=torch.randn(()), - batch_size=batch_size, - ) - output = module(input_tc) - assert output.batch_size == torch.Size(batch_size) - assert output.added.shape == torch.Size(batch_size) - assert output.substracted.shape == torch.Size(batch_size) From c4e3ecdaa0e421d47b99ef3e11d5b47812f6c6c6 Mon Sep 17 00:00:00 2001 From: Antoine de Maleprade Date: Mon, 17 Nov 2025 18:32:41 +0100 Subject: [PATCH 7/8] Add TensorClassModuleBase and TensorClassModuleWrapper to nn reference --- docs/source/reference/nn.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/reference/nn.rst b/docs/source/reference/nn.rst index 0f91e4544..21e5ad8c3 100644 --- a/docs/source/reference/nn.rst +++ b/docs/source/reference/nn.rst @@ -193,6 +193,8 @@ to build distributions from network outputs and get summary statistics or sample TensorDictModuleBase TensorDictModule + TensorClassModuleBase + TensorClassModuleWrapper ProbabilisticTensorDictModule ProbabilisticTensorDictSequential TensorDictSequential From b3489c9db37bd63d94ea0e86ddd52c6b4a3e1308 Mon Sep 17 00:00:00 2001 From: Antoine de Maleprade Date: Mon, 17 Nov 2025 18:41:27 +0100 Subject: [PATCH 8/8] detailed documentation with example --- docs/source/reference/nn.rst | 70 ++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/docs/source/reference/nn.rst b/docs/source/reference/nn.rst index 21e5ad8c3..179939b01 100644 --- a/docs/source/reference/nn.rst +++ b/docs/source/reference/nn.rst @@ -186,6 +186,76 @@ to build distributions from network outputs and get summary statistics or sample device=None, is_shared=False) +Type-Safe TensorClass Modules +------------------------------ + +The :class:`~.TensorClassModuleBase` provides a type-safe way to define modules that work with +:class:`~tensordict.tensorclass.TensorClass` inputs and outputs. This offers compile-time type +checking and improved code clarity compared to working with string-based keys. + +A :class:`~.TensorClassModuleBase` subclass specifies its input and output types through generic +type parameters. The module can be converted to work with :class:`~.TensorDict` objects using the +:meth:`~.TensorClassModuleBase.as_td_module` method, which returns a :class:`~.TensorClassModuleWrapper`: + +.. code-block:: + + >>> import torch + >>> from tensordict.tensorclass import TensorClass + >>> from tensordict.nn import TensorClassModuleBase + >>> from tensordict import TensorDict + >>> + >>> # Define input and output TensorClass types + >>> class InputTC(TensorClass): + ... a: torch.Tensor + ... b: torch.Tensor + ... + >>> class OutputTC(TensorClass): + ... sum: torch.Tensor + ... difference: torch.Tensor + ... + >>> # Create a type-safe module + >>> class MyModule(TensorClassModuleBase[InputTC, OutputTC]): + ... def forward(self, x: InputTC) -> OutputTC: + ... return OutputTC( + ... sum=x.a + x.b, + ... difference=x.a - x.b, + ... batch_size=x.batch_size + ... ) + ... + >>> # Use with TensorClass + >>> module = MyModule() + >>> input_tc = InputTC(a=torch.tensor([1.0, 2.0]), b=torch.tensor([3.0, 4.0]), batch_size=[2]) + >>> output = module(input_tc) + >>> print(output.sum) + tensor([4., 6.]) + >>> print(output.difference) + tensor([-2., -2.]) + >>> + >>> # Convert to TensorDictModule for use in TensorDict workflows + >>> td_module = module.as_td_module() + >>> td = TensorDict({"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])}, batch_size=[2]) + >>> result = td_module(td) + >>> print(result) + TensorDict( + fields={ + a: Tensor(torch.Size([2]), dtype=torch.float32), + b: Tensor(torch.Size([2]), dtype=torch.float32), + difference: Tensor(torch.Size([2]), dtype=torch.float32), + sum: Tensor(torch.Size([2]), dtype=torch.float32)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False) + +The type-safe approach offers several benefits: + +* **Type checking**: IDEs and type checkers can verify correct usage at development time +* **Self-documenting**: The input and output structure is clear from the type signature +* **Refactoring**: Renaming fields in TensorClass definitions is caught by type checkers +* **Nested structures**: Support for nested TensorClass types with automatic key extraction + +:class:`~.TensorClassModuleBase` modules can be composed and used in :class:`~.TensorDictSequential` +after conversion via :meth:`~.TensorClassModuleBase.as_td_module`. + .. autosummary:: :toctree: generated/