graphnetz.training¶
Lightweight training loops shared by the example notebooks.
Each function returns a plain dict of per-epoch metrics, ready to feed into
graphnetz.plotting.plot_history().
All trainers accept device='auto' (the default), which dispatches to
CUDA when available, then Apple-silicon MPS, then CPU. Pass an explicit
torch.device or string to pin placement.
- graphnetz.training.train_dgi(model: _DGILike, data: Data, epochs: int = 100, lr: float = 0.001, verbose: bool = False, device: torch.device | str | None = 'auto') dict[str, list[float]][source]¶
Train a Deep Graph Infomax model (unsupervised).
- graphnetz.training.train_graph_classification(model: torch.nn.Module, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 30, lr: float = 0.001, verbose: bool = False, device: torch.device | str | None = 'auto') dict[str, list[float]][source]¶
Train a graph-level classifier.
Handles single-label and multi-label datasets transparently: when
batch.yis shaped[B, C]with float dtype (e.g. LRGBPeptides-func, OGB molhiv variants), the loss switches to binary cross-entropy with logits and the reported metric is the average correctly-classified label fraction.
- graphnetz.training.train_graph_regression(model: torch.nn.Module, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 30, lr: float = 0.001, verbose: bool = False, device: torch.device | str | None = 'auto') dict[str, list[float]][source]¶
Train a graph-level regressor (MSE loss, MAE on val).
- graphnetz.training.train_link_prediction(model: _LinkPredLike, train_data: Data, val_data: Data, test_data: Data, epochs: int = 100, lr: float = 0.01, verbose: bool = False, device: torch.device | str | None = 'auto') dict[str, list[float]][source]¶
Train a link predictor with binary cross-entropy on RandomLinkSplit.
The model is expected to expose
encode(data)returning per-node embeddings anddecode(z, edge_label_index)returning per-edge scores (seegraphnetz.models._adapters.LinkPredWrapper).
- graphnetz.training.train_node_classification(model: torch.nn.Module, data: Data, epochs: int = 100, lr: float = 0.01, weight_decay: float = 0.0005, verbose: bool = False, device: torch.device | str | None = 'auto') dict[str, list[float]][source]¶
Train a node classifier with Planetoid-style train/val/test masks.
- graphnetz.training.train_node_degree_regression(model: torch.nn.Module, data: Data, epochs: int = 100, lr: float = 0.01, verbose: bool = False, device: torch.device | str | None = 'auto') dict[str, list[float]][source]¶
Self-supervised node-level regression: predict log node degree.
- graphnetz.training.train_relational_link_prediction(model: _RelationalLinkPredLike, train_data: Data, val_data: Data, test_data: Data, epochs: int = 100, lr: float = 0.01, verbose: bool = False, device: torch.device | str | None = 'auto') dict[str, list[float]][source]¶
Train a relational link predictor (DistMult) on knowledge graph triples.
The model is expected to expose
encode(data)returning per-node embeddings anddecode(z, edge_index, edge_type)returning per-edge scores (seegraphnetz.models._adapters.RelationalLinkPredWrapper).