Skip to content

Commit 1430c2b

Browse files
author
Vincent Moens
committed
[Doc] Better doc for TensorDictModuleBase
ghstack-source-id: 1c5cbcc Pull Request resolved: #1226 (cherry picked from commit b8d3ff9)
1 parent c88acce commit 1430c2b

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tensordict/nn/common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,35 @@ class TensorDictModuleBase(nn.Module):
485485
486486
>>> tensordict_out = module.forward(tensordict_in)
487487
488+
Unlike :class:`~tensordict.nn.TensorDictModule`, `TensorDictModuleBase` is typically used via subclassing:
489+
you can wrap any python function in a `TensorDictModuleBase` subclass, as long as the subclass forward reads and
490+
writes tensordict (or related types) instances.
491+
492+
The `in_keys` and `out_keys` should be properly specified. For example, `out_keys` can be dynamically reduced using
493+
:meth:`~tensordict.nn.TensorDictBase.select_out_keys`.
494+
495+
Examples:
496+
>>> from tensordict import TensorDict
497+
>>> from tensordict.nn import TensorDictModuleBase
498+
>>> class Mod(TensorDictModuleBase):
499+
... in_keys = ["a"] # can also be specified during __init__
500+
... out_keys = ["b", "c"]
501+
... def forward(self, tensordict):
502+
... b = tensordict["a"].clone()
503+
... c = b + 1
504+
... return tensordict.replace({"b": b, "c": c})
505+
>>> mod = Mod()
506+
>>> td = mod(TensorDict(a=0))
507+
>>> td["b"]
508+
tensor(0)
509+
>>> td["c"]
510+
tensor(1)
511+
>>> mod.select_out_keys("c")
512+
>>> td = mod(TensorDict(a=0))
513+
>>> td["c"]
514+
tensor(1)
515+
>>> assert "b" not in td
516+
488517
"""
489518

490519
def __new__(cls, *args, **kwargs):

0 commit comments

Comments
 (0)