Source code for graphnetz.datasets.knowledge
"""Knowledge graph and language datasets.
Wraps PyG knowledge-graph benchmarks for relational link prediction.
"""
from torch_geometric.data import Data
from torch_geometric.datasets import RelLinkPredDataset, WordNet18RR
from graphnetz.datasets._netz import Netz
[docs]
def fb15k_237(root: str) -> RelLinkPredDataset:
"""FB15k-237 relational link prediction benchmark."""
return RelLinkPredDataset(root=root, name="FB15k-237")
class _WordNet18RRRel:
"""WN18-RR reshaped to match :class:`RelLinkPredDataset`'s interface.
PyG's :class:`WordNet18RR` exposes edge-level ``train_mask`` / ``val_mask`` /
``test_mask`` over a single ``edge_index``; the benchmark dispatcher's
relational path expects ``train_edge_index`` / ``valid_edge_index`` /
``test_edge_index`` (plus matching ``*_edge_type``) and ``num_relations``,
as on :class:`RelLinkPredDataset`. This wrapper performs that conversion.
"""
def __init__(self, base: WordNet18RR) -> None:
d = base[0]
ei, et = d.edge_index, d.edge_type
train_ei, train_et = ei[:, d.train_mask], et[d.train_mask]
valid_ei, valid_et = ei[:, d.val_mask], et[d.val_mask]
test_ei, test_et = ei[:, d.test_mask], et[d.test_mask]
self._data = Data(
edge_index=train_ei,
edge_type=train_et,
train_edge_index=train_ei,
train_edge_type=train_et,
valid_edge_index=valid_ei,
valid_edge_type=valid_et,
test_edge_index=test_ei,
test_edge_type=test_et,
num_nodes=int(d.num_nodes),
)
self.num_relations = int(et.max()) + 1
self.num_features = 0
def __getitem__(self, idx: int) -> Data:
if idx != 0:
raise IndexError(idx)
return self._data
def __len__(self) -> int:
return 1
[docs]
def wordnet18rr(root: str) -> _WordNet18RRRel:
"""WordNet18-RR relational link prediction benchmark."""
return _WordNet18RRRel(WordNet18RR(root=root))
[docs]
def wordnet_netz(root: str) -> Netz:
"""WordNet semantic graph (Netzschleuder)."""
return Netz(root=root, dataset_name="wordnet", network_name="wordnet")
__all__ = ["fb15k_237", "wordnet18rr", "wordnet_netz"]