Source code for graphnetz.models._gcn
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
[docs]
class GCN(torch.nn.Module):
"""Two-layer Graph Convolutional Network.
Parameters
----------
in_channels : int
The number of input features.
hidden_channels : int
The number of hidden features.
out_channels : int
The number of output features.
References
----------
.. [Kipf2017] Kipf, T. N., & Welling, M. (2017).
"Semi-Supervised Classification with Graph Convolutional Networks."
arXiv:1609.02907.
"""
def __init__(self, in_channels: int, hidden_channels: int, out_channels: int) -> None:
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
[docs]
def forward(self, data: Data) -> torch.Tensor:
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = torch.relu(x)
return self.conv2(x, edge_index)