Skip to content

Commit 684dca7

Browse files
author
Vincent Moens
committed
[Feature] Plotting TensorDictSequential graphs
ghstack-source-id: 9f27d6b Pull Request resolved: #1144
1 parent e073cbe commit 684dca7

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

tensordict/nn/sequence.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import collections
9+
import contextlib
910
import logging
1011
from copy import deepcopy
1112
from typing import Any, Callable, Iterable, List, OrderedDict, overload
@@ -578,3 +579,64 @@ def __setitem__(
578579

579580
def __delitem__(self, index: int | slice | str) -> None:
580581
self.module.__delitem__(idx=index)
582+
583+
def plot(self, example_input: TensorDictBase | None = None, **kwargs):
584+
import pydot
585+
586+
graph = pydot.Dot(
587+
"my_graph", graph_type="digraph", bgcolor="yellow", splines="curved"
588+
)
589+
graph.set_bgcolor("white")
590+
591+
if example_input is not None:
592+
from torch._subclasses.fake_tensor import FakeTensorMode
593+
594+
fake_mode = FakeTensorMode()
595+
converter = fake_mode.fake_tensor_converter
596+
fake_td = example_input.apply(
597+
lambda x: converter.from_real_tensor(fake_mode, x)
598+
)
599+
else:
600+
fake_td = None
601+
fake_mode = contextlib.nullcontext()
602+
603+
with fake_mode:
604+
iterator = (
605+
enumerate(self._module_iter())
606+
if not isinstance(self.module, nn.ModuleDict)
607+
else self.module.items()
608+
)
609+
for name, module in iterator:
610+
graph.add_node(
611+
pydot.Node(str(name), shape="box")
612+
) # label=str(node.module)))
613+
614+
# Check if in_keys are there already
615+
in_keys = module.in_keys
616+
for in_key in in_keys:
617+
if in_key not in graph.obj_dict["nodes"]:
618+
in_key_node = pydot.Node(
619+
in_key, label=in_key, shape="plaintext"
620+
)
621+
graph.add_node(in_key_node)
622+
in_key_edge = pydot.Edge(
623+
in_key, str(name), color="blue", style="arrow"
624+
)
625+
graph.add_edge(in_key_edge)
626+
627+
if not isinstance(module, TensorDictModule):
628+
fake_td = self._run_module(module, fake_td, **kwargs)
629+
630+
out_keys = module.out_keys
631+
for out_key in out_keys:
632+
if out_key not in graph.obj_dict["nodes"]:
633+
out_key_node = pydot.Node(
634+
out_key, label=out_key, shape="plaintext"
635+
)
636+
graph.add_node(out_key_node)
637+
out_key_edge = pydot.Edge(
638+
str(name), out_key, color="blue", style="arrow"
639+
)
640+
graph.add_edge(out_key_edge)
641+
642+
graph.write_png("/Users/vmoens/Downloads/my_graph.png")

0 commit comments

Comments
 (0)