Source code for graphnetz.datasets.vision
"""Geometry and vision datasets.
Coverage:
- Image-derived superpixel graphs: ``MNISTSuperpixels``, ``CIFAR10`` (GNN benchmark).
- Meshes / point clouds: PyG ``ModelNet`` (10/40 classes), ``ShapeNet`` part segmentation.
"""
from torch_geometric.datasets import GNNBenchmarkDataset, MNISTSuperpixels, ModelNet, ShapeNet
[docs]
def mnist_superpixels(root: str, train: bool = True) -> MNISTSuperpixels:
"""MNIST images converted to 75-superpixel graphs."""
return MNISTSuperpixels(root=root, train=train)
[docs]
def cifar10_superpixels(root: str, split: str = "train") -> GNNBenchmarkDataset:
"""CIFAR10 superpixel graphs (GNN benchmark suite)."""
return GNNBenchmarkDataset(root=root, name="CIFAR10", split=split)
[docs]
def modelnet10(root: str, train: bool = True) -> ModelNet:
"""ModelNet10 3D shapes (10 classes)."""
return ModelNet(root=root, name="10", train=train)
[docs]
def modelnet40(root: str, train: bool = True) -> ModelNet:
"""ModelNet40 3D shapes (40 classes)."""
return ModelNet(root=root, name="40", train=train)
[docs]
def shapenet(root: str, categories: list[str] | None = None) -> ShapeNet:
"""ShapeNet point clouds with part-segmentation labels.
Pass ``categories=['Chair']`` (etc.) to limit to a subset.
"""
return ShapeNet(root=root, categories=categories)
__all__ = [
"cifar10_superpixels",
"mnist_superpixels",
"modelnet10",
"modelnet40",
"shapenet",
]