File tree Expand file tree Collapse file tree 1 file changed +29
-0
lines changed Expand file tree Collapse file tree 1 file changed +29
-0
lines changed Original file line number Diff line number Diff line change @@ -485,6 +485,35 @@ class TensorDictModuleBase(nn.Module):
485485
486486 >>> tensordict_out = module.forward(tensordict_in)
487487
488+ Unlike :class:`~tensordict.nn.TensorDictModule`, `TensorDictModuleBase` is typically used via subclassing:
489+ you can wrap any python function in a `TensorDictModuleBase` subclass, as long as the subclass forward reads and
490+ writes tensordict (or related types) instances.
491+
492+ The `in_keys` and `out_keys` should be properly specified. For example, `out_keys` can be dynamically reduced using
493+ :meth:`~tensordict.nn.TensorDictBase.select_out_keys`.
494+
495+ Examples:
496+ >>> from tensordict import TensorDict
497+ >>> from tensordict.nn import TensorDictModuleBase
498+ >>> class Mod(TensorDictModuleBase):
499+ ... in_keys = ["a"] # can also be specified during __init__
500+ ... out_keys = ["b", "c"]
501+ ... def forward(self, tensordict):
502+ ... b = tensordict["a"].clone()
503+ ... c = b + 1
504+ ... return tensordict.replace({"b": b, "c": c})
505+ >>> mod = Mod()
506+ >>> td = mod(TensorDict(a=0))
507+ >>> td["b"]
508+ tensor(0)
509+ >>> td["c"]
510+ tensor(1)
511+ >>> mod.select_out_keys("c")
512+ >>> td = mod(TensorDict(a=0))
513+ >>> td["c"]
514+ tensor(1)
515+ >>> assert "b" not in td
516+
488517 """
489518
490519 def __new__ (cls , * args , ** kwargs ):
You can’t perform that action at this time.
0 commit comments