ItemSampler๏
- class dgl.graphbolt.ItemSampler(item_set: ~dgl.graphbolt.itemset.ItemSet | ~dgl.graphbolt.itemset.HeteroItemSet, batch_size: int, minibatcher: ~typing.Callable | None = <function minibatcher_default>, drop_last: bool | None = False, shuffle: bool | None = False, seed: int | None = None)[source]๏
Bases:
IterDataPipe
A sampler to iterate over input items and create minibatches.
Input items could be node IDs, node pairs with or without labels, node pairs with negative sources/destinations.
Note: This class ItemSampler is not decorated with torchdata.datapipes.functional_datapipe on purpose. This indicates it does not support function-like call. But any iterable datapipes from torchdata can be further appended.
- Parameters:
item_set (Union[ItemSet, HeteroItemSet]) โ Data to be sampled.
batch_size (int) โ The size of each batch.
minibatcher (Optional[Callable]) โ A callable that takes in a list of items and returns a MiniBatch.
drop_last (bool) โ Option to drop the last batch if itโs not full.
shuffle (bool) โ Option to shuffle before sample.
seed (int) โ The seed for reproducible stochastic shuffling. If None, a random seed will be generated.
Examples
Node IDs.
>>> import torch >>> from dgl import graphbolt as gb >>> item_set = gb.ItemSet(torch.arange(0, 10), names="seeds") >>> item_sampler = gb.ItemSampler( ... item_set, batch_size=4, shuffle=False, drop_last=False ... ) >>> next(iter(item_sampler)) MiniBatch(seeds=tensor([0, 1, 2, 3]), sampled_subgraphs=None, node_features=None, labels=None, input_nodes=None, indexes=None, edge_features=None, compacted_seeds=None, blocks=None,)
Node pairs.
>>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), ... names="seeds") >>> item_sampler = gb.ItemSampler( ... item_set, batch_size=4, shuffle=False, drop_last=False ... ) >>> next(iter(item_sampler)) MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]), sampled_subgraphs=None, node_features=None, labels=None, input_nodes=None, indexes=None, edge_features=None, compacted_seeds=None, blocks=None,)
Node pairs and labels.
>>> item_set = gb.ItemSet( ... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 20)), ... names=("seeds", "labels") ... ) >>> item_sampler = gb.ItemSampler( ... item_set, batch_size=4, shuffle=False, drop_last=False ... ) >>> next(iter(item_sampler)) MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]), sampled_subgraphs=None, node_features=None, labels=tensor([10, 11, 12, 13]), input_nodes=None, indexes=None, edge_features=None, compacted_seeds=None, blocks=None,)
Node pairs, labels and indexes.
>>> seeds = torch.arange(0, 20).reshape(-1, 2) >>> labels = torch.tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0]) >>> indexes = torch.tensor([0, 1, 0, 0, 0, 0, 1, 1, 1, 1]) >>> item_set = gb.ItemSet((seeds, labels, indexes), names=("seeds", ... "labels", "indexes")) >>> item_sampler = gb.ItemSampler( ... item_set, batch_size=4, shuffle=False, drop_last=False ... ) >>> next(iter(item_sampler)) MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]), sampled_subgraphs=None, node_features=None, labels=tensor([1, 1, 0, 0]), input_nodes=None, indexes=tensor([0, 1, 0, 0]), edge_features=None, compacted_seeds=None, blocks=None,)
5. Further process batches with other datapipes such as
torchdata.datapipes.iter.Mapper
.>>> item_set = gb.ItemSet(torch.arange(0, 10)) >>> data_pipe = gb.ItemSampler(item_set, 4) >>> def add_one(batch): ... return batch + 1 >>> data_pipe = data_pipe.map(add_one) >>> list(data_pipe) [tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8]), tensor([ 9, 10])]
Heterogeneous node IDs.
>>> ids = { ... "user": gb.ItemSet(torch.arange(0, 5), names="seeds"), ... "item": gb.ItemSet(torch.arange(0, 6), names="seeds"), ... } >>> item_set = gb.HeteroItemSet(ids) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4) >>> next(iter(item_sampler)) MiniBatch(seeds={'user': tensor([0, 1, 2, 3])}, sampled_subgraphs=None, node_features=None, labels=None, input_nodes=None, indexes=None, edge_features=None, compacted_seeds=None, blocks=None,)
Heterogeneous node pairs.
>>> seeds_like = torch.arange(0, 10).reshape(-1, 2) >>> seeds_follow = torch.arange(10, 20).reshape(-1, 2) >>> item_set = gb.HeteroItemSet({ ... "user:like:item": gb.ItemSet( ... seeds_like, names="seeds"), ... "user:follow:user": gb.ItemSet( ... seeds_follow, names="seeds"), ... }) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4) >>> next(iter(item_sampler)) MiniBatch(seeds={'user:like:item': tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None, node_features=None, labels=None, input_nodes=None, indexes=None, edge_features=None, compacted_seeds=None, blocks=None,)
Heterogeneous node pairs and labels.
>>> seeds_like = torch.arange(0, 10).reshape(-1, 2) >>> labels_like = torch.arange(0, 5) >>> seeds_follow = torch.arange(10, 20).reshape(-1, 2) >>> labels_follow = torch.arange(5, 10) >>> item_set = gb.HeteroItemSet({ ... "user:like:item": gb.ItemSet((seeds_like, labels_like), ... names=("seeds", "labels")), ... "user:follow:user": gb.ItemSet((seeds_follow, labels_follow), ... names=("seeds", "labels")), ... }) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4) >>> next(iter(item_sampler)) MiniBatch(seeds={'user:like:item': tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None, node_features=None, labels={'user:like:item': tensor([0, 1, 2, 3])}, input_nodes=None, indexes=None, edge_features=None, compacted_seeds=None, blocks=None,)
Heterogeneous node pairs, labels and indexes.
>>> seeds_like = torch.arange(0, 10).reshape(-1, 2) >>> labels_like = torch.tensor([1, 1, 0, 0, 0]) >>> indexes_like = torch.tensor([0, 1, 0, 0, 1]) >>> seeds_follow = torch.arange(20, 30).reshape(-1, 2) >>> labels_follow = torch.tensor([1, 1, 0, 0, 0]) >>> indexes_follow = torch.tensor([0, 1, 0, 0, 1]) >>> item_set = gb.HeteroItemSet({ ... "user:like:item": gb.ItemSet((seeds_like, labels_like, ... indexes_like), names=("seeds", "labels", "indexes")), ... "user:follow:user": gb.ItemSet((seeds_follow,labels_follow, ... indexes_follow), names=("seeds", "labels", "indexes")), ... }) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4) >>> next(iter(item_sampler)) MiniBatch(seeds={'user:like:item': tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None, node_features=None, labels={'user:like:item': tensor([1, 1, 0, 0])}, input_nodes=None, indexes={'user:like:item': tensor([0, 1, 0, 0])}, edge_features=None, compacted_seeds=None, blocks=None,)