diff --git a/docs/source/reference/nn.rst b/docs/source/reference/nn.rst index 0f91e4544..179939b01 100644 --- a/docs/source/reference/nn.rst +++ b/docs/source/reference/nn.rst @@ -186,6 +186,76 @@ to build distributions from network outputs and get summary statistics or sample device=None, is_shared=False) +Type-Safe TensorClass Modules +------------------------------ + +The :class:`~.TensorClassModuleBase` provides a type-safe way to define modules that work with +:class:`~tensordict.tensorclass.TensorClass` inputs and outputs. This offers compile-time type +checking and improved code clarity compared to working with string-based keys. + +A :class:`~.TensorClassModuleBase` subclass specifies its input and output types through generic +type parameters. The module can be converted to work with :class:`~.TensorDict` objects using the +:meth:`~.TensorClassModuleBase.as_td_module` method, which returns a :class:`~.TensorClassModuleWrapper`: + +.. code-block:: + + >>> import torch + >>> from tensordict.tensorclass import TensorClass + >>> from tensordict.nn import TensorClassModuleBase + >>> from tensordict import TensorDict + >>> + >>> # Define input and output TensorClass types + >>> class InputTC(TensorClass): + ... a: torch.Tensor + ... b: torch.Tensor + ... + >>> class OutputTC(TensorClass): + ... sum: torch.Tensor + ... difference: torch.Tensor + ... + >>> # Create a type-safe module + >>> class MyModule(TensorClassModuleBase[InputTC, OutputTC]): + ... def forward(self, x: InputTC) -> OutputTC: + ... return OutputTC( + ... sum=x.a + x.b, + ... difference=x.a - x.b, + ... batch_size=x.batch_size + ... ) + ... + >>> # Use with TensorClass + >>> module = MyModule() + >>> input_tc = InputTC(a=torch.tensor([1.0, 2.0]), b=torch.tensor([3.0, 4.0]), batch_size=[2]) + >>> output = module(input_tc) + >>> print(output.sum) + tensor([4., 6.]) + >>> print(output.difference) + tensor([-2., -2.]) + >>> + >>> # Convert to TensorDictModule for use in TensorDict workflows + >>> td_module = module.as_td_module() + >>> td = TensorDict({"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])}, batch_size=[2]) + >>> result = td_module(td) + >>> print(result) + TensorDict( + fields={ + a: Tensor(torch.Size([2]), dtype=torch.float32), + b: Tensor(torch.Size([2]), dtype=torch.float32), + difference: Tensor(torch.Size([2]), dtype=torch.float32), + sum: Tensor(torch.Size([2]), dtype=torch.float32)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False) + +The type-safe approach offers several benefits: + +* **Type checking**: IDEs and type checkers can verify correct usage at development time +* **Self-documenting**: The input and output structure is clear from the type signature +* **Refactoring**: Renaming fields in TensorClass definitions is caught by type checkers +* **Nested structures**: Support for nested TensorClass types with automatic key extraction + +:class:`~.TensorClassModuleBase` modules can be composed and used in :class:`~.TensorDictSequential` +after conversion via :meth:`~.TensorClassModuleBase.as_td_module`. + .. autosummary:: :toctree: generated/ @@ -193,6 +263,8 @@ to build distributions from network outputs and get summary statistics or sample TensorDictModuleBase TensorDictModule + TensorClassModuleBase + TensorClassModuleWrapper ProbabilisticTensorDictModule ProbabilisticTensorDictSequential TensorDictSequential diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 162cfb26d..91478555f 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -73,7 +73,12 @@ unravel_key_list, ) from tensordict._pytree import * -from tensordict.nn import as_tensordict_module, TensorDictParams +from tensordict.nn import ( + as_tensordict_module, + TensorClassModuleBase, + TensorClassModuleWrapper, + TensorDictParams, +) try: from tensordict._version import __version__ # @manual=//pytorch/tensordict:version @@ -149,6 +154,8 @@ "NonTensorStack", # NN imports "as_tensordict_module", + "TensorClassModuleBase", + "TensorClassModuleWrapper", "TensorDictParams", # Version "__version__", diff --git a/tensordict/nn/__init__.py b/tensordict/nn/__init__.py index 6e3e68fdd..08bf63616 100644 --- a/tensordict/nn/__init__.py +++ b/tensordict/nn/__init__.py @@ -34,6 +34,10 @@ set_interaction_type, ) from tensordict.nn.sequence import TensorDictSequential +from tensordict.nn.tensorclass_module import ( + TensorClassModuleBase, + TensorClassModuleWrapper, +) from tensordict.nn.utils import ( add_custom_mapping, biased_softplus, @@ -57,6 +61,8 @@ "TensorDictSequential", "EnsembleModule", "CudaGraphModule", + "TensorClassModuleBase", + "TensorClassModuleWrapper", # Probabilistic modules "ProbabilisticTensorDictModule", "ProbabilisticTensorDictSequential", diff --git a/tensordict/nn/tensorclass_module.py b/tensordict/nn/tensorclass_module.py new file mode 100644 index 000000000..87f3a35ce --- /dev/null +++ b/tensordict/nn/tensorclass_module.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Iterable +from dataclasses import Field +from typing import Any, cast, Generic, get_args, get_origin, TypeVar + +from tensordict._td import TensorDict +from tensordict.nn.common import dispatch, TensorDictModuleBase +from tensordict.tensorclass import TensorClass +from torch import nn, Tensor + +__all__ = ["TensorClassModuleBase", "TensorClassModuleWrapper"] + + +def _tensor_class_keys(tensorclass_type: type[TensorClass]) -> list[tuple[str, ...]]: + """Extract all keys from a TensorClass type, including nested keys. + + Args: + tensorclass_type (type[TensorClass]): The TensorClass type to extract keys from. + + Returns: + list[tuple[str, ...]]: A list of key tuples representing all fields in the TensorClass. + + """ + fields = cast("Iterable[Field[Any]]", tensorclass_type.fields()) + keys: list[tuple[str, ...]] = [] + for field in fields: + key = field.name + if issubclass(field.type, TensorClass): + subkeys = _tensor_class_keys(cast(type[TensorClass], field.type)) + for subkey in subkeys: + keys.append((key,) + subkey) + else: + keys.append((key,)) + return keys + + +InputTensorClass = TypeVar("InputTensorClass", bound=TensorClass) +OutputTensorClass = TypeVar("OutputTensorClass", bound=TensorClass) + + +class TensorClassModuleWrapper(TensorDictModuleBase): + """Wrapper class for TensorClassModuleBase objects. + + This wrapper allows TensorClassModuleBase instances to be used in TensorDict-based + workflows by handling the conversion between TensorDict and TensorClass representations. + When called with a TensorDict, the wrapper converts it to a TensorClass, passes it through + the wrapped module, and converts the output back to a TensorDict. + + Args: + module (TensorClassModuleBase): The TensorClassModuleBase instance to wrap. + + Examples: + >>> from tensordict import TensorDict + >>> from tensordict.tensorclass import TensorClass + >>> from tensordict.nn import TensorClassModuleBase + >>> import torch + >>> + >>> class InputTC(TensorClass): + ... x: torch.Tensor + ... + >>> class OutputTC(TensorClass): + ... y: torch.Tensor + ... + >>> class MyModule(TensorClassModuleBase[InputTC, OutputTC]): + ... def forward(self, input: InputTC) -> OutputTC: + ... return OutputTC(y=input.x + 1, batch_size=input.batch_size) + ... + >>> module = MyModule() + >>> td_module = module.as_td_module() + >>> td = TensorDict({"x": torch.zeros(3)}, batch_size=[3]) + >>> result = td_module(td) + >>> assert "y" in result + + """ + + def __init__( + self, module: TensorClassModuleBase[InputTensorClass, OutputTensorClass] + ) -> None: + super().__init__() + self.tc_module = module + self.in_keys = _tensor_class_keys(cast(type[TensorClass], module.input_type)) + self.out_keys = _tensor_class_keys(cast(type[TensorClass], module.output_type)) + + @dispatch(auto_batch_size=False) + def forward(self, tensordict: TensorDict, *args, **kwargs) -> TensorDict: + """Forward pass converting TensorDict to TensorClass and back. + + Args: + tensordict (TensorDict): Input tensordict. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + TensorDict: Output tensordict. + + """ + return self.tc_module( + self.tc_module.input_type.from_tensordict(tensordict) + ).to_tensordict() + + +InputClass = TypeVar("InputClass", bound=(TensorClass | Tensor)) +OutputClass = TypeVar("OutputClass", bound=(TensorClass | Tensor)) + + +class TensorClassModuleBase(Generic[InputClass, OutputClass], ABC, nn.Module): + """A TensorClassModuleBase is a base class for modules that operate on TensorClass instances. + + TensorClassModuleBase subclasses provide a type-safe way to define modules that work with TensorClass + inputs and outputs. The class automatically extracts input and output type information from the + generic type parameters. + + The module can be converted to a TensorDictModule using the :meth:`as_td_module` + method, allowing it to be used in TensorDict-based workflows. + + Type Parameters: + InputClass: The input type, must be a TensorClass or Tensor. + OutputClass: The output type, must be a TensorClass or Tensor. + + Attributes: + input_type (type[InputClass]): The input type class. + output_type (type[OutputClass]): The output type class. + + Examples: + >>> from tensordict.tensorclass import TensorClass + >>> from tensordict.nn import TensorClassModuleBase + >>> import torch + >>> + >>> class InputTC(TensorClass): + ... a: torch.Tensor + ... b: torch.Tensor + ... + >>> class OutputTC(TensorClass): + ... result: torch.Tensor + ... + >>> class AddModule(TensorClassModuleBase[InputTC, OutputTC]): + ... def forward(self, x: InputTC) -> OutputTC: + ... return OutputTC( + ... result=x.a + x.b, + ... batch_size=x.batch_size + ... ) + ... + >>> module = AddModule() + >>> input_tc = InputTC(a=torch.tensor([1.0]), b=torch.tensor([2.0]), batch_size=[1]) + >>> output = module(input_tc) + >>> assert output.result == torch.tensor([3.0]) + + """ + + input_type: type[InputClass] + output_type: type[OutputClass] + + def __init_subclass__(cls) -> None: + """Initialize subclass by extracting type information from generic parameters.""" + super().__init_subclass__() + for base in cls.__orig_bases__: # type:ignore[attr-defined] + origin = get_origin(base) + if origin is TensorClassModuleBase: + generic_args = get_args(base) + if generic_args: + cls.input_type, cls.output_type = generic_args + else: + raise ValueError( + "Generic input/output types not set in TensorClassModuleBase" + ) + + @abstractmethod + def forward(self, x: InputClass) -> OutputClass: + """Forward pass of the module. + + Args: + x (InputClass): Input instance. + + Returns: + OutputClass: Output instance. + + """ + ... + + def __call__(self, x: InputClass) -> OutputClass: + """Call the module's forward method. + + Args: + x (InputClass): Input instance. + + Returns: + OutputClass: Output instance. + + """ + return cast("OutputClass", super().__call__(x)) + + def as_td_module(self) -> TensorClassModuleWrapper: + """Convert this module to a TensorDictModule. + + This method wraps the TensorClassModuleBase in a TensorClassModuleWrapper, + allowing it to be used with TensorDict inputs and outputs. + + Returns: + TensorClassModuleWrapper: A wrapper that converts between TensorDict + and TensorClass representations. + + Raises: + ValueError: If either input_type or output_type is not a TensorClass. + + Examples: + >>> from tensordict import TensorDict + >>> from tensordict.tensorclass import TensorClass + >>> from tensordict.nn import TensorClassModuleBase + >>> import torch + >>> + >>> class InputTC(TensorClass): + ... x: torch.Tensor + ... + >>> class OutputTC(TensorClass): + ... y: torch.Tensor + ... + >>> class MyModule(TensorClassModuleBase[InputTC, OutputTC]): + ... def forward(self, input: InputTC) -> OutputTC: + ... return OutputTC(y=input.x * 2, batch_size=input.batch_size) + ... + >>> module = MyModule() + >>> td_module = module.as_td_module() + >>> td = TensorDict({"x": torch.ones(3)}, batch_size=[3]) + >>> result = td_module(td) + >>> assert (result["y"] == 2).all() + + """ + if not ( + issubclass(self.input_type, TensorClass) + and issubclass(self.output_type, TensorClass) + ): + raise ValueError( + "Only TensorClassModuleBase implementations with both input and " + "output type as TensorClass can be converted to TensorDictModule" + ) + return TensorClassModuleWrapper(self) # type:ignore[arg-type,type-var] diff --git a/test/test_nn.py b/test/test_nn.py index 253514324..f8099d5b6 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7,7 +7,9 @@ import contextlib import copy import functools +import importlib import os +import pathlib import pickle import sys import unittest @@ -32,6 +34,7 @@ dispatch, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, + TensorClassModuleBase, TensorDictModuleBase, TensorDictParams, TensorDictSequential, @@ -81,6 +84,7 @@ except ImportError: from tensordict.utils import Buffer +_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None IS_FB = os.getenv("PYTORCH_TEST_FBCODE") @@ -4106,6 +4110,184 @@ def func(c): assert func(TensorDict(c=0))["d"] == 1 +class InputTensorClass(TensorClass): + """Test input TensorClass with two tensor fields.""" + + a: torch.Tensor + b: torch.Tensor + + +class AddDiffResult(TensorClass): + """Test output TensorClass for add/diff operations.""" + + added: torch.Tensor + substracted: torch.Tensor + + +class OutputTensorClass(TensorClass): + """Test output TensorClass with nested structure.""" + + input: InputTensorClass + result: AddDiffResult + + +class AddDiffModule(TensorClassModuleBase[InputTensorClass, AddDiffResult]): + """Test module that adds and subtracts two tensors.""" + + def forward(self, x: InputTensorClass) -> AddDiffResult: + return AddDiffResult( + added=(x.a + x.b), substracted=(x.a - x.b), batch_size=x.batch_size + ) + + +class TestTensorClassModule(TensorClassModuleBase[InputTensorClass, OutputTensorClass]): + """Test module with nested TensorClass output.""" + + def __init__(self) -> None: + super().__init__() + self.add_diff = AddDiffModule() + + def forward(self, x: InputTensorClass) -> OutputTensorClass: + return OutputTensorClass( + input=x, result=self.add_diff(x), batch_size=x.batch_size + ) + + +class TestTensorClassModuleForward: + """Tests for TensorClassModule forward pass.""" + + def test_forward(self) -> None: + """Test basic forward pass with TensorClass input.""" + module = TestTensorClassModule() + value = InputTensorClass(a=10, b=5, batch_size=[]) + output = module.forward(value) + assert isinstance(output, OutputTensorClass) + assert output.result.added == 15 + assert output.result.substracted == 5 + + def test_td_forward(self) -> None: + """Test forward pass with TensorDict input via wrapper.""" + td_module = TestTensorClassModule().as_td_module() + value = InputTensorClass(a=10, b=5, batch_size=[]) + td_output = td_module(value.to_tensordict()) + assert td_output["result", "added"] == 15 + assert td_output["result", "substracted"] == 5 + + def test_wrapper_keys(self) -> None: + """Test that wrapper correctly extracts in_keys and out_keys.""" + module = TestTensorClassModule() + td_module = module.as_td_module() + assert set(td_module.in_keys) == {"a", "b"} + assert set(td_module.out_keys) == { + ("input", "a"), + ("input", "b"), + ("result", "added"), + ("result", "substracted"), + } + + +@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available") +class TestONNXExport: + """Tests for ONNX export functionality.""" + + def test_onnx_export_module(self, tmp_path: pathlib.Path) -> None: + """Test ONNX export of TensorClassModule.""" + tc_module = TestTensorClassModule() + tc_module.eval() + tc_input = InputTensorClass( + a=torch.tensor([10.0], dtype=torch.float), + b=torch.tensor([5.0], dtype=torch.float), + batch_size=[1], + ) + torch_input = tc_input.to_tensordict().to_dict() + + td_module = tc_module.as_td_module().select_out_keys( + ("result", "added"), ("result", "substracted") + ) + output_names = [ + v if isinstance(v, str) else "_".join(v) for v in td_module.out_keys + ] + + onnx_program = torch.onnx.export( + model=td_module, kwargs=torch_input, output_names=output_names, dynamo=True + ) + + path = tmp_path / "file.onnx" + onnx_program.save(str(path)) + + import onnxruntime + + ort_session = onnxruntime.InferenceSession( + path, providers=["CPUExecutionProvider"] + ) + + def to_numpy(tensor): + return ( + tensor.detach().cpu().numpy() + if tensor.requires_grad + else tensor.cpu().numpy() + ) + + output_names = [output.name for output in ort_session.get_outputs()] + + onnxruntime_input = {k: to_numpy(v) for k, v in torch_input.items()} + + onnxruntime_outputs = ort_session.run(None, onnxruntime_input) + onnxruntime_output_dict = dict(zip(output_names, onnxruntime_outputs)) + + tc_outputs = tc_module(tc_input) + + torch.testing.assert_close( + torch.as_tensor(onnxruntime_output_dict["result_added"]), + tc_outputs.result.added, + ) + torch.testing.assert_close( + torch.as_tensor(onnxruntime_output_dict["result_substracted"]), + tc_outputs.result.substracted, + ) + + +class TestEdgeCases: + """Tests for edge cases and error handling.""" + + def test_non_tensorclass_conversion_error(self) -> None: + """Test that conversion to TensorDictModule fails for non-TensorClass types.""" + + class BadModule(TensorClassModuleBase[torch.Tensor, torch.Tensor]): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + 1 + + module = BadModule() + with pytest.raises( + ValueError, + match="Only TensorClassModuleBase implementations with both input and output type as TensorClass", + ): + module.as_td_module() + + def test_batch_size_preservation(self) -> None: + """Test that batch size is correctly preserved through forward pass.""" + module = AddDiffModule() + batch_sizes = [[], [3], [2, 3], [1, 2, 3]] + + for batch_size in batch_sizes: + if batch_size: + input_tc = InputTensorClass( + a=torch.randn(*batch_size), + b=torch.randn(*batch_size), + batch_size=batch_size, + ) + else: + input_tc = InputTensorClass( + a=torch.randn(()), + b=torch.randn(()), + batch_size=batch_size, + ) + output = module(input_tc) + assert output.batch_size == torch.Size(batch_size) + assert output.added.shape == torch.Size(batch_size) + assert output.substracted.shape == torch.Size(batch_size) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)