SubgraphSampler๏ƒ

class dgl.graphbolt.SubgraphSampler(datapipe)[source]๏ƒ

Bases: MiniBatchTransformer

A subgraph sampler used to sample a subgraph from a given set of nodes from a larger graph.

Functional name: sample_subgraph.

This class is the base class of all subgraph samplers. Any subclass of SubgraphSampler should implement the sample_subgraphs() method.

Parameters:

datapipe (DataPipe) โ€“ The datapipe.

sample_subgraphs(seeds, seeds_timestamp=None)[source]๏ƒ

Sample subgraphs from the given seeds.

Any subclass of SubgraphSampler should implement this method.

Parameters:

seeds (Union[torch.Tensor, Dict[str, torch.Tensor]]) โ€“ The seed nodes.

Returns:

  • Union[torch.Tensor, Dict[str, torch.Tensor]] โ€“ The input nodes.

  • List[SampledSubgraph] โ€“ The sampled subgraphs.

Examples

>>> @functional_datapipe("my_sample_subgraph")
>>> class MySubgraphSampler(SubgraphSampler):
>>>     def __init__(self, datapipe, graph, fanouts):
>>>         super().__init__(datapipe)
>>>         self.graph = graph
>>>         self.fanouts = fanouts
>>>     def sample_subgraphs(self, seeds):
>>>         # Sample subgraphs from the given seeds.
>>>         subgraphs = []
>>>         subgraphs_nodes = []
>>>         for fanout in reversed(self.fanouts):
>>>             subgraph = self.graph.sample_neighbors(seeds, fanout)
>>>             subgraphs.insert(0, subgraph)
>>>             subgraphs_nodes.append(subgraph.nodes)
>>>             seeds = subgraph.nodes
>>>         subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes))
>>>         return subgraphs_nodes, subgraphs