"""Subgraph samplers"""
from collections import defaultdict
from functools import partial
from typing import Dict
import torch
import torch.distributed as thd
from torch.utils.data import functional_datapipe
from .base import seed_type_str_to_ntypes
from .internal import compact_temporal_nodes, unique_and_compact
from .minibatch import MiniBatch
from .minibatch_transformer import MiniBatchTransformer
__all__ = [
"SubgraphSampler",
"all_to_all",
"convert_to_hetero",
"revert_to_homo",
]
class _NoOpWaiter:
def __init__(self, result):
self.result = result
def wait(self):
"""Returns the stored value when invoked."""
result = self.result
# Ensure there is no memory leak.
self.result = None
return result
def _shift(inputs: list, group=None):
cutoff = len(inputs) - thd.get_rank(group)
return inputs[cutoff:] + inputs[:cutoff]
def all_to_all(outputs, inputs, group=None, async_op=False):
"""Wrapper for thd.all_to_all that permuted outputs and inputs before
calling it. The arguments have the permutation
`rank, ..., world_size - 1, 0, ..., rank - 1` and we make it
`0, world_size - 1` before calling `thd.all_to_all`."""
shift_fn = partial(_shift, group=group)
outputs = shift_fn(list(outputs))
inputs = shift_fn(list(inputs))
if outputs[0].is_cuda:
return thd.all_to_all(outputs, inputs, group, async_op)
# gloo backend will be used.
outputs_single = torch.cat(outputs)
output_split_sizes = [o.size(0) for o in outputs]
handle = thd.all_to_all_single(
outputs_single,
torch.cat(inputs),
output_split_sizes,
[i.size(0) for i in inputs],
group,
async_op,
)
temp_outputs = outputs_single.split(output_split_sizes)
class _Waiter:
def __init__(self, handle, outputs, temp_outputs):
self.handle = handle
self.outputs = outputs
self.temp_outputs = temp_outputs
def wait(self):
"""Returns the stored value when invoked."""
handle = self.handle
outputs = self.outputs
temp_outputs = self.temp_outputs
# Ensure that there is no leak
self.handle = self.outputs = self.temp_outputs = None
if handle is not None:
handle.wait()
for output, temp_output in zip(outputs, temp_outputs):
output.copy_(temp_output)
post_processor = _Waiter(handle, outputs, temp_outputs)
return post_processor if async_op else post_processor.wait()
def revert_to_homo(d: dict):
"""Utility function to convert a dictionary that stores homogenous data."""
is_homogenous = len(d) == 1 and "_N" in d
return list(d.values())[0] if is_homogenous else d
def convert_to_hetero(item):
"""Utility function to convert homogenous data to heterogenous with a single
node type."""
is_heterogenous = isinstance(item, dict)
return item if is_heterogenous else {"_N": item}
[docs]
@functional_datapipe("sample_subgraph")
class SubgraphSampler(MiniBatchTransformer):
"""A subgraph sampler used to sample a subgraph from a given set of nodes
from a larger graph.
Functional name: :obj:`sample_subgraph`.
This class is the base class of all subgraph samplers. Any subclass of
SubgraphSampler should implement either the :meth:`sample_subgraphs` method
or the :meth:`sampling_stages` method to define the fine-grained sampling
stages to take advantage of optimizations provided by the GraphBolt
DataLoader.
Parameters
----------
datapipe : DataPipe
The datapipe.
args : Non-Keyword Arguments
Arguments to be passed into sampling_stages.
kwargs : Keyword Arguments
Arguments to be passed into sampling_stages. Preprocessing stage makes
use of the `asynchronous` and `cooperative` parameters before they are
passed to the sampling stages.
"""
def __init__(
self,
datapipe,
*args,
**kwargs,
):
async_op = kwargs.get("asynchronous", False)
cooperative = kwargs.get("cooperative", False)
preprocess_fn = partial(
self._preprocess, cooperative=cooperative, async_op=async_op
)
datapipe = datapipe.transform(preprocess_fn)
if async_op:
fn = partial(self._wait_preprocess_future, cooperative=cooperative)
datapipe = datapipe.buffer().transform(fn)
if cooperative:
datapipe = datapipe.transform(self._seeds_cooperative_exchange_1)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(
self._seeds_cooperative_exchange_1_wait_future
).buffer()
datapipe = datapipe.transform(self._seeds_cooperative_exchange_2)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._seeds_cooperative_exchange_3)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._seeds_cooperative_exchange_4)
datapipe = self.sampling_stages(datapipe, *args, **kwargs)
datapipe = datapipe.transform(self._postprocess)
super().__init__(datapipe)
@staticmethod
def _postprocess(minibatch):
delattr(minibatch, "_seed_nodes")
delattr(minibatch, "_seeds_timestamp")
return minibatch
@staticmethod
def _preprocess(minibatch, cooperative: bool, async_op: bool):
if minibatch.seeds is None:
raise ValueError(
f"Invalid minibatch {minibatch}: `seeds` should have a value."
)
rank = thd.get_rank() if cooperative else 0
world_size = thd.get_world_size() if cooperative else 1
results = SubgraphSampler._seeds_preprocess(
minibatch, rank, world_size, async_op
)
if async_op:
minibatch._preprocess_future = results
else:
(
minibatch._seed_nodes,
minibatch._seeds_timestamp,
minibatch.compacted_seeds,
offsets,
) = results
if cooperative:
minibatch._seeds_offsets = offsets
return minibatch
@staticmethod
def _wait_preprocess_future(minibatch, cooperative: bool):
(
minibatch._seed_nodes,
minibatch._seeds_timestamp,
minibatch.compacted_seeds,
offsets,
) = minibatch._preprocess_future.wait()
delattr(minibatch, "_preprocess_future")
if cooperative:
minibatch._seeds_offsets = offsets
return minibatch
@staticmethod
def _seeds_cooperative_exchange_1(minibatch):
rank = thd.get_rank()
world_size = thd.get_world_size()
seeds = minibatch._seed_nodes
is_homogeneous = not isinstance(seeds, dict)
if is_homogeneous:
seeds = {"_N": seeds}
if minibatch._seeds_offsets is None:
assert minibatch.compacted_seeds is None
minibatch._rank_sort_future = torch.ops.graphbolt.rank_sort_async(
list(seeds.values()), rank, world_size
)
return minibatch
@staticmethod
def _seeds_cooperative_exchange_1_wait_future(minibatch):
world_size = thd.get_world_size()
seeds = minibatch._seed_nodes
is_homogeneous = not isinstance(seeds, dict)
if is_homogeneous:
seeds = {"_N": seeds}
num_ntypes = len(seeds.keys())
if minibatch._seeds_offsets is None:
result = minibatch._rank_sort_future.wait()
delattr(minibatch, "_rank_sort_future")
sorted_seeds, sorted_compacted, sorted_offsets = {}, {}, {}
for i, (
seed_type,
(typed_sorted_seeds, typed_index, typed_offsets),
) in enumerate(zip(seeds.keys(), result)):
sorted_seeds[seed_type] = typed_sorted_seeds
sorted_compacted[seed_type] = typed_index
sorted_offsets[seed_type] = typed_offsets
minibatch._seed_nodes = sorted_seeds
minibatch.compacted_seeds = revert_to_homo(sorted_compacted)
minibatch._seeds_offsets = sorted_offsets
else:
minibatch._seeds_offsets = {"_N": minibatch._seeds_offsets}
counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)
for i, offsets in enumerate(minibatch._seeds_offsets.values()):
counts_sent[
torch.arange(i, world_size * num_ntypes, num_ntypes)
] = offsets.diff()
delattr(minibatch, "_seeds_offsets")
counts_received = torch.empty_like(counts_sent)
minibatch._counts_future = all_to_all(
counts_received.split(num_ntypes),
counts_sent.split(num_ntypes),
async_op=True,
)
minibatch._counts_sent = counts_sent
minibatch._counts_received = counts_received
return minibatch
@staticmethod
def _seeds_cooperative_exchange_2(minibatch):
world_size = thd.get_world_size()
seeds = minibatch._seed_nodes
minibatch._counts_future.wait()
delattr(minibatch, "_counts_future")
num_ntypes = len(seeds.keys())
seeds_received = {}
counts_sent = {}
counts_received = {}
for i, (ntype, typed_seeds) in enumerate(seeds.items()):
idx = torch.arange(i, world_size * num_ntypes, num_ntypes)
typed_counts_sent = minibatch._counts_sent[idx].tolist()
typed_counts_received = minibatch._counts_received[idx].tolist()
typed_seeds_received = typed_seeds.new_empty(
sum(typed_counts_received)
)
all_to_all(
typed_seeds_received.split(typed_counts_received),
typed_seeds.split(typed_counts_sent),
)
seeds_received[ntype] = typed_seeds_received
counts_sent[ntype] = typed_counts_sent
counts_received[ntype] = typed_counts_received
minibatch._seed_nodes = seeds_received
minibatch._counts_sent = revert_to_homo(counts_sent)
minibatch._counts_received = revert_to_homo(counts_received)
return minibatch
@staticmethod
def _seeds_cooperative_exchange_3(minibatch):
nodes = {
ntype: [typed_seeds]
for ntype, typed_seeds in minibatch._seed_nodes.items()
}
minibatch._unique_future = unique_and_compact(
nodes, 0, 1, async_op=True
)
return minibatch
@staticmethod
def _seeds_cooperative_exchange_4(minibatch):
unique_seeds, inverse_seeds, _ = minibatch._unique_future.wait()
delattr(minibatch, "_unique_future")
inverse_seeds = {
ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items()
}
minibatch._seed_nodes = revert_to_homo(unique_seeds)
sizes = {
ntype: typed_seeds.size(0)
for ntype, typed_seeds in unique_seeds.items()
}
minibatch._seed_sizes = revert_to_homo(sizes)
minibatch._seed_inverse_ids = revert_to_homo(inverse_seeds)
return minibatch
def _sample(self, minibatch):
(
minibatch.input_nodes,
minibatch.sampled_subgraphs,
) = self.sample_subgraphs(
minibatch._seed_nodes, minibatch._seeds_timestamp
)
return minibatch
[docs]
def sampling_stages(self, datapipe):
"""The sampling stages are defined here by chaining to the datapipe. The
default implementation expects :meth:`sample_subgraphs` to be
implemented. To define fine-grained stages, this method should be
overridden.
"""
return datapipe.transform(self._sample)
@staticmethod
def _seeds_preprocess(
minibatch: MiniBatch,
rank: int = 0,
world_size: int = 1,
async_op: bool = False,
):
"""Preprocess `seeds` in a minibatch to construct `unique_seeds`,
`node_timestamp` and `compacted_seeds` for further sampling. It
optionally incorporates timestamps for temporal graphs, organizing and
compacting seeds based on their types and timestamps. In heterogeneous
graph, `seeds` with same node type will be unqiued together.
Parameters
----------
minibatch: MiniBatch
The minibatch.
rank : int
The rank of the current process among cooperating processes.
world_size : int
The number of cooperating
(`arXiv:2210.13339<https://arxiv.org/abs/2310.12403>`__) processes.
async_op: bool
Boolean indicating whether the call is asynchronous. If so, the
result can be obtained by calling wait on the returned future.
Returns
-------
unique_seeds: torch.Tensor or Dict[str, torch.Tensor]
A tensor or a dictionary of tensors representing the unique seeds.
In heterogeneous graphs, seeds are returned for each node type.
nodes_timestamp: None or a torch.Tensor or Dict[str, torch.Tensor]
Containing timestamps for each seed. This is only returned if
`minibatch` includes timestamps and the graph is temporal.
compacted_seeds: torch.tensor or a Dict[str, torch.Tensor]
Representation of compacted seeds corresponding to 'seeds', where
all node ids inside are compacted.
offsets: None or torch.Tensor or Dict[src, torch.Tensor]
The unique nodes offsets tensor partitions the unique_nodes tensor.
Has size `world_size + 1` and
`unique_nodes[offsets[i]: offsets[i + 1]]` belongs to the rank
`(rank + i) % world_size`.
"""
use_timestamp = hasattr(minibatch, "timestamp")
assert (
not use_timestamp or world_size == 1
), "Temporal code path does not currently support Cooperative Minibatching"
seeds = minibatch.seeds
is_heterogeneous = isinstance(seeds, Dict)
if is_heterogeneous:
# Collect nodes from all types of input.
nodes = defaultdict(list)
nodes_timestamp = None
if use_timestamp:
nodes_timestamp = defaultdict(list)
for seed_type, typed_seeds in seeds.items():
# When typed_seeds is a one-dimensional tensor, it represents
# seed nodes, which does not need to do unique and compact.
if typed_seeds.ndim == 1:
nodes_timestamp = (
minibatch.timestamp
if hasattr(minibatch, "timestamp")
else None
)
result = _NoOpWaiter((seeds, nodes_timestamp, None, None))
break
result = None
assert typed_seeds.ndim == 2, (
"Only tensor with shape 1*N and N*M is "
+ f"supported now, but got {typed_seeds.shape}."
)
ntypes = seed_type_str_to_ntypes(
seed_type, typed_seeds.shape[1]
)
if use_timestamp:
negative_ratio = (
typed_seeds.shape[0]
// minibatch.timestamp[seed_type].shape[0]
- 1
)
neg_timestamp = minibatch.timestamp[
seed_type
].repeat_interleave(negative_ratio)
for i, ntype in enumerate(ntypes):
nodes[ntype].append(typed_seeds[:, i])
if use_timestamp:
nodes_timestamp[ntype].append(
minibatch.timestamp[seed_type]
)
nodes_timestamp[ntype].append(neg_timestamp)
class _Waiter:
def __init__(self, nodes, nodes_timestamp, seeds):
# Unique and compact the collected nodes.
if use_timestamp:
self.future = compact_temporal_nodes(
nodes, nodes_timestamp
)
else:
self.future = unique_and_compact(
nodes, rank, world_size, async_op
)
self.seeds = seeds
def wait(self):
"""Returns the stored value when invoked."""
if use_timestamp:
unique_seeds, nodes_timestamp, compacted = self.future
offsets = None
else:
unique_seeds, compacted, offsets = (
self.future.wait() if async_op else self.future
)
nodes_timestamp = None
seeds = self.seeds
# Ensure there is no memory leak.
self.future = self.seeds = None
compacted_seeds = {}
# Map back in same order as collect.
for seed_type, typed_seeds in seeds.items():
ntypes = seed_type_str_to_ntypes(
seed_type, typed_seeds.shape[1]
)
compacted_seed = []
for ntype in ntypes:
compacted_seed.append(compacted[ntype].pop(0))
compacted_seeds[seed_type] = (
torch.cat(compacted_seed).view(len(ntypes), -1).T
)
return (
unique_seeds,
nodes_timestamp,
compacted_seeds,
offsets,
)
# When typed_seeds is not a one-dimensional tensor
if result is None:
result = _Waiter(nodes, nodes_timestamp, seeds)
else:
# When seeds is a one-dimensional tensor, it represents seed nodes,
# which does not need to do unique and compact.
if seeds.ndim == 1:
nodes_timestamp = (
minibatch.timestamp
if hasattr(minibatch, "timestamp")
else None
)
result = _NoOpWaiter((seeds, nodes_timestamp, None, None))
else:
# Collect nodes from all types of input.
nodes = [seeds.view(-1)]
nodes_timestamp = None
if use_timestamp:
# Timestamp for source and destination nodes are the same.
negative_ratio = (
seeds.shape[0] // minibatch.timestamp.shape[0] - 1
)
neg_timestamp = minibatch.timestamp.repeat_interleave(
negative_ratio
)
seeds_timestamp = torch.cat(
(minibatch.timestamp, neg_timestamp)
)
nodes_timestamp = [
seeds_timestamp for _ in range(seeds.shape[1])
]
class _Waiter:
def __init__(self, nodes, nodes_timestamp, seeds):
# Unique and compact the collected nodes.
if use_timestamp:
self.future = compact_temporal_nodes(
nodes, nodes_timestamp
)
else:
self.future = unique_and_compact(
nodes, async_op=async_op
)
self.seeds = seeds
def wait(self):
"""Returns the stored value when invoked."""
if use_timestamp:
(
unique_seeds,
nodes_timestamp,
compacted,
) = self.future
offsets = None
else:
unique_seeds, compacted, offsets = (
self.future.wait() if async_op else self.future
)
nodes_timestamp = None
seeds = self.seeds
# Ensure there is no memory leak.
self.future = self.seeds = None
# Map back in same order as collect.
compacted_seeds = compacted[0].view(seeds.shape)
return (
unique_seeds,
nodes_timestamp,
compacted_seeds,
offsets,
)
result = _Waiter(nodes, nodes_timestamp, seeds)
return result if async_op else result.wait()
[docs]
def sample_subgraphs(
self, seeds, seeds_timestamp, seeds_pre_time_window=None
):
"""Sample subgraphs from the given seeds, possibly with temporal constraints.
Any subclass of SubgraphSampler should implement this method.
Parameters
----------
seeds : Union[torch.Tensor, Dict[str, torch.Tensor]]
The seed nodes.
seeds_timestamp : Union[torch.Tensor, Dict[str, torch.Tensor]]
The timestamps of the seed nodes. If given, the sampled subgraphs
should not contain any nodes or edges that are newer than the
timestamps of the seed nodes. Default: None.
seeds_pre_time_window : Union[torch.Tensor, Dict[str, torch.Tensor]]
The time window of the nodes represents a period of time before
`seeds_timestamp`. If provided, only neighbors and related edges
whose timestamps fall within `[seeds_timestamp -
seeds_pre_time_window, seeds_timestamp]` will be filtered.
Returns
-------
Union[torch.Tensor, Dict[str, torch.Tensor]]
The input nodes.
List[SampledSubgraph]
The sampled subgraphs.
Examples
--------
>>> @functional_datapipe("my_sample_subgraph")
>>> class MySubgraphSampler(SubgraphSampler):
>>> def __init__(self, datapipe, graph, fanouts):
>>> super().__init__(datapipe)
>>> self.graph = graph
>>> self.fanouts = fanouts
>>> def sample_subgraphs(self, seeds):
>>> # Sample subgraphs from the given seeds.
>>> subgraphs = []
>>> subgraphs_nodes = []
>>> for fanout in reversed(self.fanouts):
>>> subgraph = self.graph.sample_neighbors(seeds, fanout)
>>> subgraphs.insert(0, subgraph)
>>> subgraphs_nodes.append(subgraph.nodes)
>>> seeds = subgraph.nodes
>>> subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes))
>>> return subgraphs_nodes, subgraphs
"""
raise NotImplementedError