Source code for graphnetz.models._graph_transformer

import torch
from torch_geometric.data import Data
from torch_geometric.nn import TransformerConv


[docs] class GraphTransformer(torch.nn.Module): """Two-layer graph transformer based on TransformerConv. References ---------- .. [Shi2021] Shi, Y. et al. (2021). "Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification." IJCAI. """ def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, heads: int = 4, dropout: float = 0.1, ) -> None: super().__init__() self.dropout = dropout self.conv1 = TransformerConv(in_channels, hidden_channels, heads=heads, dropout=dropout) self.conv2 = TransformerConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=dropout)
[docs] def forward(self, data: Data) -> torch.Tensor: x, edge_index = data.x, data.edge_index x = torch.relu(self.conv1(x, edge_index)) x = torch.nn.functional.dropout(x, p=self.dropout, training=self.training) return self.conv2(x, edge_index)