import os
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, load_graphs
[docs]
class ZINCDataset(DGLBuiltinDataset):
    r"""ZINC dataset for the graph regression task.
    A subset (12K) of ZINC molecular graphs (250K) dataset is used to
    regress a molecular property known as the constrained solubility.
    For each molecular graph, the node features are the types of heavy
    atoms, between which the edge features are the types of bonds.
    Each graph contains 9-37 nodes and 16-84 edges.
    Reference `<https://arxiv.org/pdf/2003.00982.pdf>`_
    Statistics:
    Train examples: 10,000
    Valid examples: 1,000
    Test examples: 1,000
    Average number of nodes: 23.16
    Average number of edges: 39.83
    Number of atom types: 28
    Number of bond types: 4
    Parameters
    ----------
    mode : str, optional
        Should be chosen from ["train", "valid", "test"]
        Default: "train".
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: "~/.dgl/".
    force_reload : bool
        Whether to reload the dataset.
        Default: False.
    verbose : bool
        Whether to print out progress information.
        Default: False.
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
    Attributes
    ----------
    num_atom_types : int
        Number of atom types.
    num_bond_types : int
        Number of bond types.
    Examples
    ---------
    >>> from dgl.data import ZINCDataset
    >>> training_set = ZINCDataset(mode="train")
    >>> training_set.num_atom_types
    28
    >>> len(training_set)
    10000
    >>> graph, label = training_set[0]
    >>> graph
    Graph(num_nodes=29, num_edges=64,
        ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}
        edata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)})
    """
    def __init__(
        self,
        mode="train",
        raw_dir=None,
        force_reload=False,
        verbose=False,
        transform=None,
    ):
        self._url = _get_dgl_url("dataset/ZINC12k.zip")
        self.mode = mode
        super(ZINCDataset, self).__init__(
            name="zinc",
            url=self._url,
            raw_dir=raw_dir,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        )
    def process(self):
        self.load()
    @property
    def graph_path(self):
        return os.path.join(self.save_path, "ZincDGL_{}.bin".format(self.mode))
    def has_cache(self):
        return os.path.exists(self.graph_path)
    def load(self):
        self._graphs, self._labels = load_graphs(self.graph_path)
    @property
    def num_atom_types(self):
        return 28
    @property
    def num_bond_types(self):
        return 4
[docs]
    def __len__(self):
        return len(self._graphs) 
[docs]
    def __getitem__(self, idx):
        r"""Get one example by index.
        Parameters
        ----------
        idx : int
            The sample index.
        Returns
        -------
        dgl.DGLGraph
            Each graph contains:
            - ``ndata['feat']``: Types of heavy atoms as node features
            - ``edata['feat']``: Types of bonds as edge features
        Tensor
            Constrained solubility as graph label
        """
        labels = self._labels["g_label"]
        if self._transform is None:
            return self._graphs[idx], labels[idx]
        else:
            return self._transform(self._graphs[idx]), labels[idx]