"""Base classes and functionalities for dataloaders"""importinspectfromcollections.abcimportMappingfrom..importbackendasFfrom..baseimportEID,NIDfrom..convertimportheterographfrom..frameimportLazyFeaturefrom..transformsimportcompact_graphsfrom..utilsimportcontext_of,recursive_applydef_set_lazy_features(x,xdata,feature_names):iffeature_namesisNone:returnifnotisinstance(feature_names,Mapping):xdata.update({k:LazyFeature(k)forkinfeature_names})else:fortype_,namesinfeature_names.items():x[type_].data.update({k:LazyFeature(k)forkinnames})
[docs]defset_node_lazy_features(g,feature_names):"""Assign lazy features to the ``ndata`` of the input graph for prefetching optimization. When used in a :class:`~dgl.dataloading.Sampler`, lazy features mark which data should be fetched before computation in model. See :ref:`guide-minibatch-prefetching` for a detailed explanation. If the graph is homogeneous, this is equivalent to: .. code:: python g.ndata.update({k: LazyFeature(k, g.ndata[dgl.NID]) for k in feature_names}) If the graph is heterogeneous, this is equivalent to: .. code:: python for type_, names in feature_names.items(): g.nodes[type_].data.update( {k: LazyFeature(k, g.nodes[type_].data[dgl.NID]) for k in names}) Parameters ---------- g : DGLGraph The graph. feature_names : list[str] or dict[str, list[str]] The feature names to prefetch. See also -------- dgl.LazyFeature """return_set_lazy_features(g.nodes,g.ndata,feature_names)
[docs]defset_edge_lazy_features(g,feature_names):"""Assign lazy features to the ``edata`` of the input graph for prefetching optimization. When used in a :class:`~dgl.dataloading.Sampler`, lazy features mark which data should be fetched before computation in model. See :ref:`guide-minibatch-prefetching` for a detailed explanation. If the graph is homogeneous, this is equivalent to: .. code:: python g.edata.update({k: LazyFeature(k, g.edata[dgl.EID]) for k in feature_names}) If the graph is heterogeneous, this is equivalent to: .. code:: python for type_, names in feature_names.items(): g.edges[type_].data.update( {k: LazyFeature(k, g.edges[type_].data[dgl.EID]) for k in names}) Parameters ---------- g : DGLGraph The graph. feature_names : list[str] or dict[etype, list[str]] The feature names to prefetch. The ``etype`` key is either a string or a triplet. See also -------- dgl.LazyFeature """return_set_lazy_features(g.edges,g.edata,feature_names)
[docs]defset_src_lazy_features(g,feature_names):"""Assign lazy features to the ``srcdata`` of the input graph for prefetching optimization. When used in a :class:`~dgl.dataloading.Sampler`, lazy features mark which data should be fetched before computation in model. See :ref:`guide-minibatch-prefetching` for a detailed explanation. If the graph is homogeneous, this is equivalent to: .. code:: python g.srcdata.update({k: LazyFeature(k, g.srcdata[dgl.NID]) for k in feature_names}) If the graph is heterogeneous, this is equivalent to: .. code:: python for type_, names in feature_names.items(): g.srcnodes[type_].data.update( {k: LazyFeature(k, g.srcnodes[type_].data[dgl.NID]) for k in names}) Parameters ---------- g : DGLGraph The graph. feature_names : list[str] or dict[str, list[str]] The feature names to prefetch. See also -------- dgl.LazyFeature """return_set_lazy_features(g.srcnodes,g.srcdata,feature_names)
[docs]defset_dst_lazy_features(g,feature_names):"""Assign lazy features to the ``dstdata`` of the input graph for prefetching optimization. When used in a :class:`~dgl.dataloading.Sampler`, lazy features mark which data should be fetched before computation in model. See :ref:`guide-minibatch-prefetching` for a detailed explanation. If the graph is homogeneous, this is equivalent to: .. code:: python g.dstdata.update({k: LazyFeature(k, g.dstdata[dgl.NID]) for k in feature_names}) If the graph is heterogeneous, this is equivalent to: .. code:: python for type_, names in feature_names.items(): g.dstnodes[type_].data.update( {k: LazyFeature(k, g.dstnodes[type_].data[dgl.NID]) for k in names}) Parameters ---------- g : DGLGraph The graph. feature_names : list[str] or dict[str, list[str]] The feature names to prefetch. See also -------- dgl.LazyFeature """return_set_lazy_features(g.dstnodes,g.dstdata,feature_names)
[docs]classSampler(object):"""Base class for graph samplers. All graph samplers must subclass this class and override the ``sample`` method. .. code:: python from dgl.dataloading import Sampler class SubgraphSampler(Sampler): def __init__(self): super().__init__() def sample(self, g, indices): return g.subgraph(indices) """defsample(self,g,indices):"""Abstract sample method. Parameters ---------- g : DGLGraph The graph. indices : object Any object representing the indices selected in the current minibatch. """raiseNotImplementedError
[docs]classBlockSampler(Sampler):"""Base class for sampling mini-batches in the form of Message-passing Flow Graphs (MFGs). It provides prefetching options to fetch the node features for the first MFG's ``srcdata``, the node labels for the last MFG's ``dstdata`` and the edge features of all MFG's ``edata``. Parameters ---------- prefetch_node_feats : list[str] or dict[str, list[str]], optional The node data to prefetch for the first MFG. DGL will populate the first layer's MFG's ``srcnodes`` and ``srcdata`` with the node data of the given names from the original graph. prefetch_labels : list[str] or dict[str, list[str]], optional The node data to prefetch for the last MFG. DGL will populate the last layer's MFG's ``dstnodes`` and ``dstdata`` with the node data of the given names from the original graph. prefetch_edge_feats : list[str] or dict[etype, list[str]], optional The edge data names to prefetch for all the MFGs. DGL will populate every MFG's ``edges`` and ``edata`` with the edge data of the given names from the original graph. output_device : device, optional The device of the output subgraphs or MFGs. Default is the same as the minibatch of seed nodes. """
defsample_blocks(self,g,seed_nodes,exclude_eids=None):"""Generates a list of blocks from the given seed nodes. This function must return a triplet where the first element is the input node IDs for the first GNN layer (a tensor or a dict of tensors for heterogeneous graphs), the second element is the output node IDs for the last GNN layer, and the third element is the said list of blocks. """raiseNotImplementedErrordefassign_lazy_features(self,result):"""Assign lazy features for prefetching."""input_nodes,output_nodes,blocks=resultset_src_lazy_features(blocks[0],self.prefetch_node_feats)set_dst_lazy_features(blocks[-1],self.prefetch_labels)forblockinblocks:set_edge_lazy_features(block,self.prefetch_edge_feats)returninput_nodes,output_nodes,blocksdefsample(self,g,seed_nodes,exclude_eids=None):# pylint: disable=arguments-differ"""Sample a list of blocks from the given seed nodes."""result=self.sample_blocks(g,seed_nodes,exclude_eids=exclude_eids)returnself.assign_lazy_features(result)
def_find_exclude_eids_with_reverse_id(g,eids,reverse_eid_map):ifisinstance(eids,Mapping):eids={g.to_canonical_etype(k):vfork,vineids.items()}exclude_eids={k:F.cat([v,F.gather_row(reverse_eid_map[k],v)],0)fork,vineids.items()}else:exclude_eids=F.cat([eids,F.gather_row(reverse_eid_map,eids)],0)returnexclude_eidsdef_find_exclude_eids_with_reverse_types(g,eids,reverse_etype_map):exclude_eids={g.to_canonical_etype(k):vfork,vineids.items()}reverse_etype_map={g.to_canonical_etype(k):g.to_canonical_etype(v)fork,vinreverse_etype_map.items()}fork,vinreverse_etype_map.items():ifkinexclude_eids:ifvinexclude_eids:exclude_eids[v]=F.unique(F.cat((exclude_eids[k],exclude_eids[v]),dim=0))else:exclude_eids[v]=exclude_eids[k]returnexclude_eidsdef_find_exclude_eids(g,exclude_mode,eids,**kwargs):ifexclude_modeisNone:returnNoneelifcallable(exclude_mode):returnexclude_mode(eids)elifF.is_tensor(exclude_mode)or(isinstance(exclude_mode,Mapping)andall(F.is_tensor(v)forvinexclude_mode.values())):returnexclude_modeelifexclude_mode=="self":returneidselifexclude_mode=="reverse_id":return_find_exclude_eids_with_reverse_id(g,eids,kwargs["reverse_eid_map"])elifexclude_mode=="reverse_types":return_find_exclude_eids_with_reverse_types(g,eids,kwargs["reverse_etype_map"])else:raiseValueError("unsupported mode {}".format(exclude_mode))deffind_exclude_eids(g,seed_edges,exclude,reverse_eids=None,reverse_etypes=None,output_device=None,):"""Find all edge IDs to exclude according to :attr:`exclude_mode`. Parameters ---------- g : DGLGraph The graph. exclude : Can be either of the following, None (default) Does not exclude any edge. 'self' Exclude the given edges themselves but nothing else. 'reverse_id' Exclude all edges specified in ``eids``, as well as their reverse edges of the same edge type. The mapping from each edge ID to its reverse edge ID is specified in the keyword argument ``reverse_eid_map``. This mode assumes that the reverse of an edge with ID ``e`` and type ``etype`` will have ID ``reverse_eid_map[e]`` and type ``etype``. 'reverse_types' Exclude all edges specified in ``eids``, as well as their reverse edges of the corresponding edge types. The mapping from each edge type to its reverse edge type is specified in the keyword argument ``reverse_etype_map``. This mode assumes that the reverse of an edge with ID ``e`` and type ``etype`` will have ID ``e`` and type ``reverse_etype_map[etype]``. callable Any function that takes in a single argument :attr:`seed_edges` and returns a tensor or dict of tensors. eids : Tensor or dict[etype, Tensor] The edge IDs. reverse_eids : Tensor or dict[etype, Tensor] The mapping from edge ID to its reverse edge ID. reverse_etypes : dict[etype, etype] The mapping from edge etype to its reverse edge type. output_device : device The device of the output edge IDs. """exclude_eids=_find_exclude_eids(g,exclude,seed_edges,reverse_eid_map=reverse_eids,reverse_etype_map=reverse_etypes,)ifexclude_eidsisnotNoneandoutput_deviceisnotNone:exclude_eids=recursive_apply(exclude_eids,lambdax:F.copy_to(x,output_device))returnexclude_eidsclassEdgePredictionSampler(Sampler):"""Sampler class that wraps an existing sampler for node classification into another one for edge classification or link prediction. See also -------- as_edge_prediction_sampler """def__init__(self,sampler,exclude=None,reverse_eids=None,reverse_etypes=None,negative_sampler=None,prefetch_labels=None,):super().__init__()# Check if the sampler's sample method has an optional third argument.argspec=inspect.getfullargspec(sampler.sample)iflen(argspec.args)<4:# ['self', 'g', 'indices', 'exclude_eids']raiseTypeError("This sampler does not support edge or link prediction; please add an""optional third argument for edge IDs to exclude in its sample() method.")self.reverse_eids=reverse_eidsself.reverse_etypes=reverse_etypesself.exclude=excludeself.sampler=samplerself.negative_sampler=negative_samplerself.prefetch_labels=prefetch_labelsor[]self.output_device=sampler.output_devicedef_build_neg_graph(self,g,seed_edges):neg_srcdst=self.negative_sampler(g,seed_edges)ifnotisinstance(neg_srcdst,Mapping):assertlen(g.canonical_etypes)==1,("graph has multiple or no edge types; ""please return a dict in negative sampler.")neg_srcdst={g.canonical_etypes[0]:neg_srcdst}dtype=F.dtype(list(neg_srcdst.values())[0][0])ctx=context_of(seed_edges)ifseed_edgesisnotNoneelseg.deviceneg_edges={etype:neg_srcdst.get(etype,(F.copy_to(F.tensor([],dtype),ctx=ctx),F.copy_to(F.tensor([],dtype),ctx=ctx),),)foretypeing.canonical_etypes}neg_pair_graph=heterograph(neg_edges,{ntype:g.num_nodes(ntype)forntypeing.ntypes})returnneg_pair_graphdefassign_lazy_features(self,result):"""Assign lazy features for prefetching."""pair_graph=result[1]set_edge_lazy_features(pair_graph,self.prefetch_labels)# In-place updatesreturnresultdefsample(self,g,seed_edges):# pylint: disable=arguments-differ"""Samples a list of blocks, as well as a subgraph containing the sampled edges from the original graph. If :attr:`negative_sampler` is given, also returns another graph containing the negative pairs as edges. """ifisinstance(seed_edges,Mapping):seed_edges={g.to_canonical_etype(k):vfork,vinseed_edges.items()}exclude=self.excludepair_graph=g.edge_subgraph(seed_edges,relabel_nodes=False,output_device=self.output_device)eids=pair_graph.edata[EID]ifself.negative_samplerisnotNone:neg_graph=self._build_neg_graph(g,seed_edges)pair_graph,neg_graph=compact_graphs([pair_graph,neg_graph])else:pair_graph=compact_graphs(pair_graph)pair_graph.edata[EID]=eidsseed_nodes=pair_graph.ndata[NID]exclude_eids=find_exclude_eids(g,seed_edges,exclude,self.reverse_eids,self.reverse_etypes,self.output_device,)input_nodes,_,blocks=self.sampler.sample(g,seed_nodes,exclude_eids)ifself.negative_samplerisNone:returnself.assign_lazy_features((input_nodes,pair_graph,blocks))else:returnself.assign_lazy_features((input_nodes,pair_graph,neg_graph,blocks))
[docs]defas_edge_prediction_sampler(sampler,exclude=None,reverse_eids=None,reverse_etypes=None,negative_sampler=None,prefetch_labels=None,):"""Create an edge-wise sampler from a node-wise sampler. For each batch of edges, the sampler applies the provided node-wise sampler to their source and destination nodes to extract subgraphs. It also generates negative edges if a negative sampler is provided, and extract subgraphs for their incident nodes as well. For each iteration, the sampler will yield * A tensor of input nodes necessary for computing the representation on edges, or a dictionary of node type names and such tensors. * A subgraph that contains only the edges in the minibatch and their incident nodes. Note that the graph has an identical metagraph with the original graph. * If a negative sampler is given, another graph that contains the "negative edges", connecting the source and destination nodes yielded from the given negative sampler. * The subgraphs or MFGs returned by the provided node-wise sampler, generated from the incident nodes of the edges in the minibatch (as well as those of the negative edges if applicable). Parameters ---------- sampler : Sampler The node-wise sampler object. It additionally requires that the :attr:`sample` method must have an optional third argument :attr:`exclude_eids` representing the edge IDs to exclude from neighborhood. The argument will be either a tensor for homogeneous graphs or a dict of edge types and tensors for heterogeneous graphs. exclude : Union[str, callable], optional Whether and how to exclude dependencies related to the sampled edges in the minibatch. Possible values are * None, for not excluding any edges. * ``self``, for excluding the edges in the current minibatch. * ``reverse_id``, for excluding not only the edges in the current minibatch but also their reverse edges according to the ID mapping in the argument :attr:`reverse_eids`. * ``reverse_types``, for excluding not only the edges in the current minibatch but also their reverse edges stored in another type according to the argument :attr:`reverse_etypes`. * User-defined exclusion rule. It is a callable with edges in the current minibatch as a single argument and should return the edges to be excluded. reverse_eids : Tensor or dict[etype, Tensor], optional A tensor of reverse edge ID mapping. The i-th element indicates the ID of the i-th edge's reverse edge. If the graph is heterogeneous, this argument requires a dictionary of edge types and the reverse edge ID mapping tensors. reverse_etypes : dict[etype, etype], optional The mapping from the original edge types to their reverse edge types. negative_sampler : callable, optional The negative sampler. prefetch_labels : list[str] or dict[etype, list[str]], optional The edge labels to prefetch for the returned positive pair graph. See :ref:`guide-minibatch-prefetching` for a detailed explanation of prefetching. Examples -------- The following example shows how to train a 3-layer GNN for edge classification on a set of edges ``train_eid`` on a homogeneous undirected graph. Each node takes messages from all neighbors. Given an array of source node IDs ``src`` and another array of destination node IDs ``dst``, the following code creates a bidirectional graph: >>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src]))) Edge :math:`i`'s reverse edge in the graph above is edge :math:`i + |E|`. Therefore, we can create a reverse edge mapping ``reverse_eids`` by: >>> E = len(src) >>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)]) By passing ``reverse_eids`` to the edge sampler, the edges in the current mini-batch and their reversed edges will be excluded from the extracted subgraphs to avoid information leakage. >>> sampler = dgl.dataloading.as_edge_prediction_sampler( ... dgl.dataloading.NeighborSampler([15, 10, 5]), ... exclude='reverse_id', reverse_eids=reverse_eids) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_eid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, pair_graph, blocks in dataloader: ... train_on(input_nodes, pair_graph, blocks) For link prediction, one can provide a negative sampler to sample negative edges. The code below uses DGL's :class:`~dgl.dataloading.negative_sampler.Uniform` to generate 5 negative samples per edge: >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5) >>> sampler = dgl.dataloading.as_edge_prediction_sampler( ... dgl.dataloading.NeighborSampler([15, 10, 5]), ... sampler, exclude='reverse_id', reverse_eids=reverse_eids, ... negative_sampler=neg_sampler) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_eid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader: ... train_on(input_nodes, pair_graph, neg_pair_graph, blocks) For heterogeneous graphs, reverse edges may belong to a different relation. For example, the relations "user-click-item" and "item-click-by-user" in the graph below are mutual reverse. >>> g = dgl.heterograph({ ... ('user', 'click', 'item'): (user, item), ... ('item', 'clicked-by', 'user'): (item, user)}) To correctly exclude edges from each mini-batch, set ``exclude='reverse_types'`` and pass a dictionary ``{'click': 'clicked-by', 'clicked-by': 'click'}`` to the ``reverse_etypes`` argument. >>> sampler = dgl.dataloading.as_edge_prediction_sampler( ... dgl.dataloading.NeighborSampler([15, 10, 5]), ... exclude='reverse_types', ... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'}) >>> dataloader = dgl.dataloading.DataLoader( ... g, {'click': train_eid}, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, pair_graph, blocks in dataloader: ... train_on(input_nodes, pair_graph, blocks) For link prediction, provide a negative sampler to generate negative samples: >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5) >>> sampler = dgl.dataloading.as_edge_prediction_sampler( ... dgl.dataloading.NeighborSampler([15, 10, 5]), ... exclude='reverse_types', ... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'}, ... negative_sampler=neg_sampler) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_eid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader: ... train_on(input_nodes, pair_graph, neg_pair_graph, blocks) """returnEdgePredictionSampler(sampler,exclude=exclude,reverse_eids=reverse_eids,reverse_etypes=reverse_etypes,negative_sampler=negative_sampler,prefetch_labels=prefetch_labels,)