[docs]classSAINTSampler(Sampler):"""Random node/edge/walk sampler from `GraphSAINT: Graph Sampling Based Inductive Learning Method <https://arxiv.org/abs/1907.04931>`__ For each call, the sampler samples a node subset and then returns a node induced subgraph. There are three options for sampling node subsets: - For :attr:`'node'` sampler, the probability to sample a node is in proportion to its out-degree. - The :attr:`'edge'` sampler first samples an edge subset and then use the end nodes of the edges. - The :attr:`'walk'` sampler uses the nodes visited by random walks. It uniformly selects a number of root nodes and then performs a fixed-length random walk from each root node. Parameters ---------- mode : str The sampler to use, which can be :attr:`'node'`, :attr:`'edge'`, or :attr:`'walk'`. budget : int or tuple[int] Sampler configuration. - For :attr:`'node'` sampler, budget specifies the number of nodes in each sampled subgraph. - For :attr:`'edge'` sampler, budget specifies the number of edges to sample for inducing a subgraph. - For :attr:`'walk'` sampler, budget is a tuple. budget[0] specifies the number of root nodes to generate random walks. budget[1] specifies the length of a random walk. cache : bool, optional If False, it will not cache the probability arrays for sampling. Setting it to False is required if you want to use the sampler across different graphs. prefetch_ndata : list[str], optional The node data to prefetch for the subgraph. See :ref:`guide-minibatch-prefetching` for a detailed explanation of prefetching. prefetch_edata : list[str], optional The edge data to prefetch for the subgraph. See :ref:`guide-minibatch-prefetching` for a detailed explanation of prefetching. output_device : device, optional The device of the output subgraphs. Examples -------- >>> import torch >>> from dgl.dataloading import SAINTSampler, DataLoader >>> num_iters = 1000 >>> sampler = SAINTSampler(mode='node', budget=6000) >>> # Assume g.ndata['feat'] and g.ndata['label'] hold node features and labels >>> dataloader = DataLoader(g, torch.arange(num_iters), sampler, num_workers=4) >>> for subg in dataloader: ... train_on(subg) """def__init__(self,mode,budget,cache=True,prefetch_ndata=None,prefetch_edata=None,output_device="cpu",):super().__init__()self.budget=budgetifmode=="node":self.sampler=self.node_samplerelifmode=="edge":self.sampler=self.edge_samplerelifmode=="walk":self.sampler=self.walk_samplerelse:raiseDGLError(f"Expect mode to be 'node', 'edge' or 'walk', got {mode}.")self.cache=cacheself.prob=Noneself.prefetch_ndata=prefetch_ndataor[]self.prefetch_edata=prefetch_edataor[]self.output_device=output_devicedefnode_sampler(self,g):"""Node ID sampler for random node sampler"""# Alternatively, this can be realized by uniformly sampling an edge subset,# and then take the src node of the sampled edges. However, the number of edges# is typically much larger than the number of nodes.ifself.cacheandself.probisnotNone:prob=self.probelse:prob=g.out_degrees().float().clamp(min=1)ifself.cache:self.prob=probreturn(torch.multinomial(prob,num_samples=self.budget,replacement=True).unique().type(g.idtype))defedge_sampler(self,g):"""Node ID sampler for random edge sampler"""src,dst=g.edges()ifself.cacheandself.probisnotNone:prob=self.probelse:in_deg=g.in_degrees().float().clamp(min=1)out_deg=g.out_degrees().float().clamp(min=1)# We can reduce the sample space by half if graphs are always symmetric.prob=1.0/in_deg[dst.long()]+1.0/out_deg[src.long()]prob/=prob.sum()ifself.cache:self.prob=probsampled_edges=torch.unique(choice(len(prob),size=self.budget,prob=prob))sampled_nodes=torch.cat([src[sampled_edges],dst[sampled_edges]])returnsampled_nodes.unique().type(g.idtype)defwalk_sampler(self,g):"""Node ID sampler for random walk sampler"""num_roots,walk_length=self.budgetsampled_roots=torch.randint(0,g.num_nodes(),(num_roots,))traces,types=random_walk(g,nodes=sampled_roots,length=walk_length)sampled_nodes,_,_,_=pack_traces(traces,types)returnsampled_nodes.unique().type(g.idtype)defsample(self,g,indices):"""Sampling function Parameters ---------- g : DGLGraph The graph to sample from. indices : Tensor Placeholder not used. Returns ------- DGLGraph The sampled subgraph. """node_ids=self.sampler(g)sg=g.subgraph(node_ids,relabel_nodes=True,output_device=self.output_device)set_node_lazy_features(sg,self.prefetch_ndata)set_edge_lazy_features(sg,self.prefetch_edata)returnsg