Source code for dgl.graphbolt.impl.uniform_negative_sampler
"""Uniform negative sampler for GraphBolt."""
import torch
from torch.utils.data import functional_datapipe
from ..negative_sampler import NegativeSampler
__all__ = ["UniformNegativeSampler"]
[docs]
@functional_datapipe("sample_uniform_negative")
class UniformNegativeSampler(NegativeSampler):
    """Sample negative destination nodes for each source node based on a uniform
    distribution.
    Functional name: :obj:`sample_uniform_negative`.
    It's important to note that the term 'negative' refers to false negatives,
    indicating that the sampled pairs are not ensured to be absent in the graph.
    For each edge ``(u, v)``, it is supposed to generate `negative_ratio` pairs
    of negative edges ``(u, v')``, where ``v'`` is chosen uniformly from all
    the nodes in the graph.
    Parameters
    ----------
    datapipe : DataPipe
        The datapipe.
    graph : FusedCSCSamplingGraph
        The graph on which to perform negative sampling.
    negative_ratio : int
        The proportion of negative samples to positive samples.
    Examples
    --------
    >>> from dgl import graphbolt as gb
    >>> indptr = torch.LongTensor([0, 1, 2, 3, 4])
    >>> indices = torch.LongTensor([1, 2, 3, 0])
    >>> graph = gb.fused_csc_sampling_graph(indptr, indices)
    >>> seeds = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]])
    >>> item_set = gb.ItemSet(seeds, names="seeds")
    >>> item_sampler = gb.ItemSampler(
    ...     item_set, batch_size=4,)
    >>> neg_sampler = gb.UniformNegativeSampler(
    ...     item_sampler, graph, 2)
    >>> for minibatch in neg_sampler:
    ...       print(minibatch.seeds)
    ...       print(minibatch.labels)
    ...       print(minibatch.indexes)
    tensor([[0, 1], [1, 2], [2, 3], [3, 0], [0, 1], [0, 3], [1, 1], [1, 2],
        [2, 1], [2, 0], [3, 0], [3, 2]])
    tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])
    tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3])
    """
    def __init__(
        self,
        datapipe,
        graph,
        negative_ratio,
    ):
        super().__init__(datapipe, negative_ratio)
        self.graph = graph
    def _sample_with_etype(self, seeds, etype=None):
        assert seeds.ndim == 2 and seeds.shape[1] == 2, (
            "Only tensor with shape N*2 is supported for negative"
            + f" sampling, but got {seeds.shape}."
        )
        # Sample negative edges, and concatenate positive edges with them.
        all_seeds = self.graph.sample_negative_edges_uniform(
            etype,
            seeds,
            self.negative_ratio,
        )
        # Construct indexes for all node pairs.
        pos_num = seeds.shape[0]
        negative_ratio = self.negative_ratio
        pos_indexes = torch.arange(0, pos_num, device=all_seeds.device)
        neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
        indexes = torch.cat((pos_indexes, neg_indexes))
        # Construct labels for all node pairs.
        neg_num = all_seeds.shape[0] - pos_num
        labels = torch.empty(pos_num + neg_num, device=all_seeds.device)
        labels[:pos_num] = 1
        labels[pos_num:] = 0
        return all_seeds, labels, indexes