77
88import re
99import warnings
10+ from collections .abc import MutableSequence
1011
1112from textwrap import indent
12- from typing import Any , Dict , List , Optional , overload , OrderedDict
13+ from typing import Any , Dict , List , Optional , OrderedDict , overload
1314
1415import torch
1516
@@ -621,9 +622,12 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
621622 log(p(z | x, y))
622623
623624 Args:
624- *modules (sequence of TensorDictModules ): An ordered sequence of
625- :class:`~tensordict.nn.TensorDictModule` instances, terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
625+ *modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule ): An ordered sequence of
626+ :class:`~tensordict.nn.TensorDictModule` instances, usually terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
626627 to be run sequentially.
628+ The modules can be instances of TensorDictModuleBase or any other function that matches this signature.
629+ Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked,
630+ and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential.
627631
628632 Keyword Args:
629633 partial_tolerant (bool, optional): If ``True``, the input tensordict can miss some
@@ -794,14 +798,13 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
794798 @overload
795799 def __init__ (
796800 self ,
797- modules : OrderedDict ,
801+ modules : OrderedDict [ str , TensorDictModuleBase | ProbabilisticTensorDictModule ] ,
798802 partial_tolerant : bool = False ,
799803 return_composite : bool | None = None ,
800804 aggregate_probabilities : bool | None = None ,
801805 include_sum : bool | None = None ,
802806 inplace : bool | None = None ,
803- ) -> None :
804- ...
807+ ) -> None : ...
805808
806809 @overload
807810 def __init__ (
@@ -812,8 +815,7 @@ def __init__(
812815 aggregate_probabilities : bool | None = None ,
813816 include_sum : bool | None = None ,
814817 inplace : bool | None = None ,
815- ) -> None :
816- ...
818+ ) -> None : ...
817819
818820 def __init__ (
819821 self ,
@@ -829,7 +831,14 @@ def __init__(
829831 "ProbabilisticTensorDictSequential must consist of zero or more "
830832 "TensorDictModules followed by a ProbabilisticTensorDictModule"
831833 )
832- if not return_composite and not isinstance (
834+ self ._ordered_dict = False
835+ if len (modules ) == 1 and isinstance (modules [0 ], (OrderedDict , MutableSequence )):
836+ if isinstance (modules [0 ], OrderedDict ):
837+ modules_list = list (modules [0 ].values ())
838+ self ._ordered_dict = True
839+ else :
840+ modules = modules_list = list (modules [0 ])
841+ elif not return_composite and not isinstance (
833842 modules [- 1 ],
834843 (ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential ),
835844 ):
@@ -838,13 +847,22 @@ def __init__(
838847 "an instance of ProbabilisticTensorDictModule or another "
839848 "ProbabilisticTensorDictSequential (unless return_composite is set to ``True``)."
840849 )
850+ else :
851+ modules_list = list (modules )
852+
841853 # if the modules not including the final probabilistic module return the sampled
842854 # key we won't be sampling it again, in that case
843855 # ProbabilisticTensorDictSequential is presumably used to return the
844856 # distribution using `get_dist` or to sample log_probabilities
845- _ , out_keys = self ._compute_in_and_out_keys (modules [:- 1 ])
846- self ._requires_sample = modules [- 1 ].out_keys [0 ] not in set (out_keys )
847- self .__dict__ ["_det_part" ] = TensorDictSequential (* modules [:- 1 ])
857+ _ , out_keys = self ._compute_in_and_out_keys (modules_list [:- 1 ])
858+ self ._requires_sample = modules_list [- 1 ].out_keys [0 ] not in set (out_keys )
859+ if self ._ordered_dict :
860+ self .__dict__ ["_det_part" ] = TensorDictSequential (
861+ OrderedDict (list (modules [0 ].items ())[:- 1 ])
862+ )
863+ else :
864+ self .__dict__ ["_det_part" ] = TensorDictSequential (* modules [:- 1 ])
865+
848866 super ().__init__ (* modules , partial_tolerant = partial_tolerant )
849867 self .return_composite = return_composite
850868 self .aggregate_probabilities = aggregate_probabilities
@@ -885,7 +903,7 @@ def get_dist_params(
885903 tds = self .det_part
886904 type = interaction_type ()
887905 if type is None :
888- for m in reversed (self .module ):
906+ for m in reversed (list ( self ._module_iter ()) ):
889907 if hasattr (m , "default_interaction_type" ):
890908 type = m .default_interaction_type
891909 break
@@ -897,7 +915,7 @@ def get_dist_params(
897915 @property
898916 def num_samples (self ):
899917 num_samples = ()
900- for tdm in self .module :
918+ for tdm in self ._module_iter () :
901919 if isinstance (
902920 tdm , (ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential )
903921 ):
@@ -941,7 +959,7 @@ def get_dist(
941959
942960 td_copy = tensordict .copy ()
943961 dists = {}
944- for i , tdm in enumerate (self .module ):
962+ for i , tdm in enumerate (self ._module_iter () ):
945963 if isinstance (
946964 tdm , (ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential )
947965 ):
@@ -981,12 +999,21 @@ def default_interaction_type(self):
981999 encountered is returned. If no such value is found, a default `interaction_type()` is returned.
9821000
9831001 """
984- for m in reversed (self .module ):
1002+ for m in reversed (list ( self ._module_iter ()) ):
9851003 interaction = getattr (m , "default_interaction_type" , None )
9861004 if interaction is not None :
9871005 return interaction
9881006 return interaction_type ()
9891007
1008+ @property
1009+ def _last_module (self ):
1010+ if not self ._ordered_dict :
1011+ return self .module [- 1 ]
1012+ mod = None
1013+ for mod in self ._module_iter (): # noqa: B007
1014+ continue
1015+ return mod
1016+
9901017 def log_prob (
9911018 self ,
9921019 tensordict ,
@@ -1103,7 +1130,7 @@ def log_prob(
11031130 include_sum = include_sum ,
11041131 ** kwargs ,
11051132 )
1106- last_module : ProbabilisticTensorDictModule = self .module [ - 1 ]
1133+ last_module : ProbabilisticTensorDictModule = self ._last_module
11071134 out = last_module .log_prob (tensordict_inp , dist = dist , ** kwargs )
11081135 if is_tensor_collection (out ):
11091136 if tensordict_out is not None :
@@ -1162,7 +1189,7 @@ def forward(
11621189 else :
11631190 tensordict_exec = tensordict
11641191 if self .return_composite :
1165- for m in self .module :
1192+ for m in self ._module_iter () :
11661193 if isinstance (
11671194 m , (ProbabilisticTensorDictModule , ProbabilisticTensorDictModule )
11681195 ):
@@ -1173,7 +1200,7 @@ def forward(
11731200 tensordict_exec = m (tensordict_exec , ** kwargs )
11741201 else :
11751202 tensordict_exec = self .get_dist_params (tensordict_exec , ** kwargs )
1176- tensordict_exec = self .module [ - 1 ] (
1203+ tensordict_exec = self ._last_module (
11771204 tensordict_exec , _requires_sample = self ._requires_sample
11781205 )
11791206 if tensordict_out is not None :
0 commit comments