Source code for graphnetz.models._gat
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
[docs]
class GAT(torch.nn.Module):
"""Two-layer Graph Attention Network.
References
----------
.. [Velickovic2018] Veličković, P. et al. (2018). "Graph Attention Networks." ICLR.
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
heads: int = 8,
dropout: float = 0.6,
) -> None:
super().__init__()
self.dropout = dropout
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout)
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=dropout)
[docs]
def forward(self, data: Data) -> torch.Tensor:
x, edge_index = data.x, data.edge_index
x = torch.nn.functional.dropout(x, p=self.dropout, training=self.training)
x = torch.nn.functional.elu(self.conv1(x, edge_index))
x = torch.nn.functional.dropout(x, p=self.dropout, training=self.training)
return self.conv2(x, edge_index)