import os
import pickle
import numpy as np
from scipy.spatial.distance import cdist
from tqdm.auto import tqdm
from .. import backend as F
from ..convert import graph as dgl_graph
from .dgl_dataset import DGLDataset
from .utils import download, extract_archive, load_graphs, save_graphs, Subset
def sigma(dists, kth=8):
    num_nodes = dists.shape[0]
    # Compute sigma and reshape.
    if kth > num_nodes:
        # Handling for graphs with num_nodes less than kth.
        sigma = np.array([1] * num_nodes).reshape(num_nodes, 1)
    else:
        # Get k-nearest neighbors for each node.
        knns = np.partition(dists, kth, axis=-1)[:, : kth + 1]
        sigma = knns.sum(axis=1).reshape((knns.shape[0], 1)) / kth
    return sigma + 1e-8
def compute_adjacency_matrix_images(coord, feat, use_feat=True):
    coord = coord.reshape(-1, 2)
    # Compute coordinate distance.
    c_dist = cdist(coord, coord)
    if use_feat:
        # Compute feature distance.
        f_dist = cdist(feat, feat)
        # Compute adjacency.
        A = np.exp(
            -((c_dist / sigma(c_dist)) ** 2) - (f_dist / sigma(f_dist)) ** 2
        )
    else:
        A = np.exp(-((c_dist / sigma(c_dist)) ** 2))
    # Convert to symmetric matrix.
    A = 0.5 * (A + A.T)
    A[np.diag_indices_from(A)] = 0
    return A
def compute_edges_list(A, kth=9):
    # Get k-similar neighbor indices for each node.
    num_nodes = A.shape[0]
    new_kth = num_nodes - kth
    if num_nodes > kth:
        knns = np.argpartition(A, new_kth - 1, axis=-1)[:, new_kth:-1]
        knn_values = np.partition(A, new_kth - 1, axis=-1)[:, new_kth:-1]
    else:
        # Handling for graphs with less than kth nodes.
        # In such cases, the resulting graph will be fully connected.
        knns = np.tile(np.arange(num_nodes), num_nodes).reshape(
            num_nodes, num_nodes
        )
        knn_values = A
        # Removing self loop.
        if num_nodes != 1:
            knn_values = A[knns != np.arange(num_nodes)[:, None]].reshape(
                num_nodes, -1
            )
            knns = knns[knns != np.arange(num_nodes)[:, None]].reshape(
                num_nodes, -1
            )
    return knns, knn_values
class SuperPixelDataset(DGLDataset):
    def __init__(
        self,
        raw_dir=None,
        name="MNIST",
        split="train",
        use_feature=False,
        force_reload=False,
        verbose=False,
        transform=None,
    ):
        assert split in ["train", "test"], "split not valid."
        assert name in ["MNIST", "CIFAR10"], "name not valid."
        self.use_feature = use_feature
        self.split = split
        self._dataset_name = name
        self.graphs = []
        self.labels = []
        super().__init__(
            name="Superpixel",
            raw_dir=raw_dir,
            url="""
            https://www.dropbox.com/s/y2qwa77a0fxem47/superpixels.zip?dl=1
            """,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        )
    @property
    def img_size(self):
        r"""Size of dataset image."""
        if self._dataset_name == "MNIST":
            return 28
        return 32
    @property
    def save_path(self):
        r"""Directory to save the processed dataset."""
        return os.path.join(self.raw_path, "processed")
    @property
    def raw_data_path(self):
        r"""Path to save the raw dataset file."""
        return os.path.join(self.raw_path, "superpixels.zip")
    @property
    def graph_path(self):
        r"""Path to save the processed dataset file."""
        if self.use_feature:
            return os.path.join(
                self.save_path,
                f"use_feat_{self._dataset_name}_{self.split}.pkl",
            )
        return os.path.join(
            self.save_path, f"{self._dataset_name}_{self.split}.pkl"
        )
    def download(self):
        path = download(self.url, path=self.raw_data_path)
        extract_archive(path, target_dir=self.raw_path, overwrite=True)
    def process(self):
        if self._dataset_name == "MNIST":
            plk_file = "mnist_75sp"
        elif self._dataset_name == "CIFAR10":
            plk_file = "cifar10_150sp"
        with open(
            os.path.join(
                self.raw_path, "superpixels", f"{plk_file}_{self.split}.pkl"
            ),
            "rb",
        ) as f:
            self.labels, self.sp_data = pickle.load(f)
            self.labels = F.tensor(self.labels)
        self.Adj_matrices = []
        self.node_features = []
        self.edges_lists = []
        self.edge_features = []
        for index, sample in enumerate(
            tqdm(self.sp_data, desc=f"Processing {self.split} dataset")
        ):
            mean_px, coord = sample[:2]
            coord = coord / self.img_size
            if self.use_feature:
                A = compute_adjacency_matrix_images(
                    coord, mean_px
                )  # using super-pixel locations + features
            else:
                A = compute_adjacency_matrix_images(
                    coord, mean_px, False
                )  # using only super-pixel locations
            edges_list, edge_values_list = compute_edges_list(A)
            N_nodes = A.shape[0]
            mean_px = mean_px.reshape(N_nodes, -1)
            coord = coord.reshape(N_nodes, 2)
            x = np.concatenate((mean_px, coord), axis=1)
            edge_values_list = edge_values_list.reshape(-1)
            self.node_features.append(x)
            self.edge_features.append(edge_values_list)
            self.Adj_matrices.append(A)
            self.edges_lists.append(edges_list)
        for index in tqdm(
            range(len(self.sp_data)), desc=f"Dump {self.split} dataset"
        ):
            N = self.node_features[index].shape[0]
            src_nodes = []
            dst_nodes = []
            for src, dsts in enumerate(self.edges_lists[index]):
                # handling for 1 node where the self loop would be the only edge
                if N == 1:
                    src_nodes.append(src)
                    dst_nodes.append(dsts)
                else:
                    dsts = dsts[dsts != src]
                    srcs = [src] * len(dsts)
                    src_nodes.extend(srcs)
                    dst_nodes.extend(dsts)
            src_nodes = F.tensor(src_nodes)
            dst_nodes = F.tensor(dst_nodes)
            g = dgl_graph((src_nodes, dst_nodes), num_nodes=N)
            g.ndata["feat"] = F.zerocopy_from_numpy(
                self.node_features[index]
            ).to(F.float32)
            g.edata["feat"] = (
                F.zerocopy_from_numpy(self.edge_features[index])
                .to(F.float32)
                .unsqueeze(1)
            )
            self.graphs.append(g)
    def load(self):
        self.graphs, label_dict = load_graphs(self.graph_path)
        self.labels = label_dict["labels"]
    def save(self):
        save_graphs(
            self.graph_path, self.graphs, labels={"labels": self.labels}
        )
    def has_cache(self):
        return os.path.exists(self.graph_path)
    def __len__(self):
        return len(self.graphs)
    def __getitem__(self, idx):
        """Get the idx-th sample.
        Parameters
        ---------
        idx : int or tensor
            The sample index.
            1-D tensor as `idx` is allowed when transform is None.
        Returns
        -------
        (:class:`dgl.DGLGraph`, Tensor)
            Graph with node feature stored in ``feat`` field and its label.
        or
        :class:`dgl.data.utils.Subset`
            Subset of the dataset at specified indices
        """
        if F.is_tensor(idx) and idx.dim() == 1:
            if self._transform is None:
                return Subset(self, idx.cpu())
            raise ValueError(
                "Tensor idx not supported when transform is not None."
            )
        if self._transform is None:
            return self.graphs[idx], self.labels[idx]
        return self._transform(self.graphs[idx]), self.labels[idx]
[docs]
class MNISTSuperPixelDataset(SuperPixelDataset):
    r"""MNIST superpixel dataset for the graph classification task.
    DGL dataset of MNIST and CIFAR10 in the benchmark-gnn which contains graphs
    converted fromt the original MINST and CIFAR10 images.
    Reference `<http://arxiv.org/abs/2003.00982>`_
    Statistics:
        - Train examples: 60,000
        - Test examples: 10,000
        - Size of dataset images: 28
    Parameters
    ----------
    raw_dir : str
        Directory to store all the downloaded raw datasets.
        Default: "~/.dgl/".
    split : str
        Should be chosen from ["train", "test"]
        Default: "train".
    use_feature: bool
        - True: Adj matrix defined from super-pixel locations + features
        - False: Adj matrix defined from super-pixel locations (only)
        Default: False.
    force_reload : bool
        Whether to reload the dataset.
        Default: False.
    verbose : bool
        Whether to print out progress information.
        Default: False.
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
    Examples
    ---------
    >>> from dgl.data import MNISTSuperPixelDataset
    >>> # MNIST dataset
    >>> train_dataset = MNISTSuperPixelDataset(split="train")
    >>> len(train_dataset)
    60000
    >>> graph, label = train_dataset[0]
    >>> graph
    Graph(num_nodes=71, num_edges=568,
        ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)}
        edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
    >>> # support tensor to be index when transform is None
    >>> # see details in __getitem__ function
    >>> import torch
    >>> idx = torch.tensor([0, 1, 2])
    >>> train_dataset_subset = train_dataset[idx]
    >>> train_dataset_subset[0]
    Graph(num_nodes=71, num_edges=568,
        ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)}
        edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
    """
    def __init__(
        self,
        raw_dir=None,
        split="train",
        use_feature=False,
        force_reload=False,
        verbose=False,
        transform=None,
    ):
        super().__init__(
            raw_dir=raw_dir,
            name="MNIST",
            split=split,
            use_feature=use_feature,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        ) 
[docs]
class CIFAR10SuperPixelDataset(SuperPixelDataset):
    r"""CIFAR10 superpixel dataset for the graph classification task.
    DGL dataset of CIFAR10 in the benchmark-gnn which contains graphs
    converted fromt the original CIFAR10 images.
    Reference `<http://arxiv.org/abs/2003.00982>`_
    Statistics:
        - Train examples: 50,000
        - Test examples: 10,000
        - Size of dataset images: 32
    Parameters
    ----------
    raw_dir : str
        Directory to store all the downloaded raw datasets.
        Default: "~/.dgl/".
    split : str
        Should be chosen from ["train", "test"]
        Default: "train".
    use_feature: bool
        - True: Adj matrix defined from super-pixel locations + features
        - False: Adj matrix defined from super-pixel locations (only)
        Default: False.
    force_reload : bool
        Whether to reload the dataset.
        Default: False.
    verbose : bool
        Whether to print out progress information.
        Default: False.
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
    Examples
    ---------
    >>> from dgl.data import CIFAR10SuperPixelDataset
    >>> # CIFAR10 dataset
    >>> train_dataset = CIFAR10SuperPixelDataset(split="train")
    >>> len(train_dataset)
    50000
    >>> graph, label = train_dataset[0]
    >>> graph
    Graph(num_nodes=123, num_edges=984,
        ndata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float32)}
        edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}),
    >>> # support tensor to be index when transform is None
    >>> # see details in __getitem__ function
    >>> import torch
    >>> idx = torch.tensor([0, 1, 2])
    >>> train_dataset_subset = train_dataset[idx]
    >>> train_dataset_subset[0]
    Graph(num_nodes=123, num_edges=984,
        ndata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float32)}
        edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}),
    """
    def __init__(
        self,
        raw_dir=None,
        split="train",
        use_feature=False,
        force_reload=False,
        verbose=False,
        transform=None,
    ):
        super().__init__(
            raw_dir=raw_dir,
            name="CIFAR10",
            split=split,
            use_feature=use_feature,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        )