Source code for graphnetz.benchmark

"""Statistically robust benchmarks across a category for one or many models.

The dispatcher trains every compatible (model, task) pair across multiple
seeds and returns a :class:`BenchmarkReport` that exposes mean ± 95 % t-CI,
paired t-tests with Holm-Bonferroni correction, publication-ready LaTeX
tables, and plots.

Custom models are plugged in via the same three paths as before:

1. **Decorator / registry**::

       from graphnetz import register_model


       @register_model(task_type="node_cls")
       class MyGNN(torch.nn.Module):
           def __init__(self, in_channels, hidden_channels, out_channels): ...

2. **Class attribute**::

       class MyGNN(torch.nn.Module):
           task_types = {"node_cls"}

3. **Inline tuple** ``(cls, tasks)`` or ``(cls, tasks, factory)`` in the
   ``models`` mapping::

       run_benchmark("social", {"MyGNN": (MyGNN, "node_cls")})

The default factory calls ``cls(in_channels, hidden_channels, out_channels)``;
DGI-task models receive ``(in_channels, hidden_channels)`` (the third argument
is dropped).
"""

from __future__ import annotations

import importlib.util
from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from scipy import stats
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm.auto import tqdm

from graphnetz.datasets import (
    biology,
    combinatorial,
    computing,
    finance,
    infrastructure,
    knowledge,
    physics,
    security,
    social,
    vision,
)
from graphnetz.models import GAT, GCN, GIN, GraphSAGE, GraphTransformer
from graphnetz.plotting import NATURE_COLORS, plot_grouped_bars, save_figure, set_plot_style
from graphnetz.training import (
    train_graph_classification,
    train_graph_regression,
    train_link_prediction,
    train_node_classification,
)

_HAS_OGB = importlib.util.find_spec("ogb") is not None

# DGI is intentionally not a task task_type: it is a self-supervised training
# objective whose "metric" is its own loss, so it cannot serve as a
# held-out evaluation. ``train_dgi`` and the ``DGIWrapper`` adapter remain
# available as utilities for users who want to pre-train an encoder
# unsupervised; the benchmark routes unlabelled graphs through
# ``link_pred`` instead (a real held-out edge split with an AUC metric).
TASK_TYPES: frozenset[str] = frozenset({"node_cls", "graph_cls", "graph_reg", "link_pred"})
_METRIC_KEYS: tuple[str, ...] = (
    "test_acc",
    "test_auc",
    "val_acc",
    "val_auc",
    "val_mae",
)
_LOWER_IS_BETTER: frozenset[str] = frozenset({"val_mae", "train_loss"})


# --------------------------------------------------------------------------- #
# Tasks and model specs
# --------------------------------------------------------------------------- #


[docs] @dataclass(frozen=True) class Task: """A single benchmark task_type: a dataset loader plus its training task.""" name: str task_type: str # ``...`` admits seed-aware loaders ``f(root, *, seed=...)`` alongside # the basic ``f(root)`` shape — the dispatcher inspects the signature # and threads ``seed`` through when present. loader: Callable[..., Any] epochs: int = 30
[docs] @dataclass(frozen=True) class ModelSpec: """How to instantiate a model and which task tasks it supports.""" cls: type task_type: frozenset[str] = field(default_factory=frozenset) factory: Callable[..., torch.nn.Module] | None = None
[docs] def build( self, in_channels: int, hidden_channels: int, out_channels: int, *, task_type: str = "node_cls", ) -> torch.nn.Module: if self.factory is not None: try: return self.factory(in_channels, hidden_channels, out_channels, task_type=task_type) except TypeError: return self.factory(in_channels, hidden_channels, out_channels) if task_type == "dgi": return self.cls(in_channels, hidden_channels) return self.cls(in_channels, hidden_channels, out_channels)
_REGISTRY: dict[type, ModelSpec] = {}
[docs] def register_model( cls: type | None = None, *, task_type: str | Iterable[str], factory: Callable[..., torch.nn.Module] | None = None, ) -> Callable[[type], type] | type: """Register a model with the benchmark dispatcher. Usable as a decorator (``@register_model(task_type="node_cls")``) or as a plain function (``register_model(MyGNN, task_type={"graph_cls", "graph_reg"})``). """ tasks = frozenset({task_type} if isinstance(task_type, str) else task_type) unknown = tasks - TASK_TYPES if unknown: msg = f"Unknown task {sorted(unknown)}; allowed: {sorted(TASK_TYPES)}" raise ValueError(msg) def _register(target: type) -> type: _REGISTRY[target] = ModelSpec(cls=target, task_type=tasks, factory=factory) return target return _register(cls) if cls is not None else _register
def _multi_task_factory(encoder_cls: type) -> Callable[..., torch.nn.Module]: """Adapt a node-level encoder to any of the four tasks. For ``node_cls`` the encoder is built with the dataset's class count as ``out_channels`` and used directly. For ``graph_cls`` and ``graph_reg`` the encoder produces ``hidden_channels`` per node and a :class:`GraphLevelWrapper` adds global mean pooling and a head. For ``dgi`` the encoder is wrapped in :class:`DGIWrapper` so it plugs into the same training loop as :class:`graphnetz.models.DGI`. """ from graphnetz.models._adapters import GraphLevelWrapper, LinkPredWrapper def factory( in_channels: int, hidden_channels: int, out_channels: int, *, task_type: str = "node_cls", ) -> torch.nn.Module: if task_type == "node_cls": return encoder_cls(in_channels, hidden_channels, out_channels) if task_type in ("graph_cls", "graph_reg"): encoder = encoder_cls(in_channels, hidden_channels, hidden_channels) return GraphLevelWrapper(encoder, hidden_channels, out_channels) if task_type == "link_pred": encoder = encoder_cls(in_channels, hidden_channels, hidden_channels) return LinkPredWrapper(encoder) msg = f"Unknown task task_type: {task_type!r}; choices: {sorted(TASK_TYPES)}" raise ValueError(msg) return factory # Pre-register built-ins. Node-level encoders are registered for every # task task via the multi-task factory; GIN keeps its native graph-level # pooling. ``DGI`` is intentionally not registered: it is exposed as a # self-supervised training utility (``train_dgi`` + ``DGIWrapper``) # rather than a benchmark-task model. _ALL_TASKS = frozenset({"node_cls", "graph_cls", "graph_reg", "link_pred"}) register_model(GCN, task_type=_ALL_TASKS, factory=_multi_task_factory(GCN)) register_model(GAT, task_type=_ALL_TASKS, factory=_multi_task_factory(GAT)) register_model(GraphSAGE, task_type=_ALL_TASKS, factory=_multi_task_factory(GraphSAGE)) register_model(GraphTransformer, task_type=_ALL_TASKS, factory=_multi_task_factory(GraphTransformer)) register_model(GIN, task_type={"graph_cls", "graph_reg"}) def _spec_from(value: type | tuple[Any, ...] | ModelSpec) -> ModelSpec: """Resolve a ``models`` dict entry to a :class:`ModelSpec`.""" if isinstance(value, ModelSpec): return value if isinstance(value, tuple): cls = value[0] tasks = value[1] if len(value) >= 2 else None factory = value[2] if len(value) >= 3 else None if tasks is None: base = _spec_from(cls) return ModelSpec(cls=base.cls, task_type=base.task_type, factory=factory or base.factory) ks = frozenset({tasks} if isinstance(tasks, str) else tasks) unknown = ks - TASK_TYPES if unknown: msg = f"Unknown task task_type {sorted(unknown)}; allowed: {sorted(TASK_TYPES)}" raise ValueError(msg) return ModelSpec(cls=cls, task_type=ks, factory=factory) if value in _REGISTRY: return _REGISTRY[value] if hasattr(value, "task_types"): return ModelSpec(cls=value, task_type=frozenset(value.task_types)) if hasattr(value, "task"): return ModelSpec(cls=value, task_type=frozenset({value.task})) return ModelSpec(cls=value, task_type=frozenset()) # --------------------------------------------------------------------------- # # Curated benchmark tasks per category # --------------------------------------------------------------------------- # BENCHMARK_TASKS: dict[str, dict[str, list[Task]]] = { "combinatorial": { "link_pred": [ Task( "random_tsp", "link_pred", lambda root, seed=0: combinatorial.random_tsp(root, num_graphs=1, num_nodes=200, k=4, seed=seed), epochs=80, ), Task( "random_coloring", "link_pred", lambda root, seed=0: combinatorial.random_coloring( root, num_graphs=1, num_nodes=200, edge_prob=0.1, seed=seed ), epochs=80, ), ], }, "biology": { "graph_cls": [ Task("mutag", "graph_cls", biology.mutag, epochs=40), Task("proteins", "graph_cls", biology.proteins, epochs=20), ], "link_pred": [ Task("celegans", "link_pred", biology.celegans, epochs=80), ], }, "social": { "node_cls": [ Task("cora", "node_cls", social.cora, epochs=100), Task("citeseer", "node_cls", social.citeseer, epochs=100), Task("pubmed", "node_cls", social.pubmed, epochs=100), Task("roman_empire", "node_cls", social.roman_empire, epochs=80), Task("minesweeper", "node_cls", social.minesweeper, epochs=80), ], "link_pred": [ Task("cora_link_pred", "link_pred", social.cora, epochs=80), Task("citeseer_link_pred", "link_pred", social.citeseer, epochs=80), ], }, "knowledge": { "link_pred": [ Task("fb15k_237", "link_pred", knowledge.fb15k_237, epochs=20), Task("wordnet18rr", "link_pred", knowledge.wordnet18rr, epochs=20), ], }, "infrastructure": { "link_pred": [ Task("power_grid", "link_pred", infrastructure.power_grid, epochs=80), Task("euroroad", "link_pred", infrastructure.euroroad, epochs=80), ], }, "finance": { "link_pred": [ Task("product_space", "link_pred", finance.product_space, epochs=80), Task("board_directors", "link_pred", finance.board_directors, epochs=40), ], }, "computing": { "link_pred": [ Task("internet_as", "link_pred", lambda root: computing.internet_as(root), epochs=40), Task("topology", "link_pred", computing.topology, epochs=10), ], }, "vision": { "graph_cls": [ Task( "mnist_superpixels", "graph_cls", lambda root: vision.mnist_superpixels(root)[:1500], epochs=4, ), ], }, "physics": { "graph_reg": [ Task( "zinc", "graph_reg", lambda root: ( physics.zinc(root, subset=True, split="train"), physics.zinc(root, subset=True, split="val"), ), epochs=10, ), ], "link_pred": [ Task( "ising_lattice", "link_pred", lambda root, seed=0: physics.ising_lattice(root, num_graphs=1, side=20, seed=seed), epochs=60, ), ], }, "security": { "link_pred": [ Task("terrorists_911", "link_pred", security.terrorists_911, epochs=120), ], }, } if _HAS_OGB: # OGB tasks live in the domain modules; we only register them as # benchmark tasks when the ``ogb`` extra is importable so the # curated catalogue stays runnable without it. BENCHMARK_TASKS["social"]["node_cls"].append( Task("ogbn_arxiv", "node_cls", social.ogbn_arxiv, epochs=50), ) BENCHMARK_TASKS["social"]["link_pred"].append( Task("ogbl_collab", "link_pred", social.ogbl_collab, epochs=20), ) BENCHMARK_TASKS["finance"].setdefault("node_cls", []).append( Task("ogbn_products", "node_cls", finance.ogbn_products, epochs=20), ) BENCHMARK_TASKS["biology"]["graph_cls"].extend( [ Task("ogbg_molhiv", "graph_cls", biology.ogbg_molhiv, epochs=20), Task("ogbg_molpcba", "graph_cls", biology.ogbg_molpcba, epochs=20), ] )
[docs] def iter_benchmark_tasks( category: str | None = None, task_type: str | None = None, ) -> list[Task]: """Flatten ``BENCHMARK_TASKS`` to a list, optionally filtered by category/task. Examples -------- >>> [ ... t.name ... for t in iter_benchmark_tasks(category="biology", task_type="graph_cls") ... ] ['mutag', 'proteins'] """ cats = [category] if category is not None else list(BENCHMARK_TASKS) out: list[Task] = [] for c in cats: per_cat = BENCHMARK_TASKS.get(c, {}) tasks = [task_type] if task_type is not None else list(per_cat) for k in tasks: out.extend(per_cat.get(k, [])) return out
# --------------------------------------------------------------------------- # # Custom-dataset helpers # --------------------------------------------------------------------------- #
[docs] def task_from_dataset( name: str, task_type: str, dataset: Any, *, epochs: int = 30, ) -> Task: """Wrap an already-loaded dataset as a :class:`Task`. The dataset must satisfy the conventions for ``task``: a PyG dataset or any object exposing ``ds[0]`` plus the relevant attributes (``num_features`` / ``num_classes`` / ``num_relations``). The benchmark dispatcher caches the dataset, so the same instance is reused across seeds without reloading. """ if task_type not in TASK_TYPES: msg = f"Unknown task {task_type!r}; choices: {sorted(TASK_TYPES)}" raise ValueError(msg) return Task(name=name, task_type=task_type, loader=lambda _root: dataset, epochs=epochs)
[docs] def register_task(category: str, task_type: Task) -> None: """Register ``task`` under ``category`` in :data:`BENCHMARK_TASKS`. The task becomes visible to ``run_benchmark(category)`` and to :func:`iter_benchmark_tasks`. Use :func:`unregister_task` to remove it (e.g. in ``tearDown`` of a test). """ if not isinstance(task_type, Task): msg = f"task must be a Task, got {type(task_type).__name__}" raise TypeError(msg) if task_type.task_type not in TASK_TYPES: msg = f"Task {task_type.name!r} has unknown task {task_type.task_type!r}; choices: {sorted(TASK_TYPES)}" raise ValueError(msg) per_cat = BENCHMARK_TASKS.setdefault(category, {}) per = per_cat.setdefault(task_type.task_type, []) if any(t.name == task_type.name for t in per): msg = f"Task {task_type.name!r} already registered in category {category!r}/{task_type.task_type!r}" raise ValueError(msg) per.append(task_type)
[docs] def unregister_task(category: str, name: str) -> Task | None: """Remove a previously registered task; returns it, or ``None`` if absent.""" per_cat = BENCHMARK_TASKS.get(category, {}) for task_tasks in per_cat.values(): for i, t in enumerate(task_tasks): if t.name == name: return task_tasks.pop(i) return None
# --------------------------------------------------------------------------- # # Statistical helpers # --------------------------------------------------------------------------- # def _ci_half_width(values: np.ndarray, ci: float = 0.95) -> float: """Half-width of a t-distribution confidence interval for the mean.""" n = values.size if n < 2: return 0.0 sem = stats.sem(values) return float(sem * stats.t.ppf((1 + ci) / 2, n - 1)) def _bootstrap_ci_half_width( values: np.ndarray, ci: float = 0.95, n_resamples: int = 10000, random_state: int = 0, ) -> float: """Half-width of a percentile-bootstrap CI for the mean. Robust for non-Gaussian metrics (e.g. Hits@K, MRR, AUC) where the Student's-t assumption is poor. Returns ``(hi - lo) / 2`` -- the half-width of a symmetric envelope with the same total width as the percentile interval, so callers reporting ``mean ± half`` recover the bootstrap interval's spread without inflating asymmetric tails. """ arr = np.asarray(values, dtype=float).ravel() n = arr.size if n < 2: return 0.0 rng = np.random.default_rng(random_state) idx = rng.integers(0, n, size=(n_resamples, n)) means = arr[idx].mean(axis=1) alpha = (1.0 - ci) / 2.0 lo, hi = np.quantile(means, [alpha, 1.0 - alpha]) return float((hi - lo) / 2.0) def _resolve_ci_half( values: np.ndarray, ci: float, method: str, n_resamples: int, random_state: int, ) -> float: if method == "t": return _ci_half_width(values, ci) if method == "bootstrap": return _bootstrap_ci_half_width(values, ci, n_resamples, random_state) msg = f"Unknown CI method: {method!r}; choices: 't', 'bootstrap'" raise ValueError(msg) def _paired_pvalue(a: np.ndarray, b: np.ndarray, method: str) -> float: """p-value of a paired test between two seed-aligned metric arrays. ``method="t"`` is the paired Student's t-test (parametric). ``method= "wilcoxon"`` is the Wilcoxon signed-rank test on the paired differences -- recommended at small seed counts where the paired t-test's normality assumption is most fragile (Benavoli et al., JMLR 2016). """ if a.size < 2 or b.size < 2 or a.size != b.size: return float("nan") if method == "t": return float(stats.ttest_rel(a, b).pvalue) if method == "wilcoxon": diffs = np.asarray(a, dtype=float) - np.asarray(b, dtype=float) # All-zero paired differences -> the signed-rank statistic has no # ranks to assign; return NaN so the row is reported as undefined # rather than as an artificial 1.0. if not np.any(diffs != 0): return float("nan") try: return float(stats.wilcoxon(diffs, zero_method="wilcox").pvalue) except ValueError: return float("nan") msg = f"Unknown pairwise method: {method!r}; choices: 't', 'wilcoxon'" raise ValueError(msg) def _holm_correction(p_values: np.ndarray) -> np.ndarray: """Holm step-down adjusted p-values (max-monotone). NaN inputs (e.g. tests that were undefined for that pair) are excluded from the rank table and propagated as NaN in the output; they are *not* counted toward the family size, so the remaining valid tests retain their proper power. """ p = np.asarray(p_values, dtype=float) n = p.size if n == 0: return p valid = ~np.isnan(p) n_valid = int(valid.sum()) adjusted = np.full(n, np.nan, dtype=float) if n_valid == 0: return adjusted valid_idx = np.where(valid)[0] p_valid = p[valid_idx] order = np.argsort(p_valid) running = 0.0 out_valid = np.empty(n_valid, dtype=float) for rank, idx in enumerate(order): adj = float(min(p_valid[idx] * (n_valid - rank), 1.0)) running = max(running, adj) out_valid[idx] = running adjusted[valid_idx] = out_valid return adjusted def _auto_metric_key(history: Mapping[str, Any]) -> str: for key in _METRIC_KEYS: if key in history: return key return next(iter(history)) def _final_metric(history: Mapping[str, list[float]]) -> tuple[str, float]: key = _auto_metric_key(history) return key, history[key][-1] # --------------------------------------------------------------------------- # # Task runner # --------------------------------------------------------------------------- # def _run_task( task_type: Task, ds: Any, spec: ModelSpec, hidden: int, epochs: int, verbose: bool = False, device: torch.device | str | None = "auto", ) -> dict[str, list[float]]: if task_type.task_type == "node_cls": data = ds[0] model = spec.build(ds.num_features, hidden, ds.num_classes, task_type="node_cls") return train_node_classification(model, data, epochs=epochs, verbose=verbose, device=device) if task_type.task_type == "graph_cls": shuffled = ds.shuffle() split = int(0.8 * len(shuffled)) train_loader = DataLoader(shuffled[:split], batch_size=32, shuffle=True) val_loader = DataLoader(shuffled[split:], batch_size=32) model = spec.build(shuffled.num_features, hidden, shuffled.num_classes, task_type="graph_cls") return train_graph_classification( model, train_loader, val_loader, epochs=epochs, verbose=verbose, device=device ) if task_type.task_type == "graph_reg": # Loader may return either a single dataset (used for both train and # held-out -- e.g. synthetic tasks with no canonical split) or a # ``(train_ds, val_ds)`` tuple (real benchmarks like ZINC). if isinstance(ds, tuple): train_ds, val_ds = ds else: train_ds = val_ds = ds train_loader = DataLoader(train_ds, batch_size=64, shuffle=True) val_loader = DataLoader(val_ds, batch_size=64) inner = spec.build(hidden, hidden, 1, task_type="graph_reg") class _AtomEmbed(torch.nn.Module): def __init__(self, num_atoms: int = 32) -> None: super().__init__() self.embed = torch.nn.Embedding(num_atoms, hidden) self.inner = inner def forward(self, batch: Any) -> torch.Tensor: # Embed only integer atom-type ids (e.g. ZINC); pass-through # any float feature matrix unchanged so we never silently # truncate continuous features via .long(). if not batch.x.dtype.is_floating_point: batch = batch.clone() batch.x = self.embed(batch.x.view(-1).long()) return self.inner(batch) return train_graph_regression( _AtomEmbed(), train_loader, val_loader, epochs=epochs, verbose=verbose, device=device ) if task_type.task_type == "link_pred": import math from torch_geometric.transforms import RandomLinkSplit from torch_geometric.utils import degree data = ds[0] def _fabricate_log_degree_features(d: Data, edge_index: torch.Tensor) -> Data: """Build a 3-D log-degree feature from `edge_index` only. Used as the fallback when a loader ships no node features. We keep the source of degree restricted to the caller-supplied `edge_index` so val/test edges never bleed into the feature matrix at training time. """ n = int(d.num_nodes) deg = degree(edge_index[0], num_nodes=n, dtype=torch.float) log_deg = torch.log1p(deg) / math.log(max(n, 2)) ones = torch.ones(n) out = d.clone() out.x = torch.stack([log_deg, log_deg.pow(2), ones], dim=1) return out # Relational link prediction (knowledge graphs with edge_type) is # detected on the *raw* data because PyG's RelLinkPredDataset # already restricts edge_index to training edges -- we only # fabricate features when missing, using train_edge_index. if hasattr(data, "edge_type") and hasattr(data, "train_edge_index"): from graphnetz.models._adapters import RelationalLinkPredWrapper from graphnetz.training import train_relational_link_prediction if getattr(data, "x", None) is None: data = _fabricate_log_degree_features(data, data.train_edge_index) num_relations = ds.num_relations if hasattr(ds, "num_relations") else int(data.edge_type.max()) + 1 built = spec.build(data.num_features, hidden, hidden, task_type="link_pred") # spec.build returns a LinkPredWrapper for task_type="link_pred"; unwrap it # so RelationalLinkPredWrapper can drive the bare encoder directly # (otherwise its forward expects data.edge_label_index). from typing import cast encoder = cast( torch.nn.Module, built.encoder if hasattr(built, "encoder") else built, ) model = RelationalLinkPredWrapper(encoder, hidden, num_relations) # Create separate Data objects for train/val/test splits train_split = Data( x=data.x, edge_index=data.train_edge_index, edge_type=data.train_edge_type, num_nodes=data.num_nodes ) val_split = Data( x=data.x, edge_index=data.valid_edge_index, edge_type=data.valid_edge_type, num_nodes=data.num_nodes ) test_split = Data( x=data.x, edge_index=data.test_edge_index, edge_type=data.test_edge_type, num_nodes=data.num_nodes ) return train_relational_link_prediction( model, train_split, val_split, test_split, epochs=epochs, verbose=verbose, device=device, ) # Detect graph direction from the data itself instead of forcing # ``is_undirected=True`` -- on a directed graph the latter silently # de-duplicates reciprocal edges and halves the supervision signal. is_undirected = not bool(data.is_directed()) transform = RandomLinkSplit( num_val=0.05, num_test=0.10, is_undirected=is_undirected, add_negative_train_samples=True, neg_sampling_ratio=1.0, ) train_data, val_data, test_data = transform(data) # Fabricate features *after* the split so val/test edges never # leak into the node features the encoder consumes. Use only the # training message-passing edges (edge_index, not edge_label_index) # for the degree statistic. if getattr(train_data, "x", None) is None: train_data = _fabricate_log_degree_features(train_data, train_data.edge_index) val_data = val_data.clone() val_data.x = train_data.x test_data = test_data.clone() test_data.x = train_data.x # ``spec.build(task_type="link_pred")`` returns a LinkPredWrapper, which # satisfies the ``_LinkPredLike`` protocol of the trainer; mypy # only sees the declared ``Module`` return so we narrow here. from typing import cast as _cast from graphnetz.training import _LinkPredLike lp_model = _cast(_LinkPredLike, spec.build(train_data.num_features, hidden, hidden, task_type="link_pred")) return train_link_prediction( lp_model, train_data, val_data, test_data, epochs=epochs, verbose=verbose, device=device ) msg = f"Unknown task task_type: {task_type.task_type}" raise ValueError(msg) # --------------------------------------------------------------------------- # # Benchmark report # --------------------------------------------------------------------------- #
[docs] @dataclass class BenchmarkReport: """Structured outcome of a multi-seed benchmark run. ``histories[task][model]`` is a list with one history dict per seed (in seed order). The report is also a read-only mapping ``task -> {model: history_seed_0}`` for backward compatibility with single-seed callers. """ seeds: tuple[int, ...] histories: dict[str, dict[str, list[dict[str, list[float]]]]] config: dict[str, Any] = field(default_factory=dict) ci_method: str = "t" bootstrap_n: int = 10000 bootstrap_seed: int = 0 pairwise_method: str = "t" def _ci_half( self, values: np.ndarray, ci: float, method: str | None = None, ) -> float: return _resolve_ci_half( values, ci, method or self.ci_method, self.bootstrap_n, self.bootstrap_seed, ) # ----- Pickle compatibility --------------------------------------------- def __setstate__(self, state: dict[str, Any]) -> None: """Restore from pickle, backfilling fields added since serialisation. Older :class:`BenchmarkReport` pickles predate the ``ci_method`` / ``bootstrap_*`` / ``pairwise_method`` fields. ``__setstate__`` ensures they load cleanly with sensible defaults so the experiment cache (``_cache_*.pkl``) survives library upgrades. """ self.__dict__.update(state) self.__dict__.setdefault("ci_method", "t") self.__dict__.setdefault("bootstrap_n", 10000) self.__dict__.setdefault("bootstrap_seed", 0) self.__dict__.setdefault("pairwise_method", "t") self.__dict__.setdefault("config", {}) # ----- Mapping protocol (backward compat with the legacy dict shape) ----- def __iter__(self): return iter(self.histories) def __len__(self) -> int: return len(self.histories) def __getitem__(self, task_type: str) -> dict[str, dict[str, list[float]]]: per_task = self.histories[task_type] return {model: per_task[model][0] for model in per_task}
[docs] def items(self): for task in self.histories: yield task, self[task]
[docs] def keys(self): return self.histories.keys()
[docs] def values(self): return [self[task] for task in self.histories]
# ----- Statistics --------------------------------------------------------
[docs] def final_metrics(self, key: str | None = None) -> dict[str, dict[str, list[float]]]: """Final metric value per (task, model, seed).""" out: dict[str, dict[str, list[float]]] = {} for task, per_task in self.histories.items(): out[task] = {} for model, seed_histories in per_task.items(): vals: list[float] = [] for h in seed_histories: k = key or _auto_metric_key(h) vals.append(float(h[k][-1])) out[task][model] = vals return out
[docs] def metric_name(self) -> str: for per_task in self.histories.values(): for seed_histories in per_task.values(): if seed_histories: return _auto_metric_key(seed_histories[0]) return "metric"
[docs] def summary(self, ci: float = 0.95, method: str | None = None) -> pd.DataFrame: """Per-(task, model) mean, std, sem, CI half-width and bounds. ``method`` overrides ``self.ci_method`` for this call only; choose ``"t"`` for Student's-t intervals (default) or ``"bootstrap"`` for percentile-bootstrap intervals (better for non-Gaussian metrics such as Hits@K, MRR, or AUC). """ rows = [] for task, per_task in self.final_metrics().items(): for model, values in per_task.items(): arr = np.asarray(values, dtype=float) mean = float(arr.mean()) std = float(arr.std(ddof=1)) if arr.size > 1 else 0.0 sem = float(stats.sem(arr)) if arr.size > 1 else 0.0 half = self._ci_half(arr, ci, method=method) rows.append( { "task": task, "model": model, "n_seeds": arr.size, "mean": mean, "std": std, "sem": sem, "ci_low": mean - half, "ci_high": mean + half, } ) return pd.DataFrame(rows).set_index(["task", "model"]).sort_index()
[docs] def pairwise(self, alpha: float = 0.05, method: str | None = None) -> pd.DataFrame: """Paired pairwise tests between models per task with Holm adjustment. ``method`` overrides ``self.pairwise_method`` for this call only: - ``"t"`` (default) -- paired Student's t-test on per-seed final metrics. - ``"wilcoxon"`` -- non-parametric Wilcoxon signed-rank test on the paired differences. Recommended at small seed counts where the paired t-test's normality assumption is most fragile; see Benavoli et al., *JMLR* 17(5):1-36, 2016. """ finals = self.final_metrics() test = method or self.pairwise_method rows = [] for task, per_task in finals.items(): models = sorted(per_task) pairs: list[tuple[str, str, float, float]] = [] ps: list[float] = [] for i, model_a in enumerate(models): for model_b in models[i + 1 :]: a = np.asarray(per_task[model_a], dtype=float) b = np.asarray(per_task[model_b], dtype=float) p = _paired_pvalue(a, b, test) pairs.append((model_a, model_b, float(a.mean() - b.mean()), p)) ps.append(p) adj = _holm_correction(np.asarray(ps, dtype=float)) for (model_a, model_b, diff, p_raw), p_holm in zip(pairs, adj, strict=False): rows.append( { "task": task, "model_a": model_a, "model_b": model_b, "mean_diff": diff, "p_raw": p_raw, "p_holm": p_holm, "significant": (not np.isnan(p_holm)) and p_holm < alpha, } ) return pd.DataFrame(rows)
[docs] def friedman(self, alpha: float = 0.05) -> dict[str, float | int | bool]: r"""Friedman omnibus test on per-task ranks of seed-mean metrics. Returns a dict with the statistic ``chi2``, the asymptotic $\chi^2_{k-1}$ p-value, the rejection flag at ``alpha``, and the $(k, N)$ shape used. The Nemenyi post-hoc surfaced in :meth:`plot_critical_difference` should only be interpreted when ``rejected`` is true (Demšar, 2006). Only models present in every task are included; per-task ranks use the metric direction (lower-is-better for ``val_mae`` and ``train_loss``). """ finals = self.final_metrics() if not finals: return {"chi2": float("nan"), "p_value": float("nan"), "k": 0, "n": 0, "rejected": False} common: set[str] = set.intersection(*[set(per.keys()) for per in finals.values()]) if len(common) < 2 or len(finals) < 2: return { "chi2": float("nan"), "p_value": float("nan"), "k": len(common), "n": len(finals), "rejected": False, } models = sorted(common) tasks = sorted(finals) means = np.array([[float(np.mean(finals[t][m])) for m in models] for t in tasks]) rows: list[np.ndarray] = [] for i, task in enumerate(tasks): sample = next(iter(self.histories[task].values()))[0] sign = 1.0 if _auto_metric_key(sample) in _LOWER_IS_BETTER else -1.0 rows.append(stats.rankdata(sign * means[i], method="average")) ranks = np.array(rows) k = len(models) n = len(tasks) avg = ranks.mean(axis=0) chi2 = (12.0 * n) / (k * (k + 1)) * (float(np.sum(avg**2)) - k * (k + 1) ** 2 / 4.0) p = float(stats.chi2.sf(chi2, df=k - 1)) return {"chi2": float(chi2), "p_value": p, "k": k, "n": n, "rejected": bool(p < alpha)}
# ----- Reporting helpers ------------------------------------------------- def _best_per_task(self) -> dict[str, str]: finals = self.final_metrics() metric = self.metric_name() lower_is_better = metric in _LOWER_IS_BETTER best: dict[str, str] = {} for task, per_task in finals.items(): scored = [(model, float(np.mean(values))) for model, values in per_task.items()] if lower_is_better: best[task] = min(scored, key=lambda x: x[1])[0] else: best[task] = max(scored, key=lambda x: x[1])[0] return best
[docs] def to_latex( self, path: str | Path, *, ci: float = 0.95, bold_best: bool = True, pretty_tasks: Mapping[str, str] | None = None, caption: str | None = None, label: str | None = None, method: str | None = None, ) -> Path: """Booktabs LaTeX table of mean ± CI half-width with bold-best per task. ``method`` overrides ``self.ci_method`` (``"t"`` or ``"bootstrap"``). """ finals = self.final_metrics() tasks = sorted(finals) models = sorted({m for per in finals.values() for m in per}) best = self._best_per_task() if bold_best else {} pretty = dict(pretty_tasks or {}) lines: list[str] = [] if caption is not None or label is not None: lines.extend([r"\begin{table}[t]", r" \centering"]) if caption is not None: lines.append(rf" \caption{{{caption}}}") if label is not None: lines.append(rf" \label{{{label}}}") lines.append(r"\begin{tabular}{l" + "c" * len(tasks) + "}") lines.append(r"\toprule") header = "Model & " + " & ".join(pretty.get(t, t) for t in tasks) + r" \\" lines.append(header) lines.append(r"\midrule") for model in models: cells = [] for task in tasks: values = np.asarray(finals[task].get(model, []), dtype=float) if values.size == 0: cells.append("--") continue mean = float(values.mean()) half = self._ci_half(values, ci, method=method) if bold_best and best.get(task) == model: cell = rf"$\mathbf{{{mean:.3f} \pm {half:.3f}}}$" else: cell = rf"${mean:.3f} \pm {half:.3f}$" cells.append(cell) lines.append(f"{model} & " + " & ".join(cells) + r" \\") lines.append(r"\bottomrule") lines.append(r"\end{tabular}") if caption is not None or label is not None: lines.append(r"\end{table}") out = Path(path) out.parent.mkdir(parents=True, exist_ok=True) out.write_text("\n".join(lines) + "\n") return out
[docs] def pairwise_to_latex( self, path: str | Path, *, alpha: float = 0.05, caption: str | None = None, label: str | None = None, method: str | None = None, ) -> Path: """LaTeX booktabs table of pairwise Holm-adjusted p-values. ``method`` overrides ``self.pairwise_method`` (``"t"`` or ``"wilcoxon"``) for this call only. """ df = self.pairwise(alpha=alpha, method=method) lines: list[str] = [] if caption is not None or label is not None: lines.extend([r"\begin{table}[t]", r" \centering"]) if caption is not None: lines.append(rf" \caption{{{caption}}}") if label is not None: lines.append(rf" \label{{{label}}}") lines.append(r"\begin{tabular}{llcccl}") lines.append(r"\toprule") lines.append(r"Task & Comparison & $\Delta\mu$ & $p_{\text{raw}}$ & $p_{\text{Holm}}$ & Sig. \\") lines.append(r"\midrule") for _, row in df.iterrows(): sig = r"\textbf{*}" if row["significant"] else "" p_raw = "n/a" if pd.isna(row["p_raw"]) else f"{row['p_raw']:.3g}" p_holm = "n/a" if pd.isna(row["p_holm"]) else f"{row['p_holm']:.3g}" lines.append( f"{row['task']} & {row['model_a']} vs.\\ {row['model_b']} & " f"${row['mean_diff']:+.3f}$ & {p_raw} & {p_holm} & {sig} \\\\" ) lines.append(r"\bottomrule") lines.append(r"\end{tabular}") if caption is not None or label is not None: lines.append(r"\end{table}") out = Path(path) out.parent.mkdir(parents=True, exist_ok=True) out.write_text("\n".join(lines) + "\n") return out
# ----- Plotting ----------------------------------------------------------
[docs] def plot( self, ax: plt.Axes | None = None, *, ci: float = 0.95, ylabel: str | None = None, title: str | None = None, annotate: bool = True, pretty_tasks: Mapping[str, str] | None = None, ) -> tuple[plt.Figure, plt.Axes]: """Grouped bar chart of mean ± CI half-width across seeds.""" finals = self.final_metrics() pretty = dict(pretty_tasks or {}) values: dict[str, dict[str, float]] = {} errors: dict[str, dict[str, float]] = {} for task, per_task in finals.items(): label = pretty.get(task, task) values[label] = {} errors[label] = {} for model, vals in per_task.items(): arr = np.asarray(vals, dtype=float) values[label][model] = float(arr.mean()) errors[label][model] = self._ci_half(arr, ci) return plot_grouped_bars( values, errors=errors, ax=ax, title=title, ylabel=ylabel or self.metric_name(), annotate=annotate, )
[docs] def plot_forest( self, ax: plt.Axes | None = None, *, ci: float = 0.95, pretty_tasks: Mapping[str, str] | None = None, xlabel: str | None = None, height_per_task: float = 0.42, sort_within: bool = False, band: bool = True, ) -> tuple[plt.Figure, plt.Axes]: """Forest plot, one row per task with models jittered within the row. Height scales with the number of *tasks* only -- adding more models widens the within-row jitter rather than adding new rows -- so the figure stays compact for many models. ``sort_within=True`` orders the jittered positions per-task so the best mean lands at the top of the row (helps spot leaders when there are many models). Each model keeps a stable colour across tasks. ``band=True`` shades alternating task rows (banded reading aid). """ from graphnetz.plotting import COLUMN_INCHES set_plot_style() finals = self.final_metrics() tasks = sorted(finals) models = sorted({m for per in finals.values() for m in per}) pretty = dict(pretty_tasks or {}) n_tasks = len(tasks) n_models = len(models) metric = self.metric_name() lower_is_better = metric in _LOWER_IS_BETTER if ax is None: height = max(1.6, height_per_task * n_tasks + 1.0) fig, ax = plt.subplots(figsize=(COLUMN_INCHES["single"] * 1.05, height)) else: fig = ax.figure # type: ignore[assignment] jitter_span = 0.7 slot_positions = ( np.linspace(-jitter_span / 2, jitter_span / 2, n_models) if n_models > 1 else np.zeros(n_models) ) # Precompute per-task offsets (mapping model_index -> within-row offset). per_task_offset: dict[str, dict[str, float]] = {} for task in tasks: present = [m for m in models if m in finals[task]] if sort_within and len(present) > 1: means = np.array([float(np.mean(finals[task][m])) for m in present]) order = np.argsort(means if lower_is_better else -means) ordered = [present[i] for i in order] else: ordered = present row_offsets = ( np.linspace(-jitter_span / 2, jitter_span / 2, len(ordered)) if len(ordered) > 1 else np.zeros(len(ordered)) ) per_task_offset[task] = dict(zip(ordered, row_offsets, strict=False)) if band: for i in range(n_tasks): if i % 2 == 0: ax.axhspan( i - 0.5, i + 0.5, facecolor="0.96", edgecolor="none", zorder=0, ) for j, model in enumerate(models): xs: list[float] = [] ys: list[float] = [] errs: list[float] = [] for i, task in enumerate(tasks): if model not in finals[task]: continue arr = np.asarray(finals[task][model], dtype=float) xs.append(float(arr.mean())) offset = per_task_offset[task].get(model, slot_positions[j]) ys.append(i + offset) errs.append(self._ci_half(arr, ci)) if xs: color = NATURE_COLORS[j % len(NATURE_COLORS)] ax.errorbar( xs, ys, xerr=[errs, errs], fmt="o", color=color, ecolor=color, elinewidth=1.0, capsize=2.0, markersize=3.5, label=model, zorder=3, ) for i in range(n_tasks - 1): ax.axhline(i + 0.5, color="0.85", linewidth=0.3, zorder=1) ax.set_yticks(range(n_tasks)) ax.set_yticklabels([pretty.get(t, t) for t in tasks]) ax.set_ylim(n_tasks - 0.5, -0.5) ax.set_xlabel(xlabel or metric) ax.set_axisbelow(True) ax.xaxis.grid(True, linewidth=0.3, alpha=0.4, zorder=1) ax.legend( loc="lower center", bbox_to_anchor=(0.5, 1.02), ncol=min(n_models, 4), frameon=False, handlelength=1.2, handletextpad=0.4, columnspacing=1.0, ) fig.tight_layout() return fig, ax
[docs] def plot_pairwise( self, ax: plt.Axes | None = None, *, ci: float = 0.95, alpha: float = 0.05, pretty_tasks: Mapping[str, str] | None = None, layout: str = "matrix", max_cols: int = 3, method: str | None = None, ) -> tuple[plt.Figure, Any]: """Pairwise comparison plot, with two layouts that scale differently. ``layout="matrix"`` (default) -- one significance heatmap per task, with the lower triangle holding $-\\log_{10}(p_{\\text{Holm}})$ and the upper triangle holding the signed mean difference. Scales to many models and many tasks (panels arranged in a grid with at most ``max_cols`` columns). ``layout="list"`` -- one row per pairwise comparison with CI whiskers and a significance marker. Best for small numbers of comparisons. ``method`` overrides ``self.pairwise_method`` (``"t"`` or ``"wilcoxon"``) for this call only. """ if layout == "list": return self._plot_pairwise_list(ax=ax, ci=ci, alpha=alpha, pretty_tasks=pretty_tasks, method=method) if layout == "matrix": return self._plot_pairwise_matrix( ci=ci, alpha=alpha, pretty_tasks=pretty_tasks, max_cols=max_cols, method=method ) msg = f"Unknown pairwise layout: {layout!r}; choices: 'matrix', 'list'" raise ValueError(msg)
def _plot_pairwise_matrix( self, *, ci: float = 0.95, alpha: float = 0.05, pretty_tasks: Mapping[str, str] | None = None, max_cols: int = 3, method: str | None = None, ) -> tuple[plt.Figure, np.ndarray]: from matplotlib.colors import TwoSlopeNorm from graphnetz.plotting import COLUMN_INCHES set_plot_style() finals = self.final_metrics() df = self.pairwise(alpha=alpha, method=method) tasks = sorted(finals) pretty = dict(pretty_tasks or {}) n_tasks = len(tasks) # Per-task model lists (intersection used for matrix axes). per_task_models = {t: sorted(finals[t]) for t in tasks} max_models = max((len(per_task_models[t]) for t in tasks), default=0) if max_models < 2: fig, ax = plt.subplots() ax.text(0.5, 0.5, "fewer than two models per task", ha="center", va="center", transform=ax.transAxes) ax.axis("off") return fig, np.array([[ax]]) cols = max(1, min(max_cols, n_tasks)) rows = (n_tasks + cols - 1) // cols fig_w = COLUMN_INCHES["double"] if cols > 1 else COLUMN_INCHES["single"] cell = max(0.42, 1.4 / max_models) fig_h = (cell * max_models + 1.4) * rows fig, axes_obj = plt.subplots(rows, cols, figsize=(fig_w, fig_h), squeeze=False) diff_max = max(1e-9, df["mean_diff"].abs().max() if not df.empty else 1.0) norm = TwoSlopeNorm(vmin=-diff_max, vcenter=0, vmax=diff_max) for k, task in enumerate(tasks): r, c = divmod(k, cols) ax = axes_obj[r, c] models_t = per_task_models[task] n = len(models_t) mat = np.full((n, n), np.nan) # lower: -log10(p), upper: mean diff sub = df[df["task"] == task] if not df.empty else df for _, row in sub.iterrows(): ia = models_t.index(row["model_a"]) ib = models_t.index(row["model_b"]) if ia == ib: continue lo, hi = (ia, ib) if ia < ib else (ib, ia) p_holm = row["p_holm"] if not np.isnan(p_holm): mat[hi, lo] = -np.log10(max(p_holm, 1e-12)) mat[lo, hi] = row["mean_diff"] if ia < ib else -row["mean_diff"] mask_lower = np.tri(n, n, -1, dtype=bool) mask_upper = mask_lower.T # Two passes: lower triangle (significance), upper triangle (effect). lower = np.where(mask_lower, mat, np.nan) upper = np.where(mask_upper, mat, np.nan) ax.imshow(lower, cmap="Greys", vmin=0.0, vmax=3.0, aspect="equal") ax.imshow(upper, cmap="RdBu_r", norm=norm, aspect="equal") # Annotate cells. for i in range(n): for j in range(n): if i == j: continue if mask_lower[i, j]: # significance cell: show p_holm sub_match = sub[ ((sub["model_a"] == models_t[j]) & (sub["model_b"] == models_t[i])) | ((sub["model_a"] == models_t[i]) & (sub["model_b"] == models_t[j])) ] if sub_match.empty: continue p = float(sub_match["p_holm"].iloc[0]) text = "n/a" if np.isnan(p) else f"{p:.2g}" is_sig = (not np.isnan(p)) and p < alpha if is_sig: text += "*" color = "white" if (not np.isnan(p) and -np.log10(max(p, 1e-12)) > 1.5) else "black" weight = "bold" if is_sig else "normal" ax.text(j, i, text, ha="center", va="center", fontsize=6, color=color, fontweight=weight) elif mask_upper[i, j]: d = mat[i, j] if np.isnan(d): continue color = "white" if abs(d) > 0.6 * diff_max else "black" ax.text(j, i, f"{d:+.2f}", ha="center", va="center", fontsize=6, color=color) ax.set_xticks(range(n)) ax.set_yticks(range(n)) ax.set_xticklabels(models_t, rotation=30, ha="right") ax.set_yticklabels(models_t) ax.set_xticks([], minor=True) ax.set_yticks([], minor=True) ax.tick_params(which="both", length=0) for spine in ax.spines.values(): spine.set_visible(False) ax.set_title(pretty.get(task, task)) # Hide unused panels. for k in range(n_tasks, rows * cols): r, c = divmod(k, cols) axes_obj[r, c].axis("off") # Caption-style legend strip. fig.suptitle( r"lower: $-\log_{10}(p_{\mathrm{Holm}})$ (darker = more significant, $*$ = $p<\alpha$);" r" upper: mean difference (row $-$ column, red $>0$, blue $<0$)", y=0.02, fontsize=7, ) fig.tight_layout(rect=(0, 0.05, 1, 1)) return fig, axes_obj def _plot_pairwise_list( self, ax: plt.Axes | None = None, *, ci: float = 0.95, alpha: float = 0.05, pretty_tasks: Mapping[str, str] | None = None, method: str | None = None, ) -> tuple[plt.Figure, plt.Axes]: from matplotlib.lines import Line2D from graphnetz.plotting import COLUMN_INCHES set_plot_style() finals = self.final_metrics() df = self.pairwise(alpha=alpha, method=method) if df.empty: fig, ax = plt.subplots() ax.text(0.5, 0.5, "no pairwise comparisons", ha="center", va="center", transform=ax.transAxes) return fig, ax pretty = dict(pretty_tasks or {}) rows: list[tuple[str, str, str, float, float, bool]] = [] for _, row in df.iterrows(): a = np.asarray(finals[row["task"]][row["model_a"]], dtype=float) b = np.asarray(finals[row["task"]][row["model_b"]], dtype=float) diff_per_seed = a - b mean = float(diff_per_seed.mean()) half = self._ci_half(diff_per_seed, ci) rows.append( ( pretty.get(row["task"], row["task"]), row["model_a"], row["model_b"], mean, half, bool(row["significant"]), ) ) if ax is None: fig, ax = plt.subplots(figsize=(COLUMN_INCHES["single"], 0.34 * len(rows) + 0.6)) else: fig = ax.figure # type: ignore[assignment] ytick_positions: list[float] = [] ytick_labels: list[str] = [] for i, (task_label, ma, mb, mean, half, sig) in enumerate(rows): color = NATURE_COLORS[0] if sig else NATURE_COLORS[3] ax.errorbar( mean, i, xerr=[[half], [half]], fmt="o" if sig else "s", color=color, ecolor=color, elinewidth=1.0, capsize=2.0, markersize=4.0 if sig else 3.0, ) ytick_positions.append(i) ytick_labels.append(f"{task_label}: {ma} - {mb}") ax.axvline(0, color="0.4", linewidth=0.6, linestyle="--") ax.set_yticks(ytick_positions) ax.set_yticklabels(ytick_labels) ax.invert_yaxis() ax.set_xlabel(r"Mean difference (95% CI, paired)") ax.set_axisbelow(True) ax.xaxis.grid(True, linewidth=0.3, alpha=0.4) legend_handles = [ Line2D( [0], [0], marker="o", color=NATURE_COLORS[0], linestyle="", markersize=4.0, label=rf"$p_{{\mathrm{{Holm}}}} < {alpha}$", ), Line2D([0], [0], marker="s", color=NATURE_COLORS[3], linestyle="", markersize=3.0, label="not significant"), ] ax.legend( handles=legend_handles, loc="lower center", bbox_to_anchor=(0.5, 1.02), ncol=2, frameon=False, handlelength=1.2, handletextpad=0.4, ) fig.tight_layout() return fig, ax
[docs] def plot_critical_difference( self, *, alpha: float = 0.05, title: str | None = None, ) -> tuple[plt.Figure, plt.Axes]: r"""Demšar critical-difference (CD) diagram. Computes mean ranks of every model across tasks and overlays the Nemenyi critical difference at level ``alpha``. Models within ``CD`` of each other are joined by a thick horizontal "clique" bar (i.e., not significantly different). This is the canonical scalable visualization for multi-method, multi-dataset benchmarks (Demšar, 2006). Only models present in *every* task are included. Requires at least two tasks and at least two such models. """ from scipy.stats import studentized_range from graphnetz.plotting import COLUMN_INCHES set_plot_style() finals = self.final_metrics() common: set[str] = set.intersection(*[set(per.keys()) for per in finals.values()]) if finals else set() if len(common) < 2 or len(finals) < 2: fig, ax = plt.subplots(figsize=(COLUMN_INCHES["single"], 1.6)) ax.text( 0.5, 0.5, "CD diagram needs >= 2 tasks and >= 2 models common to all tasks", ha="center", va="center", transform=ax.transAxes, fontsize=8, ) ax.axis("off") return fig, ax models = sorted(common) tasks = sorted(finals) means = np.array([[float(np.mean(finals[t][m])) for m in models] for t in tasks]) # Direction (lower-is-better) is detected *per task* so the CD # diagram is correct on heterogeneous benchmarks where some tasks # use accuracy (higher better) and others use loss (lower better). rows: list[np.ndarray] = [] for i, task in enumerate(tasks): sample = next(iter(self.histories[task].values()))[0] task_metric = _auto_metric_key(sample) sign = 1.0 if task_metric in _LOWER_IS_BETTER else -1.0 rows.append(stats.rankdata(sign * means[i], method="average")) ranks = np.array(rows) avg_ranks = ranks.mean(axis=0) # Ranks are always lower-is-better by construction. k = len(models) n = len(tasks) # Friedman omnibus: only interpret Nemenyi after the global null is # rejected (Demšar, 2006). We compute it from the same rank table. avg_for_chi2 = ranks.mean(axis=0) chi2 = (12.0 * n) / (k * (k + 1)) * (float(np.sum(avg_for_chi2**2)) - k * (k + 1) ** 2 / 4.0) friedman_p = float(stats.chi2.sf(chi2, df=k - 1)) friedman_rejected = friedman_p < alpha q = float(studentized_range.ppf(1 - alpha, k, np.inf) / np.sqrt(2)) cd = q * float(np.sqrt(k * (k + 1) / (6 * n))) order = np.argsort(avg_ranks) sorted_models = [models[i] for i in order] sorted_ranks = avg_ranks[order] # Maximal cliques: contiguous runs in rank order whose span < CD. cliques_raw: list[tuple[int, int]] = [] i = 0 while i < k: j = i while j + 1 < k and sorted_ranks[j + 1] - sorted_ranks[i] < cd: j += 1 if j > i: cliques_raw.append((i, j)) i += 1 cliques: list[tuple[int, int]] = [] for a, b in sorted(set(cliques_raw)): if any(c <= a and b <= d for c, d in cliques): continue cliques = [(c, d) for c, d in cliques if not (a <= c and d <= b)] cliques.append((a, b)) # Layout coordinates. fig_w = COLUMN_INCHES["double"] fig_h = max(2.2, 1.6 + 0.22 * k) fig, ax = plt.subplots(figsize=(fig_w, fig_h)) rank_y = 0.0 x_min, x_max = 1.0, float(k) ax.plot([x_min, x_max], [rank_y, rank_y], color="black", linewidth=0.8) for r in range(int(x_min), int(x_max) + 1): ax.plot([r, r], [rank_y, rank_y - 0.04], color="black", linewidth=0.6) ax.text(r, rank_y - 0.08, f"{r}", ha="center", va="top", fontsize=8) # Method leaders + side labels (left for top half, right for bottom half). half = (k + 1) // 2 label_y_step = 0.16 label_y_top = 0.32 label_x_left = x_min - 0.5 label_x_right = x_max + 0.5 for idx, (model, r) in enumerate(zip(sorted_models, sorted_ranks, strict=False)): color = NATURE_COLORS[idx % len(NATURE_COLORS)] if idx < half: label_x = label_x_left ha = "right" ly = label_y_top + (half - idx - 1) * label_y_step else: label_x = label_x_right ha = "left" ly = label_y_top + (idx - half) * label_y_step ax.plot([r, r], [rank_y, ly], color="0.55", linewidth=0.5, zorder=1) ax.plot([r, label_x], [ly, ly], color="0.55", linewidth=0.5, zorder=1) ax.plot([r], [rank_y], marker="o", markersize=3.5, color=color, zorder=2) ax.text( label_x + (-0.05 if ha == "right" else 0.05), ly, f"{model} ({r:.2f})", va="center", ha=ha, fontsize=8, color=color, ) # Clique bars below the rank axis (start below the tick labels). bar_y = rank_y - 0.16 for a, b in cliques: ax.plot( [sorted_ranks[a] - 0.06, sorted_ranks[b] + 0.06], [bar_y, bar_y], color="black", linewidth=3.5, solid_capstyle="round", zorder=3, ) bar_y -= 0.06 # CD scale at the top. cd_y = label_y_top + max(half - 1, 0) * label_y_step + 0.22 ax.plot([x_min, x_min + cd], [cd_y, cd_y], color="black", linewidth=1.0) ax.plot([x_min, x_min], [cd_y - 0.025, cd_y + 0.025], color="black", linewidth=1.0) ax.plot( [x_min + cd, x_min + cd], [cd_y - 0.025, cd_y + 0.025], color="black", linewidth=1.0, ) ax.text( x_min + cd / 2, cd_y + 0.04, rf"CD = {cd:.3f} (Nemenyi, $\alpha={alpha}$, $k={k}$, $N={n}$)", ha="center", va="bottom", fontsize=8, ) friedman_color = "0.15" if friedman_rejected else "0.4" ax.text( x_min + cd / 2, cd_y + 0.18, rf"Friedman $\chi^2_{{{k - 1}}} = {chi2:.2f}$, $p = {friedman_p:.3g}$" + (" (reject)" if friedman_rejected else " (do not reject)"), ha="center", va="bottom", fontsize=7, color=friedman_color, ) # Direction caption below all clique bars. caption_y = bar_y - 0.04 ax.text( (x_min + x_max) / 2, caption_y, "Mean rank (lower rank = better)", ha="center", va="top", fontsize=8, color="0.3", ) ax.set_xlim(label_x_left - 1.2, label_x_right + 1.2) ax.set_ylim(caption_y - 0.12, cd_y + 0.2) ax.axis("off") if title is not None: ax.set_title(title) fig.tight_layout() return fig, ax
[docs] def plot_learning_curves( self, *, ci: float = 0.95, metric_key: str | None = None, pretty_tasks: Mapping[str, str] | None = None, ylabel: str = "Test accuracy", legend_loc: str = "lower right", ) -> tuple[plt.Figure, np.ndarray]: """Mean ± t-CI learning curves, one panel per task, sharing y-axis.""" set_plot_style() from graphnetz.plotting import COLUMN_INCHES, panel_label tasks = list(self.histories) ncols = max(len(tasks), 1) width = COLUMN_INCHES["double"] height = width / 2.6 fig, axes_obj = plt.subplots(1, ncols, figsize=(width, height), sharey=True) axes = np.atleast_1d(axes_obj) pretty = dict(pretty_tasks or {}) for idx, task in enumerate(tasks): ax = axes[idx] per_task = self.histories[task] for j, model in enumerate(per_task): seed_histories = per_task[model] if not seed_histories: continue key = metric_key or _auto_metric_key(seed_histories[0]) arr = np.array([h[key] for h in seed_histories], dtype=float) mean = arr.mean(axis=0) n = arr.shape[0] if n > 1: sem = arr.std(axis=0, ddof=1) / np.sqrt(n) half = sem * stats.t.ppf((1 + ci) / 2, n - 1) else: half = np.zeros_like(mean) epochs_axis = np.arange(1, mean.size + 1) color = NATURE_COLORS[j % len(NATURE_COLORS)] ax.plot(epochs_axis, mean, color=color, label=model, linewidth=1.2) ax.fill_between(epochs_axis, mean - half, mean + half, color=color, alpha=0.2, linewidth=0) ax.set_xlabel("Epoch") ax.set_title(pretty.get(task, task)) ax.set_axisbelow(True) ax.yaxis.grid(True, linewidth=0.3, alpha=0.4) if idx == 0: ax.set_ylabel(ylabel) ax.legend(loc=legend_loc, borderaxespad=0.4) else: ax.tick_params(labelleft=False) panel_label(ax, "abcdefgh"[idx]) fig.tight_layout() return fig, axes
# --------------------------------------------------------------------------- # # Driver # --------------------------------------------------------------------------- # def _seed_all(seed: int) -> None: """Seed every RNG that benchmark training touches.""" import random random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def _normalize_seeds( seeds: int | Iterable[int] | None, seed: int | None, ) -> tuple[int, ...]: if seed is not None: return (int(seed),) if seeds is None: return (0, 1, 2, 3, 4, 5, 6, 7, 8, 9) if isinstance(seeds, int): return (int(seeds),) return tuple(int(s) for s in seeds)
[docs] def run_benchmark( category: str | None = None, models: type | tuple[Any, ...] | ModelSpec | dict[str, type | tuple[Any, ...] | ModelSpec] | None = None, root: str = "data/benchmark", hidden_channels: int = 64, epochs: int | None = None, only: list[str] | None = None, verbose: bool = True, seeds: int | Iterable[int] | None = None, seed: int | None = None, task_type: str | None = None, tasks: Iterable[Task] | None = None, device: torch.device | str | None = "auto", ) -> BenchmarkReport: """Run a benchmark across one or more (model, task, seed) combinations. Two ways to choose tasks: 1. **By category** (default) -- tasks come from :data:`BENCHMARK_TASKS` indexed as ``[category][task_type] -> list[Task]``. Pass ``category="social"`` (etc.) and optionally restrict with ``task_type`` and ``only=``. 2. **Ad-hoc** -- pass ``tasks=[Task(...), ...]`` to bypass the registry entirely. Useful for benchmarking custom datasets without mutating global state. ``category`` then defaults to ``"custom"`` and is used only to namespace ``root/`` cache directories. The runner trains every compatible (model, task) pair across each value in ``seeds`` (default ``(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)``) and aggregates the per-seed histories into a :class:`BenchmarkReport`. """ if models is None: msg = "run_benchmark requires `models` (a class, dict, or ModelSpec)" raise ValueError(msg) if task_type is not None and task_type not in TASK_TYPES: msg = f"Unknown task type {task_type!r}. Choices: {sorted(TASK_TYPES)}" raise ValueError(msg) if not isinstance(models, dict): spec = _spec_from(models) models = {spec.cls.__name__: spec} resolved = {name: _spec_from(value) for name, value in models.items()} seed_list = _normalize_seeds(seeds, seed) if tasks is not None: task_list = list(tasks) for t in task_list: if not isinstance(t, Task): msg = f"`tasks` must contain Task instances, got {type(t).__name__}" raise TypeError(msg) if t.task_type not in TASK_TYPES: msg = f"Task {t.name!r} has unknown task type {t.task_type!r}; choices: {sorted(TASK_TYPES)}" raise ValueError(msg) if task_type is not None: task_list = [t for t in task_list if t.task_type == task_type] if category is None: category = "custom" else: if category is None: msg = "run_benchmark requires either `category` or `tasks=`" raise ValueError(msg) if category not in BENCHMARK_TASKS: msg = f"Unknown category {category!r}. Choices: {sorted(BENCHMARK_TASKS)}" raise KeyError(msg) task_list = iter_benchmark_tasks(category=category, task_type=task_type) if only is not None: task_list = [t for t in task_list if t.name in only] tasks = task_list # the loop below treats this as the working list histories: dict[str, dict[str, list[dict[str, list[float]]]]] = {} total_combinations = sum( 1 for spec in resolved.values() for task in tasks if task.task_type in spec.task_type ) * len(seed_list) overall_pbar = tqdm( total=total_combinations, desc="Benchmark", unit="run", disable=not verbose, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", ) import inspect for task in tasks: try: seed_aware = "seed" in inspect.signature(task.loader).parameters except (TypeError, ValueError): seed_aware = False ds_cache: Any = None # for seed-agnostic loaders, load once histories[task.name] = {} for model_name, spec in resolved.items(): if task.task_type not in spec.task_type: continue histories[task.name][model_name] = [] for s in seed_list: _seed_all(s) if seed_aware: # Seed-aware loaders (e.g. synthetic combinatorial graphs) # produce a fresh dataset per seed, so cross-seed variance # captures data resampling rather than only model init. ds = task.loader(f"{root}/{category}/{task.name}/seed{s}", seed=s) else: if ds_cache is None: ds_cache = task.loader(f"{root}/{category}/{task.name}") ds = ds_cache history = _run_task( task, ds, spec, hidden_channels, epochs or task.epochs, verbose=verbose, device=device ) histories[task.name][model_name].append(history) # Update overall progress with latest metric last_metrics = {k: v[-1] for k, v in history.items() if v} metric_str = " ".join(f"{k[:3]}={v:.3f}" for k, v in last_metrics.items()) overall_pbar.set_postfix_str(f"{task.name}/{model_name}/s{s} | {metric_str}", refresh=False) overall_pbar.update(1) overall_pbar.close() from graphnetz.training import _resolve_device config = { "category": category, "task": task, "hidden_channels": hidden_channels, "epochs": epochs, "only": only, "device": str(_resolve_device(device)), } return BenchmarkReport(seeds=seed_list, histories=histories, config=config)
[docs] def plot_benchmark( results: BenchmarkReport | Mapping[str, Mapping[str, Mapping[str, list[float]]]], errors: Mapping[str, Mapping[str, float]] | None = None, ax: plt.Axes | None = None, title: str | None = None, annotate: bool = True, ci: float = 0.95, ) -> tuple[plt.Figure, plt.Axes]: """Grouped bar chart with mean ± CI error bars. Accepts a :class:`BenchmarkReport` (preferred) or the legacy dict form for a single seed. ``errors`` overrides the default t-CI half-width. """ if isinstance(results, BenchmarkReport): return results.plot(ax=ax, title=title, annotate=annotate, ci=ci) set_plot_style() values: dict[str, dict[str, float]] = {} metric_label: str | None = None for task_name, per_task in results.items(): per_value: dict[str, float] = {} for model_name, history in per_task.items(): metric, value = _final_metric(history) metric_label = metric_label or metric per_value[model_name] = value values[task_name] = per_value return plot_grouped_bars( values, errors=errors, ax=ax, title=title, ylabel=metric_label or "metric", annotate=annotate, )
__all__ = [ "BENCHMARK_TASKS", "TASK_TYPES", "BenchmarkReport", "ModelSpec", "Task", "iter_benchmark_tasks", "plot_benchmark", "register_model", "register_task", "run_benchmark", "save_figure", "task_from_dataset", "unregister_task", ]