graphnetz.models

class graphnetz.models.DGI(*args: Any, **kwargs: Any)[source]

Bases: Module

Deep Graph Infomax for unsupervised node representation learning.

References

forward(data: torch_geometric.data.Data) tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]
loss(pos_z: torch.Tensor, neg_z: torch.Tensor, summary: torch.Tensor) torch.Tensor[source]
class graphnetz.models.GAT(*args: Any, **kwargs: Any)[source]

Bases: Module

Two-layer Graph Attention Network.

References

[Velickovic2018]

Veličković, P. et al. (2018). “Graph Attention Networks.” ICLR.

forward(data: torch_geometric.data.Data) torch.Tensor[source]
class graphnetz.models.GCN(*args: Any, **kwargs: Any)[source]

Bases: Module

Two-layer Graph Convolutional Network.

Parameters:
  • in_channels (int) – The number of input features.

  • hidden_channels (int) – The number of hidden features.

  • out_channels (int) – The number of output features.

References

[Kipf2017]

Kipf, T. N., & Welling, M. (2017). “Semi-Supervised Classification with Graph Convolutional Networks.” arXiv:1609.02907.

forward(data: torch_geometric.data.Data) torch.Tensor[source]
class graphnetz.models.GIN(*args: Any, **kwargs: Any)[source]

Bases: Module

Graph Isomorphism Network for graph-level prediction.

References

[Xu2019]

Xu, K. et al. (2019). “How Powerful are Graph Neural Networks?” ICLR.

forward(data: torch_geometric.data.Data) torch.Tensor[source]
class graphnetz.models.GraphSAGE(*args: Any, **kwargs: Any)[source]

Bases: Module

Two-layer GraphSAGE for node-level prediction.

References

[Hamilton2017]

Hamilton, W. L., Ying, R., & Leskovec, J. (2017). “Inductive Representation Learning on Large Graphs.” NeurIPS.

forward(data: torch_geometric.data.Data) torch.Tensor[source]
class graphnetz.models.GraphTransformer(*args: Any, **kwargs: Any)[source]

Bases: Module

Two-layer graph transformer based on TransformerConv.

References

[Shi2021]

Shi, Y. et al. (2021). “Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification.” IJCAI.

forward(data: torch_geometric.data.Data) torch.Tensor[source]

Task type adapters

Task-type adapters that turn any node-level encoder into a graph-level classifier/regressor or a Deep Graph Infomax model.

This is the glue that lets GCN, GAT, GraphSAGE, and the Graph Transformer plug into every benchmark task in the library, not just node classification.

class graphnetz.models._adapters.DGIWrapper(*args: Any, **kwargs: Any)[source]

Bases: Module

Wrap any node-level encoder as a Deep Graph Infomax model.

Mirrors the graphnetz.models.DGI interface (forward(data) returning the (pos_z, neg_z, summary) triple, plus a loss(...) helper) so the benchmark trainer does not need to special-case it.

forward(data: torch_geometric.data.Data) tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]
loss(pos_z: torch.Tensor, neg_z: torch.Tensor, summary: torch.Tensor) torch.Tensor[source]
class graphnetz.models._adapters.GraphLevelWrapper(*args: Any, **kwargs: Any)[source]

Bases: Module

Wrap a node-level encoder for graph-level prediction.

The encoder is expected to map a PyG Data batch to per-node features of shape [N, hidden_channels]. The wrapper adds a global mean pool over the batch index and a linear classification/regression head.

forward(data: torch_geometric.data.Data) torch.Tensor[source]
class graphnetz.models._adapters.LinkPredWrapper(*args: Any, **kwargs: Any)[source]

Bases: Module

Wrap any node-level encoder as a link predictor with a dot-product decoder.

The wrapper exposes encode(data) returning per-node embeddings of shape [N, hidden_channels] and decode(z, edge_label_index) returning a [E] tensor of edge logits.

encode(data: torch_geometric.data.Data) torch.Tensor[source]
static decode(z: torch.Tensor, edge_label_index: torch.Tensor) torch.Tensor[source]
forward(data: torch_geometric.data.Data) torch.Tensor[source]
class graphnetz.models._adapters.RelationalLinkPredWrapper(*args: Any, **kwargs: Any)[source]

Bases: Module

Wrap any node-level encoder as a relational link predictor (DistMult).

Learns a relation embedding matrix and scores triples via (z[h] * r * z[t]).sum() (element-wise product, DistMult).

encode(data: torch_geometric.data.Data) torch.Tensor[source]
decode(z: torch.Tensor, edge_label_index: torch.Tensor, edge_type: torch.Tensor) torch.Tensor[source]
forward(data: torch_geometric.data.Data) torch.Tensor[source]