Source code for graphnetz.benchmark._runner

"""The benchmark driver: per-task training dispatch and the public runner."""

from __future__ import annotations

from collections.abc import Iterable, Mapping
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm.auto import tqdm

from graphnetz.benchmark._report import BenchmarkReport
from graphnetz.benchmark._specs import TASK_TYPES, ModelSpec, Task, _spec_from
from graphnetz.benchmark._stats import _final_metric
from graphnetz.benchmark._tasks import BENCHMARK_TASKS, iter_benchmark_tasks
from graphnetz.plotting import plot_grouped_bars, set_plot_style
from graphnetz.training import (
    train_graph_classification,
    train_graph_regression,
    train_link_prediction,
    train_node_classification,
)

# --------------------------------------------------------------------------- #
# Task runner
# --------------------------------------------------------------------------- #


def _run_task(
    task_type: Task,
    ds: Any,
    spec: ModelSpec,
    hidden: int,
    epochs: int,
    verbose: bool = False,
    device: torch.device | str | None = "auto",
) -> dict[str, list[float]]:
    if task_type.task_type == "node_cls":
        data = ds[0]
        model = spec.build(ds.num_features, hidden, ds.num_classes, task_type="node_cls")
        return train_node_classification(model, data, epochs=epochs, verbose=verbose, device=device)

    if task_type.task_type == "graph_cls":
        shuffled = ds.shuffle()
        split = int(0.8 * len(shuffled))
        train_loader = DataLoader(shuffled[:split], batch_size=32, shuffle=True)
        val_loader = DataLoader(shuffled[split:], batch_size=32)
        model = spec.build(shuffled.num_features, hidden, shuffled.num_classes, task_type="graph_cls")
        return train_graph_classification(
            model, train_loader, val_loader, epochs=epochs, verbose=verbose, device=device
        )

    if task_type.task_type == "graph_reg":
        # Loader may return either a single dataset (used for both train and
        # held-out -- e.g. synthetic tasks with no canonical split) or a
        # ``(train_ds, val_ds)`` tuple (real benchmarks like ZINC).
        if isinstance(ds, tuple):
            train_ds, val_ds = ds
        else:
            train_ds = val_ds = ds
        train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
        val_loader = DataLoader(val_ds, batch_size=64)
        inner = spec.build(hidden, hidden, 1, task_type="graph_reg")

        class _AtomEmbed(torch.nn.Module):
            def __init__(self, num_atoms: int = 32) -> None:
                super().__init__()
                self.embed = torch.nn.Embedding(num_atoms, hidden)
                self.inner = inner

            def forward(self, batch: Any) -> torch.Tensor:
                # Embed only integer atom-type ids (e.g. ZINC); pass-through
                # any float feature matrix unchanged so we never silently
                # truncate continuous features via .long().
                if not batch.x.dtype.is_floating_point:
                    batch = batch.clone()
                    batch.x = self.embed(batch.x.view(-1).long())
                return self.inner(batch)

        return train_graph_regression(
            _AtomEmbed(), train_loader, val_loader, epochs=epochs, verbose=verbose, device=device
        )

    if task_type.task_type == "link_pred":
        import math

        from torch_geometric.transforms import RandomLinkSplit
        from torch_geometric.utils import degree

        data = ds[0]

        def _fabricate_log_degree_features(d: Data, edge_index: torch.Tensor) -> Data:
            """Build a 3-D log-degree feature from `edge_index` only.

            Used as the fallback when a loader ships no node features. We
            keep the source of degree restricted to the caller-supplied
            `edge_index` so val/test edges never bleed into the feature
            matrix at training time.
            """
            n = int(d.num_nodes)
            deg = degree(edge_index[0], num_nodes=n, dtype=torch.float)
            log_deg = torch.log1p(deg) / math.log(max(n, 2))
            ones = torch.ones(n)
            out = d.clone()
            out.x = torch.stack([log_deg, log_deg.pow(2), ones], dim=1)
            return out

        # Relational link prediction (knowledge graphs with edge_type) is
        # detected on the *raw* data because PyG's RelLinkPredDataset
        # already restricts edge_index to training edges -- we only
        # fabricate features when missing, using train_edge_index.
        if hasattr(data, "edge_type") and hasattr(data, "train_edge_index"):
            from graphnetz.models._adapters import RelationalLinkPredWrapper
            from graphnetz.training import train_relational_link_prediction

            if getattr(data, "x", None) is None:
                data = _fabricate_log_degree_features(data, data.train_edge_index)

            num_relations = ds.num_relations if hasattr(ds, "num_relations") else int(data.edge_type.max()) + 1
            built = spec.build(data.num_features, hidden, hidden, task_type="link_pred")
            # spec.build returns a LinkPredWrapper for task_type="link_pred"; unwrap it
            # so RelationalLinkPredWrapper can drive the bare encoder directly
            # (otherwise its forward expects data.edge_label_index).
            from typing import cast

            encoder = cast(
                torch.nn.Module,
                built.encoder if hasattr(built, "encoder") else built,
            )
            model = RelationalLinkPredWrapper(encoder, hidden, num_relations)

            # Create separate Data objects for train/val/test splits
            train_split = Data(
                x=data.x, edge_index=data.train_edge_index, edge_type=data.train_edge_type, num_nodes=data.num_nodes
            )
            val_split = Data(
                x=data.x, edge_index=data.valid_edge_index, edge_type=data.valid_edge_type, num_nodes=data.num_nodes
            )
            test_split = Data(
                x=data.x, edge_index=data.test_edge_index, edge_type=data.test_edge_type, num_nodes=data.num_nodes
            )

            return train_relational_link_prediction(
                model,
                train_split,
                val_split,
                test_split,
                epochs=epochs,
                verbose=verbose,
                device=device,
            )

        # Detect graph direction from the data itself instead of forcing
        # ``is_undirected=True`` -- on a directed graph the latter silently
        # de-duplicates reciprocal edges and halves the supervision signal.
        is_undirected = not bool(data.is_directed())
        transform = RandomLinkSplit(
            num_val=0.05,
            num_test=0.10,
            is_undirected=is_undirected,
            add_negative_train_samples=True,
            neg_sampling_ratio=1.0,
        )
        train_data, val_data, test_data = transform(data)
        # Fabricate features *after* the split so val/test edges never
        # leak into the node features the encoder consumes.  Use only the
        # training message-passing edges (edge_index, not edge_label_index)
        # for the degree statistic.
        if getattr(train_data, "x", None) is None:
            train_data = _fabricate_log_degree_features(train_data, train_data.edge_index)
            val_data = val_data.clone()
            val_data.x = train_data.x
            test_data = test_data.clone()
            test_data.x = train_data.x
        # ``spec.build(task_type="link_pred")`` returns a LinkPredWrapper, which
        # satisfies the ``_LinkPredLike`` protocol of the trainer; mypy
        # only sees the declared ``Module`` return so we narrow here.
        from typing import cast as _cast

        from graphnetz.training import _LinkPredLike

        lp_model = _cast(_LinkPredLike, spec.build(train_data.num_features, hidden, hidden, task_type="link_pred"))
        return train_link_prediction(
            lp_model, train_data, val_data, test_data, epochs=epochs, verbose=verbose, device=device
        )

    msg = f"Unknown task task_type: {task_type.task_type}"
    raise ValueError(msg)


def _seed_all(seed: int) -> None:
    """Seed every RNG that benchmark training touches."""
    import random

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def _normalize_seeds(
    seeds: int | Iterable[int] | None,
    seed: int | None,
) -> tuple[int, ...]:
    if seed is not None:
        return (int(seed),)
    if seeds is None:
        return (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
    if isinstance(seeds, int):
        return (int(seeds),)
    return tuple(int(s) for s in seeds)


[docs] def run_benchmark( category: str | None = None, models: type | tuple[Any, ...] | ModelSpec | dict[str, type | tuple[Any, ...] | ModelSpec] | None = None, root: str = "data/benchmark", hidden_channels: int = 64, epochs: int | None = None, only: list[str] | None = None, verbose: bool = True, seeds: int | Iterable[int] | None = None, seed: int | None = None, task_type: str | None = None, tasks: Iterable[Task] | None = None, device: torch.device | str | None = "auto", ) -> BenchmarkReport: """Run a benchmark across one or more (model, task, seed) combinations. Two ways to choose tasks: 1. **By category** (default) -- tasks come from :data:`BENCHMARK_TASKS` indexed as ``[category][task_type] -> list[Task]``. Pass ``category="social"`` (etc.) and optionally restrict with ``task_type`` and ``only=``. 2. **Ad-hoc** -- pass ``tasks=[Task(...), ...]`` to bypass the registry entirely. Useful for benchmarking custom datasets without mutating global state. ``category`` then defaults to ``"custom"`` and is used only to namespace ``root/`` cache directories. The runner trains every compatible (model, task) pair across each value in ``seeds`` (default ``(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)``) and aggregates the per-seed histories into a :class:`BenchmarkReport`. """ if models is None: msg = "run_benchmark requires `models` (a class, dict, or ModelSpec)" raise ValueError(msg) if task_type is not None and task_type not in TASK_TYPES: msg = f"Unknown task type {task_type!r}. Choices: {sorted(TASK_TYPES)}" raise ValueError(msg) if not isinstance(models, dict): spec = _spec_from(models) models = {spec.cls.__name__: spec} resolved = {name: _spec_from(value) for name, value in models.items()} seed_list = _normalize_seeds(seeds, seed) if tasks is not None: task_list = list(tasks) for t in task_list: if not isinstance(t, Task): msg = f"`tasks` must contain Task instances, got {type(t).__name__}" raise TypeError(msg) if t.task_type not in TASK_TYPES: msg = f"Task {t.name!r} has unknown task type {t.task_type!r}; choices: {sorted(TASK_TYPES)}" raise ValueError(msg) if task_type is not None: task_list = [t for t in task_list if t.task_type == task_type] if category is None: category = "custom" else: if category is None: msg = "run_benchmark requires either `category` or `tasks=`" raise ValueError(msg) if category not in BENCHMARK_TASKS: msg = f"Unknown category {category!r}. Choices: {sorted(BENCHMARK_TASKS)}" raise KeyError(msg) task_list = iter_benchmark_tasks(category=category, task_type=task_type) if only is not None: task_list = [t for t in task_list if t.name in only] tasks = task_list # the loop below treats this as the working list histories: dict[str, dict[str, list[dict[str, list[float]]]]] = {} total_combinations = sum( 1 for spec in resolved.values() for task in tasks if task.task_type in spec.task_type ) * len(seed_list) overall_pbar = tqdm( total=total_combinations, desc="Benchmark", unit="run", disable=not verbose, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", ) import inspect for task in tasks: try: seed_aware = "seed" in inspect.signature(task.loader).parameters except (TypeError, ValueError): seed_aware = False ds_cache: Any = None # for seed-agnostic loaders, load once histories[task.name] = {} for model_name, spec in resolved.items(): if task.task_type not in spec.task_type: continue histories[task.name][model_name] = [] for s in seed_list: _seed_all(s) if seed_aware: # Seed-aware loaders (e.g. synthetic combinatorial graphs) # produce a fresh dataset per seed, so cross-seed variance # captures data resampling rather than only model init. ds = task.loader(f"{root}/{category}/{task.name}/seed{s}", seed=s) else: if ds_cache is None: ds_cache = task.loader(f"{root}/{category}/{task.name}") ds = ds_cache history = _run_task( task, ds, spec, hidden_channels, epochs or task.epochs, verbose=verbose, device=device ) histories[task.name][model_name].append(history) # Update overall progress with latest metric last_metrics = {k: v[-1] for k, v in history.items() if v} metric_str = " ".join(f"{k[:3]}={v:.3f}" for k, v in last_metrics.items()) overall_pbar.set_postfix_str(f"{task.name}/{model_name}/s{s} | {metric_str}", refresh=False) overall_pbar.update(1) overall_pbar.close() from graphnetz.training import _resolve_device config = { "category": category, "task": task, "hidden_channels": hidden_channels, "epochs": epochs, "only": only, "device": str(_resolve_device(device)), } return BenchmarkReport(seeds=seed_list, histories=histories, config=config)
[docs] def plot_benchmark( results: BenchmarkReport | Mapping[str, Mapping[str, Mapping[str, list[float]]]], errors: Mapping[str, Mapping[str, float]] | None = None, ax: plt.Axes | None = None, title: str | None = None, annotate: bool = True, ci: float = 0.95, ) -> tuple[plt.Figure, plt.Axes]: """Grouped bar chart with mean ± CI error bars. Accepts a :class:`BenchmarkReport` (preferred) or the legacy dict form for a single seed. ``errors`` overrides the default t-CI half-width. """ if isinstance(results, BenchmarkReport): return results.plot(ax=ax, title=title, annotate=annotate, ci=ci) set_plot_style() values: dict[str, dict[str, float]] = {} metric_label: str | None = None for task_name, per_task in results.items(): per_value: dict[str, float] = {} for model_name, history in per_task.items(): metric, value = _final_metric(history) metric_label = metric_label or metric per_value[model_name] = value values[task_name] = per_value return plot_grouped_bars( values, errors=errors, ax=ax, title=title, ylabel=metric_label or "metric", annotate=annotate, )