SubgraphSampler๏
- class dgl.graphbolt.SubgraphSampler(datapipe, *args, **kwargs)[source]๏
Bases:
MiniBatchTransformer
A subgraph sampler used to sample a subgraph from a given set of nodes from a larger graph.
Functional name:
sample_subgraph
.This class is the base class of all subgraph samplers. Any subclass of SubgraphSampler should implement either the
sample_subgraphs()
method or thesampling_stages()
method to define the fine-grained sampling stages to take advantage of optimizations provided by the GraphBolt DataLoader.- Parameters:
datapipe (DataPipe) โ The datapipe.
args (Non-Keyword Arguments) โ Arguments to be passed into sampling_stages.
kwargs (Keyword Arguments) โ Arguments to be passed into sampling_stages.
- sample_subgraphs(seeds, seeds_timestamp, seeds_pre_time_window=None)[source]๏
Sample subgraphs from the given seeds, possibly with temporal constraints.
Any subclass of SubgraphSampler should implement this method.
- Parameters:
seeds (Union[torch.Tensor, Dict[str, torch.Tensor]]) โ The seed nodes.
seeds_timestamp (Union[torch.Tensor, Dict[str, torch.Tensor]]) โ The timestamps of the seed nodes. If given, the sampled subgraphs should not contain any nodes or edges that are newer than the timestamps of the seed nodes. Default: None.
seeds_pre_time_window (Union[torch.Tensor, Dict[str, torch.Tensor]]) โ The time window of the nodes represents a period of time before seeds_timestamp. If provided, only neighbors and related edges whose timestamps fall within [seeds_timestamp - seeds_pre_time_window, seeds_timestamp] will be filtered.
- Returns:
Union[torch.Tensor, Dict[str, torch.Tensor]] โ The input nodes.
List[SampledSubgraph] โ The sampled subgraphs.
Examples
>>> @functional_datapipe("my_sample_subgraph") >>> class MySubgraphSampler(SubgraphSampler): >>> def __init__(self, datapipe, graph, fanouts): >>> super().__init__(datapipe) >>> self.graph = graph >>> self.fanouts = fanouts >>> def sample_subgraphs(self, seeds): >>> # Sample subgraphs from the given seeds. >>> subgraphs = [] >>> subgraphs_nodes = [] >>> for fanout in reversed(self.fanouts): >>> subgraph = self.graph.sample_neighbors(seeds, fanout) >>> subgraphs.insert(0, subgraph) >>> subgraphs_nodes.append(subgraph.nodes) >>> seeds = subgraph.nodes >>> subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes)) >>> return subgraphs_nodes, subgraphs
- sampling_stages(datapipe)[source]๏
The sampling stages are defined here by chaining to the datapipe. The default implementation expects
sample_subgraphs()
to be implemented. To define fine-grained stages, this method should be overridden.