Source code for graphnetz.datasets.biology

"""Health and biology datasets.

Coverage:

- Molecular: PyG TUDataset (MUTAG, PROTEINS, ENZYMES).
- Long-range peptides: PyG ``LRGBDataset`` (Peptides-func graph
  classification, Peptides-struct graph regression — Dwivedi et al., NeurIPS
  2022 long-range graph benchmark).
- Protein-protein interaction: PyG ``PPI`` (inductive multi-graph).
- Metabolic: Netzschleuder ``celegans_metabolic``.
- Brain connectomes: Netzschleuder ``budapest_connectome``.
- Epidemiology: Netzschleuder ``sp_hospital`` and ``sp_high_school`` contact graphs.
- Open Graph Benchmark (optional ``ogb`` extra): ``ogbg_molhiv`` (~41 K
  molecules, binary HIV-inhibition), ``ogbg_molpcba`` (~438 K
  molecules, 128 binary bioassay tasks). Both also need the ``chem``
  extra for RDKit featurisation.

Patient-disease-treatment knowledge graphs have no canonical free dataset and
are intentionally omitted.
"""

from typing import Any

from torch_geometric.datasets import PPI, LRGBDataset, TUDataset

from graphnetz.datasets._netz import Netz
from graphnetz.datasets._ogb import load_ogb_graph


[docs] def mutag(root: str) -> TUDataset: """Mutagenicity: 188 molecules, binary class.""" return TUDataset(root=root, name="MUTAG")
[docs] def proteins(root: str) -> TUDataset: """Proteins: 1113 graphs, binary class.""" return TUDataset(root=root, name="PROTEINS")
[docs] def enzymes(root: str) -> TUDataset: """Enzymes: 600 graphs, 6 classes.""" return TUDataset(root=root, name="ENZYMES")
[docs] def ppi(root: str, split: str = "train") -> PPI: """Protein-protein interaction (inductive node multi-label classification).""" return PPI(root=root, split=split)
[docs] def celegans(root: str) -> Netz: """C. elegans metabolic network (Netzschleuder).""" return Netz(root=root, dataset_name="celegans_metabolic", network_name="celegans_metabolic")
[docs] def budapest_connectome(root: str, network_name: str = "100m_avg") -> Netz: """Budapest reference connectome (mean connectivity across 100 subjects).""" return Netz(root=root, dataset_name="budapest_connectome", network_name=network_name)
[docs] def hospital_contacts(root: str) -> Netz: """Sociopatterns hospital ward face-to-face contact network.""" return Netz(root=root, dataset_name="sp_hospital", network_name="sp_hospital")
[docs] def high_school_contacts(root: str) -> Netz: """Sociopatterns high-school contact network.""" return Netz(root=root, dataset_name="sp_high_school", network_name="sp_high_school")
[docs] def peptides_func(root: str, split: str = "train") -> LRGBDataset: """Peptides-func: long-range graph classification (10-way multilabel).""" return LRGBDataset(root=root, name="Peptides-func", split=split)
[docs] def peptides_struct(root: str, split: str = "train") -> LRGBDataset: """Peptides-struct: long-range graph regression (11 structural targets).""" return LRGBDataset(root=root, name="Peptides-struct", split=split)
[docs] def ogbg_molhiv(root: str) -> Any: """OGB MolHIV: ~41 K molecules, binary HIV-inhibition labels.""" return load_ogb_graph("ogbg-molhiv", root)
[docs] def ogbg_molpcba(root: str) -> Any: """OGB MolPCBA: ~438 K molecules, 128 binary bioassay labels.""" return load_ogb_graph("ogbg-molpcba", root)
__all__ = [ "budapest_connectome", "celegans", "enzymes", "high_school_contacts", "hospital_contacts", "mutag", "ogbg_molhiv", "ogbg_molpcba", "peptides_func", "peptides_struct", "ppi", "proteins", ]