graphnetz.models¶
- class graphnetz.models.DGI(*args: Any, **kwargs: Any)[source]¶
Bases:
ModuleDeep Graph Infomax for unsupervised node representation learning.
References
- forward(data: torch_geometric.data.Data) tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]¶
- loss(pos_z: torch.Tensor, neg_z: torch.Tensor, summary: torch.Tensor) torch.Tensor[source]¶
- class graphnetz.models.GAT(*args: Any, **kwargs: Any)[source]¶
Bases:
ModuleTwo-layer Graph Attention Network.
References
[Velickovic2018]Veličković, P. et al. (2018). “Graph Attention Networks.” ICLR.
- forward(data: torch_geometric.data.Data) torch.Tensor[source]¶
- class graphnetz.models.GCN(*args: Any, **kwargs: Any)[source]¶
Bases:
ModuleTwo-layer Graph Convolutional Network.
- Parameters:
References
[Kipf2017]Kipf, T. N., & Welling, M. (2017). “Semi-Supervised Classification with Graph Convolutional Networks.” arXiv:1609.02907.
- forward(data: torch_geometric.data.Data) torch.Tensor[source]¶
- class graphnetz.models.GIN(*args: Any, **kwargs: Any)[source]¶
Bases:
ModuleGraph Isomorphism Network for graph-level prediction.
References
[Xu2019]Xu, K. et al. (2019). “How Powerful are Graph Neural Networks?” ICLR.
- forward(data: torch_geometric.data.Data) torch.Tensor[source]¶
- class graphnetz.models.GraphSAGE(*args: Any, **kwargs: Any)[source]¶
Bases:
ModuleTwo-layer GraphSAGE for node-level prediction.
References
[Hamilton2017]Hamilton, W. L., Ying, R., & Leskovec, J. (2017). “Inductive Representation Learning on Large Graphs.” NeurIPS.
- forward(data: torch_geometric.data.Data) torch.Tensor[source]¶
- class graphnetz.models.GraphTransformer(*args: Any, **kwargs: Any)[source]¶
Bases:
ModuleTwo-layer graph transformer based on TransformerConv.
References
[Shi2021]Shi, Y. et al. (2021). “Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification.” IJCAI.
- forward(data: torch_geometric.data.Data) torch.Tensor[source]¶
Task type adapters¶
Task-type adapters that turn any node-level encoder into a graph-level classifier/regressor or a Deep Graph Infomax model.
This is the glue that lets GCN, GAT, GraphSAGE, and the Graph Transformer plug into every benchmark task in the library, not just node classification.
- class graphnetz.models._adapters.DGIWrapper(*args: Any, **kwargs: Any)[source]¶
Bases:
ModuleWrap any node-level encoder as a Deep Graph Infomax model.
Mirrors the
graphnetz.models.DGIinterface (forward(data)returning the(pos_z, neg_z, summary)triple, plus aloss(...)helper) so the benchmark trainer does not need to special-case it.- forward(data: torch_geometric.data.Data) tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]¶
- loss(pos_z: torch.Tensor, neg_z: torch.Tensor, summary: torch.Tensor) torch.Tensor[source]¶
- class graphnetz.models._adapters.GraphLevelWrapper(*args: Any, **kwargs: Any)[source]¶
Bases:
ModuleWrap a node-level encoder for graph-level prediction.
The encoder is expected to map a PyG
Databatch to per-node features of shape[N, hidden_channels]. The wrapper adds a global mean pool over the batch index and a linear classification/regression head.- forward(data: torch_geometric.data.Data) torch.Tensor[source]¶
- class graphnetz.models._adapters.LinkPredWrapper(*args: Any, **kwargs: Any)[source]¶
Bases:
ModuleWrap any node-level encoder as a link predictor with a dot-product decoder.
The wrapper exposes
encode(data)returning per-node embeddings of shape[N, hidden_channels]anddecode(z, edge_label_index)returning a[E]tensor of edge logits.- encode(data: torch_geometric.data.Data) torch.Tensor[source]¶
- static decode(z: torch.Tensor, edge_label_index: torch.Tensor) torch.Tensor[source]¶
- forward(data: torch_geometric.data.Data) torch.Tensor[source]¶
- class graphnetz.models._adapters.RelationalLinkPredWrapper(*args: Any, **kwargs: Any)[source]¶
Bases:
ModuleWrap any node-level encoder as a relational link predictor (DistMult).
Learns a relation embedding matrix and scores triples via
(z[h] * r * z[t]).sum()(element-wise product, DistMult).- encode(data: torch_geometric.data.Data) torch.Tensor[source]¶
- decode(z: torch.Tensor, edge_label_index: torch.Tensor, edge_type: torch.Tensor) torch.Tensor[source]¶
- forward(data: torch_geometric.data.Data) torch.Tensor[source]¶