SAINTSampler
- class dgl.dataloading.SAINTSampler(mode, budget, cache=True, prefetch_ndata=None, prefetch_edata=None, output_device='cpu')[source]
Bases:
SamplerRandom node/edge/walk sampler from GraphSAINT: Graph Sampling Based Inductive Learning Method
For each call, the sampler samples a node subset and then returns a node induced subgraph. There are three options for sampling node subsets:
For
'node'sampler, the probability to sample a node is in proportion to its out-degree.The
'edge'sampler first samples an edge subset and then use the end nodes of the edges.The
'walk'sampler uses the nodes visited by random walks. It uniformly selects a number of root nodes and then performs a fixed-length random walk from each root node.
- Parameters:
mode (str) – The sampler to use, which can be
'node','edge', or'walk'.Sampler configuration.
For
'node'sampler, budget specifies the number of nodes in each sampled subgraph.For
'edge'sampler, budget specifies the number of edges to sample for inducing a subgraph.For
'walk'sampler, budget is a tuple. budget[0] specifies the number of root nodes to generate random walks. budget[1] specifies the length of a random walk.
cache (bool, optional) – If False, it will not cache the probability arrays for sampling. Setting it to False is required if you want to use the sampler across different graphs.
prefetch_ndata (list[str], optional) –
The node data to prefetch for the subgraph.
See 6.8 Feature Prefetching for a detailed explanation of prefetching.
prefetch_edata (list[str], optional) –
The edge data to prefetch for the subgraph.
See 6.8 Feature Prefetching for a detailed explanation of prefetching.
output_device (device, optional) – The device of the output subgraphs.
Examples
>>> import torch >>> from dgl.dataloading import SAINTSampler, DataLoader >>> num_iters = 1000 >>> sampler = SAINTSampler(mode='node', budget=6000) >>> # Assume g.ndata['feat'] and g.ndata['label'] hold node features and labels >>> dataloader = DataLoader(g, torch.arange(num_iters), sampler, num_workers=4) >>> for subg in dataloader: ... train_on(subg)