Source code for graphnetz.models._dgi

import torch
from torch_geometric.data import Data
from torch_geometric.nn import DeepGraphInfomax, GCNConv


def _corruption(x: torch.Tensor, edge_index: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    return x[torch.randperm(x.size(0))], edge_index


class _Encoder(torch.nn.Module):
    def __init__(self, in_channels: int, hidden_channels: int) -> None:
        super().__init__()
        self.conv = GCNConv(in_channels, hidden_channels)
        self.prelu = torch.nn.PReLU(hidden_channels)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        return self.prelu(self.conv(x, edge_index))


[docs] class DGI(torch.nn.Module): """Deep Graph Infomax for unsupervised node representation learning. References ---------- .. [1] Veličković, P. et al. (2019). "Deep Graph Infomax." ICLR. """ def __init__(self, in_channels: int, hidden_channels: int = 512) -> None: super().__init__() self.model = DeepGraphInfomax( hidden_channels=hidden_channels, encoder=_Encoder(in_channels, hidden_channels), summary=lambda z, *_: torch.sigmoid(z.mean(dim=0)), corruption=_corruption, )
[docs] def forward(self, data: Data) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.model(data.x, data.edge_index)
[docs] def loss(self, pos_z: torch.Tensor, neg_z: torch.Tensor, summary: torch.Tensor) -> torch.Tensor: return self.model.loss(pos_z, neg_z, summary)