|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
8 | 8 | import collections |
| 9 | +import contextlib |
9 | 10 | import logging |
10 | 11 | from copy import deepcopy |
11 | 12 | from typing import Any, Callable, Iterable, List, OrderedDict, overload |
@@ -578,3 +579,64 @@ def __setitem__( |
578 | 579 |
|
579 | 580 | def __delitem__(self, index: int | slice | str) -> None: |
580 | 581 | self.module.__delitem__(idx=index) |
| 582 | + |
| 583 | + def plot(self, example_input: TensorDictBase | None = None, **kwargs): |
| 584 | + import pydot |
| 585 | + |
| 586 | + graph = pydot.Dot( |
| 587 | + "my_graph", graph_type="digraph", bgcolor="yellow", splines="curved" |
| 588 | + ) |
| 589 | + graph.set_bgcolor("white") |
| 590 | + |
| 591 | + if example_input is not None: |
| 592 | + from torch._subclasses.fake_tensor import FakeTensorMode |
| 593 | + |
| 594 | + fake_mode = FakeTensorMode() |
| 595 | + converter = fake_mode.fake_tensor_converter |
| 596 | + fake_td = example_input.apply( |
| 597 | + lambda x: converter.from_real_tensor(fake_mode, x) |
| 598 | + ) |
| 599 | + else: |
| 600 | + fake_td = None |
| 601 | + fake_mode = contextlib.nullcontext() |
| 602 | + |
| 603 | + with fake_mode: |
| 604 | + iterator = ( |
| 605 | + enumerate(self._module_iter()) |
| 606 | + if not isinstance(self.module, nn.ModuleDict) |
| 607 | + else self.module.items() |
| 608 | + ) |
| 609 | + for name, module in iterator: |
| 610 | + graph.add_node( |
| 611 | + pydot.Node(str(name), shape="box") |
| 612 | + ) # label=str(node.module))) |
| 613 | + |
| 614 | + # Check if in_keys are there already |
| 615 | + in_keys = module.in_keys |
| 616 | + for in_key in in_keys: |
| 617 | + if in_key not in graph.obj_dict["nodes"]: |
| 618 | + in_key_node = pydot.Node( |
| 619 | + in_key, label=in_key, shape="plaintext" |
| 620 | + ) |
| 621 | + graph.add_node(in_key_node) |
| 622 | + in_key_edge = pydot.Edge( |
| 623 | + in_key, str(name), color="blue", style="arrow" |
| 624 | + ) |
| 625 | + graph.add_edge(in_key_edge) |
| 626 | + |
| 627 | + if not isinstance(module, TensorDictModule): |
| 628 | + fake_td = self._run_module(module, fake_td, **kwargs) |
| 629 | + |
| 630 | + out_keys = module.out_keys |
| 631 | + for out_key in out_keys: |
| 632 | + if out_key not in graph.obj_dict["nodes"]: |
| 633 | + out_key_node = pydot.Node( |
| 634 | + out_key, label=out_key, shape="plaintext" |
| 635 | + ) |
| 636 | + graph.add_node(out_key_node) |
| 637 | + out_key_edge = pydot.Edge( |
| 638 | + str(name), out_key, color="blue", style="arrow" |
| 639 | + ) |
| 640 | + graph.add_edge(out_key_edge) |
| 641 | + |
| 642 | + graph.write_png("/Users/vmoens/Downloads/my_graph.png") |
0 commit comments