Source code for graphnetz.models._gin
import torch
from torch.nn import BatchNorm1d, Linear, ReLU, Sequential
from torch_geometric.data import Data
from torch_geometric.nn import GINConv, global_add_pool
def _mlp(in_dim: int, out_dim: int) -> Sequential:
return Sequential(Linear(in_dim, out_dim), BatchNorm1d(out_dim), ReLU(), Linear(out_dim, out_dim), ReLU())
[docs]
class GIN(torch.nn.Module):
"""Graph Isomorphism Network for graph-level prediction.
References
----------
.. [Xu2019] Xu, K. et al. (2019). "How Powerful are Graph Neural Networks?" ICLR.
"""
def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int = 3) -> None:
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(GINConv(_mlp(in_channels, hidden_channels), train_eps=True))
for _ in range(num_layers - 1):
self.convs.append(GINConv(_mlp(hidden_channels, hidden_channels), train_eps=True))
self.classifier = Linear(hidden_channels, out_channels)
[docs]
def forward(self, data: Data) -> torch.Tensor:
x, edge_index, batch = data.x, data.edge_index, data.batch
for conv in self.convs:
x = conv(x, edge_index)
x = global_add_pool(x, batch)
return self.classifier(x)