-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Description
🐛 Describe the bug
When attempting to use torch.utils.tensorboard.SummaryWriter.add_graph
to visualize a torch_geometric.nn.HeteroGNN
model, the operation fails. The add_graph
function internally uses torch.jit.trace
, which is unable to handle dictionaries with tuple keys. The HeteroGNN
model, specifically its HeteroConv
layers, requires an edge_index_dict with keys of type (str, str, str)
to define the different edge types in a heterogeneous graph. This incompatibility leads to a RuntimeError
during tracing.
Minimal example to reproduce the error:
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear
from torch.utils.tensorboard import SummaryWriter
# Load a heterogeneous graph dataset
dataset = OGB_MAG(root='./data', preprocess='metapath2vec', transform=T.ToUndirected())
data = dataset[0]
# Define the HeteroGNN model
class HeteroGNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_layers):
super().__init__()
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels),
('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels),
('paper', 'rev_writes', 'author'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
}, aggr='sum')
self.convs.append(conv)
self.lin = Linear(hidden_channels, out_channels)
def forward(self, x_dict, edge_index_dict):
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: x.relu() for key, x in x_dict.items()}
return self.lin(x_dict['author'])
# Instantiate the model and initialize lazy modules
model = HeteroGNN(hidden_channels=64, out_channels=dataset.num_classes, num_layers=2)
with torch.no_grad():
out = model(data.x_dict, data.edge_index_dict)
# Attempt to add the graph to TensorBoard, which fails
writer = SummaryWriter()
writer.add_graph(model, (data.x_dict, data.edge_index_dict,))
writer.close()
Observed error:
File ~<Path>/.venv/lib/python3.12/site-packages/torch/jit/_trace.py:1282, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
1280 else:
1281 example_inputs = make_tuple(example_inputs)
1282 module._c._create_method_from_trace(
1283 method_name,
1284 func,
1285 example_inputs,
1286 var_lookup_fn,
1287 strict,
1288 _force_outplace,
1289 argument_names,
1290 _store_inputs,
1291 )
1293 check_trace_method = module._c._get_method(method_name)
1295 # Check the trace against new traces created from user-specified inputs
RuntimeError: Cannot create dict for key type '(str, str, str)', only int, float, complex, Tensor, device and string keys are supported
A workaround is to create a wrapper class that converts the dictionary inputs into lists of tensors and then reconstructs the dictionaries inside the wrapper's forward
method. This allows jit.trace
to work as it can handle lists of tensors.
class HeteroGNNWrapper(torch.nn.Module):
def __init__(self, model: torch.nn.Module, x_dict_keys, edge_index_dict_keys):
super().__init__()
self.model = model
self.x_dict_keys = x_dict_keys
self.edge_index_dict_keys = edge_index_dict_keys
def forward(self, x_list, edge_index_list):
# Reconstruct the dicts from the lists of tensors and keys
x_dict = {key: tensor for key, tensor in zip(self.x_dict_keys, x_list)}
edge_index_dict = {key: tensor for key, tensor in zip(self.edge_index_dict_keys, edge_index_list)}
# Call the original model's forward method
return self.model(x_dict, edge_index_dict)
# Prepare the data for the wrapper
x_list = list(data.x_dict.values())
x_dict_keys = list(data.x_dict.keys())
edge_index_list = list(data.edge_index_dict.values())
edge_index_dict_keys = list(data.edge_index_dict.keys())
# Instantiate the wrapper and add the graph to TensorBoard
wrapped_model = HeteroGNNWrapper(model, x_dict_keys, edge_index_dict_keys)
writer = SummaryWriter()
writer.add_graph(wrapped_model, (x_list, edge_index_list,))
writer.close()
Versions
PyTorch version: 2.8.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.2 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39
Python version: 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 5 7530U with Radeon Graphics
CPU family: 25
Model: 80
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 1
Stepping: 0
CPU(s) scaling MHz: 72%
CPU max MHz: 4546.0000
CPU min MHz: 400.0000
BogoMIPS: 3992.46
Versions of relevant libraries:
[pip3] numpy==2.3.2
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.3
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] torch==2.8.0
[pip3] torch-geometric==2.6.1
[pip3] triton==3.4.0
[conda] blas 1.0 mkl
[conda] intel-openmp 2022.0.1 h06a4308_3633
[conda] mkl 2023.2.0 h84fe81f_50496 conda-forge
[conda] mkl-service 2.4.0 py312h5eee18b_1
[conda] mkl_fft 1.3.8 py312h5eee18b_0
[conda] mkl_random 1.2.4 py312hdb19cb5_0
[conda] numpy 1.26.4 py312hc5e2394_0
[conda] numpy-base 1.26.4 py312h0da6c21_0
[conda] tbb 2021.8.0 hdb19cb5_0