@@ -58,6 +58,8 @@ class TensorDictSequential(TensorDictModule):
5858 These can be instances of TensorDictModuleBase or any other function that matches this signature.
5959 Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked,
6060 and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential.
61+ Regular ``dict`` inputs will be converted to ``OrderedDict`` if necessary.
62+
6163 Keyword Args:
6264 partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys.
6365 If so, the only module that will be executed are those who can be executed given the keys that
@@ -98,7 +100,8 @@ class TensorDictSequential(TensorDictModule):
98100 >>> module(torch.zeros(3))
99101 (tensor([1., 1., 1.]), tensor([-0.7214, -0.8748, 0.1571, -0.1138], grad_fn=<AddBackward0>))
100102
101- TensorDictSequence supports functional, modular and vmap coding:
103+ TensorDictSequence supports functional, modular and vmap coding.
104+
102105 Examples:
103106 >>> import torch
104107 >>> from tensordict import TensorDict
@@ -219,6 +222,12 @@ def __init__(
219222 super ().__init__ (
220223 module = nn .ModuleList (modules ), in_keys = in_keys , out_keys = out_keys
221224 )
225+ elif len (modules ) == 1 and isinstance (modules [0 ], dict ):
226+ return self .__init__ (
227+ collections .OrderedDict (modules [0 ]),
228+ partial_tolerant = partial_tolerant ,
229+ selected_out_keys = selected_out_keys ,
230+ )
222231 else :
223232 modules = self ._convert_modules (modules )
224233 in_keys , out_keys = self ._compute_in_and_out_keys (modules )
@@ -521,6 +530,20 @@ def _module_iter(self):
521530 else :
522531 yield from self .module
523532
533+ def _get_module_num_or_key (self , mod : nn .Module ) -> int | str :
534+ if isinstance (self .module , nn .ModuleDict ):
535+ for name , m in self .module .named_children ():
536+ if m is mod :
537+ return name
538+ else :
539+ raise RuntimeError ("module not found." )
540+ else :
541+ for i , m in enumerate (self .module ):
542+ if m is mod :
543+ return i
544+ else :
545+ raise RuntimeError ("module not found." )
546+
524547 @dispatch (auto_batch_size = False )
525548 @_set_skip_existing_None ()
526549 def forward (
@@ -537,7 +560,15 @@ def forward(
537560 tensordict_exec = tensordict
538561 if not len (kwargs ):
539562 for module in self ._module_iter ():
540- tensordict_exec = self ._run_module (module , tensordict_exec , ** kwargs )
563+ try :
564+ tensordict_exec = self ._run_module (
565+ module , tensordict_exec , ** kwargs
566+ )
567+ except Exception as e :
568+ module_num_or_key = self ._get_module_num_or_key (module )
569+ raise RuntimeError (
570+ f"Failed while executing module '{ module_num_or_key } '. Scroll up for more info."
571+ ) from e
541572 else :
542573 raise RuntimeError (
543574 f"TensorDictSequential does not support keyword arguments other than 'tensordict_out' or in_keys: { self .in_keys } . Got { kwargs .keys ()} instead."
0 commit comments