Source code for graphnetz.models._graphsage

import torch
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv


[docs] class GraphSAGE(torch.nn.Module): """Two-layer GraphSAGE for node-level prediction. References ---------- .. [Hamilton2017] Hamilton, W. L., Ying, R., & Leskovec, J. (2017). "Inductive Representation Learning on Large Graphs." NeurIPS. """ def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, aggr: str = "mean", dropout: float = 0.5, ) -> None: super().__init__() self.dropout = dropout self.conv1 = SAGEConv(in_channels, hidden_channels, aggr=aggr) self.conv2 = SAGEConv(hidden_channels, out_channels, aggr=aggr)
[docs] def forward(self, data: Data) -> torch.Tensor: x, edge_index = data.x, data.edge_index x = torch.relu(self.conv1(x, edge_index)) x = torch.nn.functional.dropout(x, p=self.dropout, training=self.training) return self.conv2(x, edge_index)