Source code for dgl.graphbolt.sampled_subgraph

"""Graphbolt sampled subgraph."""

# pylint: disable= invalid-name
from typing import Dict, NamedTuple, Tuple, Union

import torch

from .base import (
    apply_to,
    CSCFormatBase,
    etype_str_to_tuple,
    expand_indptr,
    is_object_pinned,
    isin,
)

from .internal_utils import recursive_apply


__all__ = ["SampledSubgraph"]


class _ExcludeEdgesWaiter:
    def __init__(self, sampled_subgraph, index):
        self.sampled_subgraph = sampled_subgraph
        self.index = index

    def wait(self):
        """Returns the stored value when invoked."""
        sampled_subgraph = self.sampled_subgraph
        index = self.index
        # Ensure there is no memory leak.
        self.sampled_subgraph = self.index = None

        if isinstance(index, dict):
            for k in list(index.keys()):
                index[k] = index[k].wait()
        else:
            index = index.wait()

        return type(sampled_subgraph)(*_slice_subgraph(sampled_subgraph, index))


class PyGLayerData(NamedTuple):
    """A named tuple class to represent homogenous inputs to a PyG model layer.
    The fields are x (input features), edge_index and size
    (source and destination sizes).
    """

    x: torch.Tensor
    edge_index: torch.Tensor
    size: Tuple[int, int]


class PyGLayerHeteroData(NamedTuple):
    """A named tuple class to represent heterogenous inputs to a PyG model
    layer. The fields are x (input features), edge_index and size
    (source and destination sizes), and all fields are dictionaries.
    """

    x: Dict[str, torch.Tensor]
    edge_index: Dict[str, torch.Tensor]
    size: Dict[str, Tuple[int, int]]


[docs] class SampledSubgraph: r"""An abstract class for sampled subgraph. In the context of a heterogeneous graph, each field should be of `Dict` type. Otherwise, for homogeneous graphs, each field should correspond to its respective value type.""" @property def sampled_csc( self, ) -> Union[CSCFormatBase, Dict[str, CSCFormatBase],]: """Returns the node pairs representing edges in csc format. - If `sampled_csc` is a CSCFormatBase: It should be in the csc format. `indptr` stores the index in the data array where each column starts. `indices` stores the row indices of the non-zero elements. - If `sampled_csc` is a dictionary: The keys should be edge type and the values should be corresponding node pairs. The ids inside is heterogeneous ids. Examples -------- 1. Homogeneous graph. >>> import dgl.graphbolt as gb >>> import torch >>> sampled_csc = gb.CSCFormatBase( ... indptr=torch.tensor([0, 1, 2, 3]), ... indices=torch.tensor([0, 1, 2])) >>> print(sampled_csc) CSCFormatBase(indptr=tensor([0, 1, 2, 3]), indices=tensor([0, 1, 2]), ) 2. Heterogeneous graph. >>> sampled_csc = {"A:relation:B": gb.CSCFormatBase( ... indptr=torch.tensor([0, 1, 2, 3]), ... indices=torch.tensor([0, 1, 2]))} >>> print(sampled_csc) {'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]), indices=tensor([0, 1, 2]), )} """ raise NotImplementedError @property def original_column_node_ids( self, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """Returns corresponding reverse column node ids the original graph. Column's reverse node ids in the original graph. A graph structure can be treated as a coordinated row and column pair, and this is the mapped ids of the column. - If `original_column_node_ids` is a tensor: It represents the original node ids. - If `original_column_node_ids` is a dictionary: The keys should be node type and the values should be corresponding original heterogeneous node ids. If present, it means column IDs are compacted, and `sampled_csc` column IDs match these compacted ones. """ return None @property def original_row_node_ids( self, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """Returns corresponding reverse row node ids the original graph. Row's reverse node ids in the original graph. A graph structure can be treated as a coordinated row and column pair, and this is the mapped ids of the row. - If `original_row_node_ids` is a tensor: It represents the original node ids. - If `original_row_node_ids` is a dictionary: The keys should be node type and the values should be corresponding original heterogeneous node ids. If present, it means row IDs are compacted, and `sampled_csc` row IDs match these compacted ones.""" return None @property def original_edge_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """Returns corresponding reverse edge ids the original graph. Reverse edge ids in the original graph. This is useful when edge features are needed. - If `original_edge_ids` is a tensor: It represents the original edge ids. - If `original_edge_ids` is a dictionary: The keys should be edge type and the values should be corresponding original heterogeneous edge ids. """ return None
[docs] def exclude_edges( self, edges: Union[ Dict[str, torch.Tensor], torch.Tensor, ], assume_num_node_within_int32: bool = True, async_op: bool = False, ): r"""Exclude edges from the sampled subgraph. This function can be used with sampled subgraphs, regardless of whether they have compacted row/column nodes or not. If the original subgraph has compacted row or column nodes, the corresponding row or column nodes in the returned subgraph will also be compacted. Parameters ---------- self : SampledSubgraph The sampled subgraph. edges : Union[torch.Tensor, Dict[str, torch.Tensor]] Edges to exclude. If sampled subgraph is homogeneous, then `edges` should be a N*2 tensors representing the edges to exclude. If sampled subgraph is heterogeneous, then `edges` should be a dictionary of edge types and the corresponding edges to exclude. assume_num_node_within_int32: bool If True, assumes the value of node IDs in the provided `edges` fall within the int32 range, which can significantly enhance computation speed. Default: True async_op: bool Boolean indicating whether the call is asynchronous. If so, the result can be obtained by calling wait on the returned future. Returns ------- SampledSubgraph An instance of a class that inherits from `SampledSubgraph`. Examples -------- >>> import dgl.graphbolt as gb >>> import torch >>> sampled_csc = {"A:relation:B": gb.CSCFormatBase( ... indptr=torch.tensor([0, 1, 2, 3]), ... indices=torch.tensor([0, 1, 2]))} >>> original_column_node_ids = {"B": torch.tensor([10, 11, 12])} >>> original_row_node_ids = {"A": torch.tensor([13, 14, 15])} >>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} >>> subgraph = gb.SampledSubgraphImpl( ... sampled_csc=sampled_csc, ... original_column_node_ids=original_column_node_ids, ... original_row_node_ids=original_row_node_ids, ... original_edge_ids=original_edge_ids ... ) >>> edges_to_exclude = {"A:relation:B": torch.tensor([[14, 11], [15, 12]])} >>> result = subgraph.exclude_edges(edges_to_exclude) >>> print(result.sampled_csc) {'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 1, 1]), indices=tensor([0]), )} >>> print(result.original_column_node_ids) {'B': tensor([10, 11, 12])} >>> print(result.original_row_node_ids) {'A': tensor([13, 14, 15])} >>> print(result.original_edge_ids) {'A:relation:B': tensor([19])} """ # TODO: Add support for value > in32, then remove this line. assert ( assume_num_node_within_int32 ), "Values > int32 are not supported yet." assert (isinstance(self.sampled_csc, CSCFormatBase)) == isinstance( edges, torch.Tensor ), ( "The sampled subgraph and the edges to exclude should be both " "homogeneous or both heterogeneous." ) # Get type of calling class. calling_class = type(self) # Three steps to exclude edges: # 1. Convert the node pairs to the original ids if they are compacted. # 2. Exclude the edges and get the index of the edges to keep. # 3. Slice the subgraph according to the index. if isinstance(self.sampled_csc, CSCFormatBase): reverse_edges = _to_reverse_ids( self.sampled_csc, self.original_row_node_ids, self.original_column_node_ids, ) index = _exclude_homo_edges( reverse_edges, edges, assume_num_node_within_int32, async_op ) else: index = {} for etype, pair in self.sampled_csc.items(): if etype not in edges: # No edges need to be excluded. index[etype] = None continue src_type, _, dst_type = etype_str_to_tuple(etype) original_row_node_ids = ( None if self.original_row_node_ids is None else self.original_row_node_ids.get(src_type) ) original_column_node_ids = ( None if self.original_column_node_ids is None else self.original_column_node_ids.get(dst_type) ) reverse_edges = _to_reverse_ids( pair, original_row_node_ids, original_column_node_ids, ) index[etype] = _exclude_homo_edges( reverse_edges, edges[etype], assume_num_node_within_int32, async_op, ) if async_op: return _ExcludeEdgesWaiter(self, index) else: return calling_class(*_slice_subgraph(self, index))
[docs] def to_pyg( self, x: Union[torch.Tensor, Dict[str, torch.Tensor]] ) -> Union[PyGLayerData, PyGLayerHeteroData]: """ Process layer inputs so that they can be consumed by a PyG model layer. Parameters ---------- x : Union[torch.Tensor, Dict[str, torch.Tensor]] The input node features to the GNN layer. Returns ------- Union[PyGLayerData, PyGLayerHeteroData] A named tuple class with `x`, `edge_index` and `size` fields. Typically, a PyG GNN layer's forward method will accept these as arguments. """ if isinstance(x, torch.Tensor): # Homogenous src = self.sampled_csc.indices dst = expand_indptr( self.sampled_csc.indptr, dtype=src.dtype, output_size=src.size(0), ) edge_index = torch.stack([src, dst], dim=0).long() dst_size = self.sampled_csc.indptr.size(0) - 1 # h and h[:dst_size] correspond to source and destination features resp. return PyGLayerData( (x, x[:dst_size]), edge_index, (x.size(0), dst_size) ) else: # Heterogenous x_dst_dict = {} edge_index_dict = {} sizes_dict = {} for etype, sampled_csc in self.sampled_csc.items(): src = sampled_csc.indices dst = expand_indptr( sampled_csc.indptr, dtype=src.dtype, output_size=src.size(0), ) edge_index = torch.stack([src, dst], dim=0).long() dst_size = sampled_csc.indptr.size(0) - 1 # h and h[:dst_size] correspond to source and destination features resp. src_ntype, _, dst_ntype = etype_str_to_tuple(etype) x_dst_dict[dst_ntype] = x[dst_ntype][:dst_size] edge_index_dict[etype] = edge_index sizes_dict[etype] = (x[src_ntype].size(0), dst_size) return PyGLayerHeteroData( (x, x_dst_dict), edge_index_dict, sizes_dict )
[docs] def to( self, device: torch.device, non_blocking=False ) -> None: # pylint: disable=invalid-name """Copy `SampledSubgraph` to the specified device using reflection.""" for attr in dir(self): # Only copy member variables. if not callable(getattr(self, attr)) and not attr.startswith("__"): setattr( self, attr, recursive_apply( getattr(self, attr), apply_to, device, non_blocking=non_blocking, ), ) return self
[docs] def pin_memory(self): """Copy `SampledSubgraph` to the pinned memory using reflection.""" return self.to("pinned")
[docs] def is_pinned(self) -> bool: """Check whether `SampledSubgraph` is pinned using reflection.""" return is_object_pinned(self)
def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids): indptr = node_pair.indptr indices = node_pair.indices if original_row_node_ids is not None: indices = torch.index_select( original_row_node_ids, dim=0, index=indices ) indptr = expand_indptr( indptr, indices.dtype, original_column_node_ids, len(indices) ) return (indices, indptr) def _relabel_two_arrays(lhs_array, rhs_array): """Relabel two arrays into a consecutive range starting from 0.""" concated = torch.cat([lhs_array, rhs_array]) _, mapping = torch.unique(concated, return_inverse=True) return mapping[: lhs_array.numel()], mapping[lhs_array.numel() :] def _exclude_homo_edges( edges: Tuple[torch.Tensor, torch.Tensor], edges_to_exclude: torch.Tensor, assume_num_node_within_int32: bool, async_op: bool, ): """Return the indices of edges to be included.""" if assume_num_node_within_int32: val = edges[0].long() << 32 | edges[1].long() edges_to_exclude_trans = edges_to_exclude.T val_to_exclude = ( edges_to_exclude_trans[0].long() << 32 | edges_to_exclude_trans[1].long() ) else: # TODO: Add support for value > int32. raise NotImplementedError( "Values out of range int32 are not supported yet" ) if async_op: return torch.ops.graphbolt.is_not_in_index_async(val, val_to_exclude) else: mask = ~isin(val, val_to_exclude) return torch.nonzero(mask, as_tuple=True)[0] def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor): """Slice the subgraph according to the index.""" def _index_select(obj, index): if obj is None: return None if index is None: return obj if isinstance(obj, CSCFormatBase): new_indices = obj.indices[index] new_indptr = torch.searchsorted(index, obj.indptr) return CSCFormatBase( indptr=new_indptr, indices=new_indices, ) if isinstance(obj, torch.Tensor): return obj[index] # Handle the case when obj is a dictionary. assert isinstance(obj, dict) assert isinstance(index, dict) ret = {} for k, v in obj.items(): ret[k] = _index_select(v, index[k]) return ret return ( _index_select(subgraph.sampled_csc, index), subgraph.original_column_node_ids, subgraph.original_row_node_ids, _index_select(subgraph.original_edge_ids, index), )