-
-
Notifications
You must be signed in to change notification settings - Fork 21
Add GNN-Based Predictor with DAG Preprocessing #430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 34 commits
12fad57
be734dd
61c6824
75875ff
081651c
b82dc01
5ebd202
857cd6f
6081f6b
7c54da6
bb4da24
10bb52c
06be0d6
ce990e3
96ca75b
a64a082
f8c99b5
e4e2742
5784ff7
7e17379
082de05
5ed00a9
e59a941
3a9f16c
c43ee01
92eda99
dc1aa55
6809ccb
8c77598
dc0a824
2419952
5335241
96096a0
4613012
1c728e2
c31cb46
13cf0f4
713343f
17c6575
169b00e
2248081
8f90b12
74ec34b
f99e17b
57b1a29
96232f0
61965d8
93f5414
95f5359
4fb7112
5ea1720
312e6ea
156b7e6
77c9f5c
1ff39e1
5a2d583
820895c
8b0b51f
de738df
794896e
29761dc
3716688
65b409b
ef8e567
a734050
a43dd5a
07e937e
1e29f6b
6dc9d28
f664616
f97e164
2e9cc1e
1a8b1a4
595469e
15d4814
ca5339e
ddb6f50
a2cca5c
af38ff4
fc48550
92897cb
f27358f
3af112a
88256c9
2cf80d1
be5210b
2e28a14
1e4a279
e9f698e
ea825ff
e28dc60
620a1d7
c258fe0
29cf56a
f412af8
3c077b1
8b79177
2a9457d
77cca31
1acc0cb
64a94ff
c4cc835
7acbec7
6ebded4
6207f97
5828e7f
a09b7dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -44,6 +44,8 @@ dependencies = [ | |||
| "numpy>=1.24; python_version >= '3.11'", | ||||
| "numpy>=1.22", | ||||
| "numpy>=1.22,<2; sys_platform == 'darwin' and 'x86_64' in platform_machine and python_version < '3.13'", # Restrict numpy v2 for macOS x86 since it is not supported anymore since torch v2.3.0 | ||||
| "optuna>=4.5.0", | ||||
| "torch-geometric>=2.6.1", | ||||
| "torch>=2.7.1,<2.8.0; sys_platform == 'darwin' and 'x86_64' in platform_machine and python_version < '3.13'", # Restrict torch v2.3.0 for macOS x86 since it is not supported anymore. | ||||
| "typing-extensions>=4.1", # for `assert_never` | ||||
| ] | ||||
|
|
@@ -164,9 +166,15 @@ implicit_reexport = true | |||
| # recent versions of `gym` are typed, but stable-baselines3 pins a very old version of gym. | ||||
| # qiskit is not yet marked as typed, but is typed mostly. | ||||
| # the other libraries do not have type stubs. | ||||
| module = ["qiskit.*", "joblib.*", "sklearn.*", "matplotlib.*", "gymnasium.*", "mqt.bench.*", "sb3_contrib.*", "bqskit.*", "qiskit_ibm_runtime.*", "networkx.*", "stable_baselines3.*"] | ||||
| module = ["qiskit.*", "joblib.*", "sklearn.*", "matplotlib.*", "gymnasium.*", "mqt.bench.*", "sb3_contrib.*", "bqskit.*", "qiskit_ibm_runtime.*", "networkx.*", "stable_baselines3.*", "torch", "torch.*", "torch_geometric", "torch_geometric.*", "optuna.*"] | ||||
| ignore_missing_imports = true | ||||
|
|
||||
| [[tool.mypy.overrides]] | ||||
| module = ["mqt.predictor.ml.*"] | ||||
| disallow_subclassing_any = false | ||||
|
|
||||
|
|
||||
|
|
||||
|
||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file should not be tracked and can be removed. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| # Copyright (c) 2023 - 2025 Chair for Design Automation, TUM | ||
| # Copyright (c) 2025 Munich Quantum Software Company GmbH | ||
| # All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: MIT | ||
| # | ||
| # Licensed under the MIT License | ||
|
|
||
| # file generated by setuptools-scm | ||
| # don't change, don't track in version control | ||
| from __future__ import annotations | ||
|
|
||
| __all__ = [ | ||
| "__commit_id__", | ||
| "__version__", | ||
| "__version_tuple__", | ||
| "commit_id", | ||
| "version", | ||
| "version_tuple", | ||
| ] | ||
|
|
||
| TYPE_CHECKING = False | ||
| if TYPE_CHECKING: | ||
| VERSION_TUPLE = tuple[int | str, ...] | ||
Check warningCode scanning / CodeQL Unreachable code Warning
This statement is unreachable.
|
||
| COMMIT_ID = str | None | ||
| else: | ||
| VERSION_TUPLE = object | ||
| COMMIT_ID = object | ||
|
|
||
| version: str | ||
| __version__: str | ||
| __version_tuple__: VERSION_TUPLE | ||
| version_tuple: VERSION_TUPLE | ||
| commit_id: COMMIT_ID | ||
| __commit_id__: COMMIT_ID | ||
|
|
||
| __version__ = version = "2.3.1.dev6+g1d835bd4c" | ||
| __version_tuple__ = version_tuple = (2, 3, 1, "dev6", "g1d835bd4c") | ||
|
|
||
| __commit_id__ = commit_id = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to change this file. GNN should never be accessed by a user directly. |
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,174 @@ | ||||
| # Copyright (c) 2023 - 2025 Chair for Design Automation, TUM | ||||
| # Copyright (c) 2025 Munich Quantum Software Company GmbH | ||||
| # All rights reserved. | ||||
| # | ||||
| # SPDX-License-Identifier: MIT | ||||
| # | ||||
| # Licensed under the MIT License | ||||
|
|
||||
| """This module contains the GNN module for graph neural networks.""" | ||||
|
|
||||
| from __future__ import annotations | ||||
|
|
||||
| import warnings | ||||
| from typing import TYPE_CHECKING, Any | ||||
|
|
||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as functional | ||||
| from torch_geometric.nn import SAGEConv, global_mean_pool | ||||
|
|
||||
| if TYPE_CHECKING: | ||||
| from collections.abc import ( | ||||
| Callable, # on 3.10+ prefer collections.abc | ||||
| ) | ||||
|
|
||||
| from torch_geometric.data import Data | ||||
|
|
||||
|
|
||||
| class GraphConvolutionSage(nn.Module): | ||||
| """Graph convolutional layer using SAGEConv.""" | ||||
|
|
||||
| def __init__( | ||||
| self, | ||||
| in_feats: int, | ||||
| hidden_dim: int, | ||||
| num_resnet_layers: int, | ||||
| *, | ||||
| conv_activation: Callable[..., torch.Tensor] = functional.leaky_relu, | ||||
| conv_act_kwargs: dict[str, Any] | None = None, | ||||
| ) -> None: | ||||
| """A flexible SageConv graph classification model. | ||||
| Args: | ||||
| in_feats: dimensionality of node features | ||||
| hidden_dim: output size of SageConv | ||||
| num_resnet_layers: how many SageConv layers (with residuals) to stack after the SageConvs | ||||
| mlp_units: list of units for each layer of the final MLP | ||||
| conv_activation: activation fn after each graph layer | ||||
| conv_act_kwargs: extra kwargs for conv_activation | ||||
| final_activation: activation applied to the final scalar output | ||||
| """ | ||||
| super().__init__() | ||||
| self.conv_activation = conv_activation | ||||
| self.conv_act_kwargs = conv_act_kwargs or {} | ||||
|
|
||||
| # --- GRAPH ENCODER --- | ||||
| self.convs = nn.ModuleList() | ||||
| # 1) Convolution not in residual configuration | ||||
| # Possible to generalize the code | ||||
| self.convs.append(SAGEConv(in_feats, hidden_dim)) | ||||
| self.convs.append(SAGEConv(hidden_dim, hidden_dim)) | ||||
|
|
||||
| for _ in range(num_resnet_layers): | ||||
| self.convs.append(SAGEConv(hidden_dim, hidden_dim)) | ||||
|
|
||||
| def forward(self, data: Data) -> torch.Tensor: | ||||
| """Forward function that allows to elaborate the input graph.""" | ||||
| x, edge_index, batch = data.x, data.edge_index, data.batch | ||||
| # 1) Graph stack with residuals | ||||
| for i, conv in enumerate(self.convs): | ||||
| x_new = conv(x, edge_index) | ||||
| x_new = self.conv_activation(x_new, **self.conv_act_kwargs) | ||||
| # the number 2 is set because two convolution without residual configuration are applied | ||||
| # and then all the others are in residual configuration | ||||
| x = x_new if i < 2 else x + x_new | ||||
|
|
||||
| # 2) Global pooling | ||||
| return global_mean_pool(x, batch) | ||||
|
|
||||
| # 3) MLP head | ||||
|
||||
| # 3) MLP head |
Uh oh!
There was an error while loading. Please reload this page.