"""Utility functions for external use."""fromfunctoolsimportpartialfromtypingimportDict,Unionimporttorchfromtorch.utils.dataimportfunctional_datapipefrom.minibatchimportMiniBatchfrom.minibatch_transformerimportMiniBatchTransformer@functional_datapipe("exclude_seed_edges")classSeedEdgesExcluder(MiniBatchTransformer):"""A mini-batch transformer used to manipulate mini-batch. Functional name: :obj:`transform`. Parameters ---------- datapipe : DataPipe The datapipe. include_reverse_edges : bool Whether reverse edges should be excluded as well. Default is False. reverse_etypes_mapping : Dict[str, str] = None The mapping from the original edge types to their reverse edge types. asynchronous: bool Boolean indicating whether edge exclusion stages should run on background threads to hide the latency of CPU GPU synchronization. Should be enabled only when sampling on the GPU. """def__init__(self,datapipe,include_reverse_edges:bool=False,reverse_etypes_mapping:Dict[str,str]=None,asynchronous=False,):exclude_seed_edges_fn=partial(exclude_seed_edges,include_reverse_edges=include_reverse_edges,reverse_etypes_mapping=reverse_etypes_mapping,async_op=asynchronous,)datapipe=datapipe.transform(exclude_seed_edges_fn)ifasynchronous:datapipe=datapipe.buffer()datapipe=datapipe.transform(self._wait_for_sampled_subgraphs)super().__init__(datapipe)@staticmethoddef_wait_for_sampled_subgraphs(minibatch):minibatch.sampled_subgraphs=[subgraph.wait()forsubgraphinminibatch.sampled_subgraphs]returnminibatch
[docs]defadd_reverse_edges(edges:Union[Dict[str,torch.Tensor],torch.Tensor],reverse_etypes_mapping:Dict[str,str]=None,):r""" This function finds the reverse edges of the given `edges` and returns the composition of them. In a homogeneous graph, reverse edges have inverted source and destination node IDs. While in a heterogeneous graph, reversing also involves swapping node IDs and their types. This function could be used before `exclude_edges` function to help find targeting edges. Note: The found reverse edges may not really exists in the original graph. And repeat edges could be added becasue reverse edges may already exists in the `edges`. Parameters ---------- edges : Union[Dict[str, torch.Tensor], torch.Tensor] - If sampled subgraph is homogeneous, then `edges` should be a N*2 tensors. - If sampled subgraph is heterogeneous, then `edges` should be a dictionary of edge types and the corresponding edges to exclude. reverse_etypes_mapping : Dict[str, str], optional The mapping from the original edge types to their reverse edge types. Returns ------- Union[Dict[str, torch.Tensor], torch.Tensor] The node pairs contain both the original edges and their reverse counterparts. Examples -------- >>> edges = {"A:r:B": torch.tensor([[0, 1],[1, 2]]))} >>> print(gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"})) {'A:r:B': torch.tensor([[0, 1],[1, 2]]), 'B:rr:A': torch.tensor([[1, 0],[2, 1]])} >>> edges = torch.tensor([[0, 1],[1, 2]]) >>> print(gb.add_reverse_edges(edges)) torch.tensor([[1, 0],[2, 1]]) """ifisinstance(edges,torch.Tensor):assertedges.ndim==2andedges.shape[1]==2,("Only tensor with shape N*2 is supported now, but got "+f"{edges.shape}.")reverse_edges=edges.flip(dims=(1,))returntorch.cat((edges,reverse_edges))else:combined_edges=edges.copy()foretype,reverse_etypeinreverse_etypes_mapping.items():ifetypeinedges:assertedges[etype].ndim==2andedges[etype].shape[1]==2,("Only tensor with shape N*2 is supported now, but got "+f"{edges[etype].shape}.")ifreverse_etypeincombined_edges:combined_edges[reverse_etype]=torch.cat((combined_edges[reverse_etype],edges[etype].flip(dims=(1,)),))else:combined_edges[reverse_etype]=edges[etype].flip(dims=(1,))returncombined_edges
[docs]defexclude_seed_edges(minibatch:MiniBatch,include_reverse_edges:bool=False,reverse_etypes_mapping:Dict[str,str]=None,async_op:bool=False,):""" Exclude seed edges with or without their reverse edges from the sampled subgraphs in the minibatch. Parameters ---------- minibatch : MiniBatch The minibatch. include_reverse_edges : bool Whether reverse edges should be excluded as well. Default is False. reverse_etypes_mapping : Dict[str, str] = None The mapping from the original edge types to their reverse edge types. async_op: bool Boolean indicating whether the call is asynchronous. If so, the result can be obtained by calling wait on the modified sampled_subgraphs. """edges_to_exclude=minibatch.seedsifinclude_reverse_edges:edges_to_exclude=add_reverse_edges(edges_to_exclude,reverse_etypes_mapping)minibatch.sampled_subgraphs=[subgraph.exclude_edges(edges_to_exclude,async_op=async_op)forsubgraphinminibatch.sampled_subgraphs]returnminibatch