Skip to content

tensorboard writer.add_graph fails with HeteroConv due to jit.trace not supporting tuple dictionary keys #10421

@jesseangelis

Description

@jesseangelis

🐛 Describe the bug

When attempting to use torch.utils.tensorboard.SummaryWriter.add_graph to visualize a torch_geometric.nn.HeteroGNN model, the operation fails. The add_graph function internally uses torch.jit.trace, which is unable to handle dictionaries with tuple keys. The HeteroGNN model, specifically its HeteroConv layers, requires an edge_index_dict with keys of type (str, str, str) to define the different edge types in a heterogeneous graph. This incompatibility leads to a RuntimeError during tracing.

Minimal example to reproduce the error:

import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear
from torch.utils.tensorboard import SummaryWriter

# Load a heterogeneous graph dataset
dataset = OGB_MAG(root='./data', preprocess='metapath2vec', transform=T.ToUndirected())
data = dataset[0]

# Define the HeteroGNN model
class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                ('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels),
                ('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels),
                ('paper', 'rev_writes', 'author'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            }, aggr='sum')
            self.convs.append(conv)
        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
        return self.lin(x_dict['author'])

# Instantiate the model and initialize lazy modules
model = HeteroGNN(hidden_channels=64, out_channels=dataset.num_classes, num_layers=2)
with torch.no_grad():
     out = model(data.x_dict, data.edge_index_dict)

# Attempt to add the graph to TensorBoard, which fails
writer = SummaryWriter()
writer.add_graph(model, (data.x_dict, data.edge_index_dict,))
writer.close()

Observed error:

File ~<Path>/.venv/lib/python3.12/site-packages/torch/jit/_trace.py:1282, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
   1280 else:
   1281     example_inputs = make_tuple(example_inputs)
   1282     module._c._create_method_from_trace(
   1283         method_name,
   1284         func,
   1285         example_inputs,
   1286         var_lookup_fn,
   1287         strict,
   1288         _force_outplace,
   1289         argument_names,
   1290         _store_inputs,
   1291     )
   1293 check_trace_method = module._c._get_method(method_name)
   1295 # Check the trace against new traces created from user-specified inputs

RuntimeError: Cannot create dict for key type '(str, str, str)', only int, float, complex, Tensor, device and string keys are supported

A workaround is to create a wrapper class that converts the dictionary inputs into lists of tensors and then reconstructs the dictionaries inside the wrapper's forward method. This allows jit.trace to work as it can handle lists of tensors.

class HeteroGNNWrapper(torch.nn.Module):
    def __init__(self, model: torch.nn.Module, x_dict_keys, edge_index_dict_keys):
        super().__init__()
        self.model = model
        self.x_dict_keys = x_dict_keys
        self.edge_index_dict_keys = edge_index_dict_keys

    def forward(self, x_list, edge_index_list):
        # Reconstruct the dicts from the lists of tensors and keys
        x_dict = {key: tensor for key, tensor in zip(self.x_dict_keys, x_list)}
        edge_index_dict = {key: tensor for key, tensor in zip(self.edge_index_dict_keys, edge_index_list)}
        
        # Call the original model's forward method
        return self.model(x_dict, edge_index_dict)

# Prepare the data for the wrapper
x_list = list(data.x_dict.values())
x_dict_keys = list(data.x_dict.keys())
edge_index_list = list(data.edge_index_dict.values())
edge_index_dict_keys = list(data.edge_index_dict.keys())

# Instantiate the wrapper and add the graph to TensorBoard
wrapped_model = HeteroGNNWrapper(model, x_dict_keys, edge_index_dict_keys)
writer = SummaryWriter()
writer.add_graph(wrapped_model, (x_list, edge_index_list,))
writer.close()

Versions

PyTorch version: 2.8.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.2 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39

Python version: 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 5 7530U with Radeon Graphics
CPU family: 25
Model: 80
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 1
Stepping: 0
CPU(s) scaling MHz: 72%
CPU max MHz: 4546.0000
CPU min MHz: 400.0000
BogoMIPS: 3992.46

Versions of relevant libraries:
[pip3] numpy==2.3.2
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.3
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] torch==2.8.0
[pip3] torch-geometric==2.6.1
[pip3] triton==3.4.0
[conda] blas 1.0 mkl
[conda] intel-openmp 2022.0.1 h06a4308_3633
[conda] mkl 2023.2.0 h84fe81f_50496 conda-forge
[conda] mkl-service 2.4.0 py312h5eee18b_1
[conda] mkl_fft 1.3.8 py312h5eee18b_0
[conda] mkl_random 1.2.4 py312hdb19cb5_0
[conda] numpy 1.26.4 py312hc5e2394_0
[conda] numpy-base 1.26.4 py312h0da6c21_0
[conda] tbb 2021.8.0 hdb19cb5_0

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions