Skip to content

Commit d2d30f6

Browse files
author
Vincent Moens
committed
[BE] Better errors for TensorDictSequential
ghstack-source-id: 2f2098d Pull Request resolved: #1227 (cherry picked from commit 28fbea1)
1 parent 1430c2b commit d2d30f6

File tree

2 files changed

+48
-10
lines changed

2 files changed

+48
-10
lines changed

tensordict/nn/probabilistic.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,14 +1299,21 @@ def forward(
12991299
tensordict_exec = tensordict
13001300
if self.return_composite:
13011301
for m in self._module_iter():
1302-
if isinstance(
1303-
m, (ProbabilisticTensorDictModule, ProbabilisticTensorDictModule)
1304-
):
1305-
tensordict_exec = m(
1306-
tensordict_exec, _requires_sample=self._requires_sample
1307-
)
1308-
else:
1309-
tensordict_exec = m(tensordict_exec, **kwargs)
1302+
try:
1303+
if isinstance(
1304+
m,
1305+
(ProbabilisticTensorDictModule, ProbabilisticTensorDictModule),
1306+
):
1307+
tensordict_exec = m(
1308+
tensordict_exec, _requires_sample=self._requires_sample
1309+
)
1310+
else:
1311+
tensordict_exec = m(tensordict_exec, **kwargs)
1312+
except Exception as e:
1313+
module_num_or_key = self._get_module_num_or_key(m)
1314+
raise RuntimeError(
1315+
f"Failed while executing module '{module_num_or_key}'. Scroll up for more info."
1316+
) from e
13101317
else:
13111318
tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs)
13121319
tensordict_exec = self._last_module(

tensordict/nn/sequence.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class TensorDictSequential(TensorDictModule):
5858
These can be instances of TensorDictModuleBase or any other function that matches this signature.
5959
Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked,
6060
and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential.
61+
Regular ``dict`` inputs will be converted to ``OrderedDict`` if necessary.
62+
6163
Keyword Args:
6264
partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys.
6365
If so, the only module that will be executed are those who can be executed given the keys that
@@ -98,7 +100,8 @@ class TensorDictSequential(TensorDictModule):
98100
>>> module(torch.zeros(3))
99101
(tensor([1., 1., 1.]), tensor([-0.7214, -0.8748, 0.1571, -0.1138], grad_fn=<AddBackward0>))
100102
101-
TensorDictSequence supports functional, modular and vmap coding:
103+
TensorDictSequence supports functional, modular and vmap coding.
104+
102105
Examples:
103106
>>> import torch
104107
>>> from tensordict import TensorDict
@@ -219,6 +222,12 @@ def __init__(
219222
super().__init__(
220223
module=nn.ModuleList(modules), in_keys=in_keys, out_keys=out_keys
221224
)
225+
elif len(modules) == 1 and isinstance(modules[0], dict):
226+
return self.__init__(
227+
collections.OrderedDict(modules[0]),
228+
partial_tolerant=partial_tolerant,
229+
selected_out_keys=selected_out_keys,
230+
)
222231
else:
223232
modules = self._convert_modules(modules)
224233
in_keys, out_keys = self._compute_in_and_out_keys(modules)
@@ -521,6 +530,20 @@ def _module_iter(self):
521530
else:
522531
yield from self.module
523532

533+
def _get_module_num_or_key(self, mod: nn.Module) -> int | str:
534+
if isinstance(self.module, nn.ModuleDict):
535+
for name, m in self.module.named_children():
536+
if m is mod:
537+
return name
538+
else:
539+
raise RuntimeError("module not found.")
540+
else:
541+
for i, m in enumerate(self.module):
542+
if m is mod:
543+
return i
544+
else:
545+
raise RuntimeError("module not found.")
546+
524547
@dispatch(auto_batch_size=False)
525548
@_set_skip_existing_None()
526549
def forward(
@@ -537,7 +560,15 @@ def forward(
537560
tensordict_exec = tensordict
538561
if not len(kwargs):
539562
for module in self._module_iter():
540-
tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)
563+
try:
564+
tensordict_exec = self._run_module(
565+
module, tensordict_exec, **kwargs
566+
)
567+
except Exception as e:
568+
module_num_or_key = self._get_module_num_or_key(module)
569+
raise RuntimeError(
570+
f"Failed while executing module '{module_num_or_key}'. Scroll up for more info."
571+
) from e
541572
else:
542573
raise RuntimeError(
543574
f"TensorDictSequential does not support keyword arguments other than 'tensordict_out' or in_keys: {self.in_keys}. Got {kwargs.keys()} instead."

0 commit comments

Comments
 (0)