"""Graphbolt dataset for legacy DGLDataset."""
from typing import List, Union
from dgl.data import AsNodePredDataset, DGLDataset
from ..base import etype_tuple_to_str
from ..dataset import Dataset, Task
from ..itemset import ItemSet, ItemSetDict
from ..sampling_graph import SamplingGraph
from .basic_feature_store import BasicFeatureStore
from .fused_csc_sampling_graph import from_dglgraph
from .ondisk_dataset import OnDiskTask
from .torch_based_feature_store import TorchBasedFeature
[docs]
class LegacyDataset(Dataset):
    """A Graphbolt dataset for legacy DGLDataset."""
    def __init__(self, legacy: DGLDataset):
        # Only supports single graph cases.
        assert len(legacy) == 1
        graph = legacy[0]
        # Handle OGB Dataset.
        if isinstance(graph, tuple):
            graph, _ = graph
        if graph.is_homogeneous:
            self._init_as_homogeneous_node_pred(legacy)
        else:
            self._init_as_heterogeneous_node_pred(legacy)
    def _init_as_heterogeneous_node_pred(self, legacy: DGLDataset):
        def _init_item_set_dict(idx, labels):
            item_set_dict = {}
            for key in idx.keys():
                item_set = ItemSet(
                    (idx[key], labels[key][idx[key]]),
                    names=("seed_nodes", "labels"),
                )
                item_set_dict[key] = item_set
            return ItemSetDict(item_set_dict)
        # OGB Dataset has the idx split.
        if hasattr(legacy, "get_idx_split"):
            graph, labels = legacy[0]
            split_idx = legacy.get_idx_split()
            # Initialize tasks.
            tasks = []
            metadata = {
                "num_classes": legacy.num_classes,
                "name": "node_classification",
            }
            train_set = _init_item_set_dict(split_idx["train"], labels)
            validation_set = _init_item_set_dict(split_idx["valid"], labels)
            test_set = _init_item_set_dict(split_idx["test"], labels)
            task = OnDiskTask(metadata, train_set, validation_set, test_set)
            tasks.append(task)
            self._tasks = tasks
            item_set_dict = {}
            for ntype in graph.ntypes:
                item_set = ItemSet(graph.num_nodes(ntype), names="seed_nodes")
                item_set_dict[ntype] = item_set
            self._all_nodes_set = ItemSetDict(item_set_dict)
            features = {}
            for ntype in graph.ntypes:
                for name in graph.nodes[ntype].data.keys():
                    tensor = graph.nodes[ntype].data[name]
                    if tensor.dim() == 1:
                        tensor = tensor.view(-1, 1)
                    features[("node", ntype, name)] = TorchBasedFeature(tensor)
            for etype in graph.canonical_etypes:
                for name in graph.edges[etype].data.keys():
                    tensor = graph.edges[etype].data[name]
                    if tensor.dim() == 1:
                        tensor = tensor.view(-1, 1)
                    gb_etype = etype_tuple_to_str(etype)
                    features[("edge", gb_etype, name)] = TorchBasedFeature(
                        tensor
                    )
            self._feature = BasicFeatureStore(features)
            self._graph = from_dglgraph(graph, is_homogeneous=False)
            self._dataset_name = legacy.name
        else:
            raise NotImplementedError(
                "Only support heterogeneous ogn node pred dataset"
            )
    def _init_as_homogeneous_node_pred(self, legacy: DGLDataset):
        legacy = AsNodePredDataset(legacy)
        # Initialize tasks.
        tasks = []
        metadata = {
            "num_classes": legacy.num_classes,
            "name": "node_classification",
        }
        train_labels = legacy[0].ndata["label"][legacy.train_idx]
        validation_labels = legacy[0].ndata["label"][legacy.val_idx]
        test_labels = legacy[0].ndata["label"][legacy.test_idx]
        train_set = ItemSet(
            (legacy.train_idx, train_labels),
            names=("seed_nodes", "labels"),
        )
        validation_set = ItemSet(
            (legacy.val_idx, validation_labels),
            names=("seed_nodes", "labels"),
        )
        test_set = ItemSet(
            (legacy.test_idx, test_labels), names=("seed_nodes", "labels")
        )
        task = OnDiskTask(metadata, train_set, validation_set, test_set)
        tasks.append(task)
        self._tasks = tasks
        num_nodes = legacy[0].num_nodes()
        self._all_nodes_set = ItemSet(num_nodes, names="seed_nodes")
        features = {}
        for name in legacy[0].ndata.keys():
            tensor = legacy[0].ndata[name]
            if tensor.dim() == 1:
                tensor = tensor.view(-1, 1)
            features[("node", None, name)] = TorchBasedFeature(tensor)
        for name in legacy[0].edata.keys():
            tensor = legacy[0].edata[name]
            if tensor.dim() == 1:
                tensor = tensor.view(-1, 1)
            features[("edge", None, name)] = TorchBasedFeature(tensor)
        self._feature = BasicFeatureStore(features)
        self._graph = from_dglgraph(legacy[0], is_homogeneous=True)
        self._dataset_name = legacy.name
    @property
    def tasks(self) -> List[Task]:
        """Return the tasks."""
        return self._tasks
    @property
    def graph(self) -> SamplingGraph:
        """Return the graph."""
        return self._graph
    @property
    def feature(self) -> BasicFeatureStore:
        """Return the feature."""
        return self._feature
    @property
    def dataset_name(self) -> str:
        """Return the dataset name."""
        return self._dataset_name
    @property
    def all_nodes_set(self) -> Union[ItemSet, ItemSetDict]:
        """Return the itemset containing all nodes."""
        return self._all_nodes_set