graphnetz.training

Lightweight training loops shared by the example notebooks.

Each function returns a plain dict of per-epoch metrics, ready to feed into graphnetz.plotting.plot_history().

All trainers accept device='auto' (the default), which dispatches to CUDA when available, then Apple-silicon MPS, then CPU. Pass an explicit torch.device or string to pin placement.

graphnetz.training.train_dgi(model: _DGILike, data: Data, epochs: int = 100, lr: float = 0.001, verbose: bool = False, device: torch.device | str | None = 'auto') dict[str, list[float]][source]

Train a Deep Graph Infomax model (unsupervised).

graphnetz.training.train_graph_classification(model: torch.nn.Module, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 30, lr: float = 0.001, verbose: bool = False, device: torch.device | str | None = 'auto') dict[str, list[float]][source]

Train a graph-level classifier.

Handles single-label and multi-label datasets transparently: when batch.y is shaped [B, C] with float dtype (e.g. LRGB Peptides-func, OGB molhiv variants), the loss switches to binary cross-entropy with logits and the reported metric is the average correctly-classified label fraction.

graphnetz.training.train_graph_regression(model: torch.nn.Module, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 30, lr: float = 0.001, verbose: bool = False, device: torch.device | str | None = 'auto') dict[str, list[float]][source]

Train a graph-level regressor (MSE loss, MAE on val).

Train a link predictor with binary cross-entropy on RandomLinkSplit.

The model is expected to expose encode(data) returning per-node embeddings and decode(z, edge_label_index) returning per-edge scores (see graphnetz.models._adapters.LinkPredWrapper).

graphnetz.training.train_node_classification(model: torch.nn.Module, data: Data, epochs: int = 100, lr: float = 0.01, weight_decay: float = 0.0005, verbose: bool = False, device: torch.device | str | None = 'auto') dict[str, list[float]][source]

Train a node classifier with Planetoid-style train/val/test masks.

graphnetz.training.train_node_degree_regression(model: torch.nn.Module, data: Data, epochs: int = 100, lr: float = 0.01, verbose: bool = False, device: torch.device | str | None = 'auto') dict[str, list[float]][source]

Self-supervised node-level regression: predict log node degree.

Train a relational link predictor (DistMult) on knowledge graph triples.

The model is expected to expose encode(data) returning per-node embeddings and decode(z, edge_index, edge_type) returning per-edge scores (see graphnetz.models._adapters.RelationalLinkPredWrapper).