Source code for dgl.graphbolt.dataloader

"""Graph Bolt DataLoaders"""

import torch
import torch.utils.data as torch_data

from .base import CopyTo
from .datapipes import (
    datapipe_graph_to_adjlist,
    find_dps,
    replace_dp,
    traverse_dps,
)
from .feature_fetcher import FeatureFetcher, FeatureFetcherStartMarker
from .impl.neighbor_sampler import SamplePerLayer
from .internal_utils import gb_warning
from .item_sampler import ItemSampler
from .minibatch_transformer import MiniBatchTransformer


__all__ = [
    "DataLoader",
]


def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
    """Find parent of target_datapipe and wrap it with ."""
    datapipes = find_dps(
        datapipe_graph,
        target_datapipe,
    )
    datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)
    for datapipe in datapipes:
        datapipe_id = id(datapipe)
        for parent_datapipe_id in datapipe_adjlist[datapipe_id][1]:
            parent_datapipe, _ = datapipe_adjlist[parent_datapipe_id]
            datapipe_graph = replace_dp(
                datapipe_graph,
                parent_datapipe,
                wrapper(parent_datapipe, **kwargs),
            )
    return datapipe_graph


def _set_worker_id(worked_id):
    torch.ops.graphbolt.set_worker_id(worked_id)


class MultiprocessingWrapper(torch_data.IterDataPipe):
    """Wraps a datapipe with multiprocessing.

    Parameters
    ----------
    datapipe : DataPipe
        The data pipeline.
    num_workers : int, optional
        The number of worker processes. Default is 0, meaning that there
        will be no multiprocessing.
    persistent_workers : bool, optional
        If True, the data loader will not shut down the worker processes after a
        dataset has been consumed once. This allows to maintain the workers
        instances alive.
    """

    def __init__(self, datapipe, num_workers=0, persistent_workers=True):
        self.datapipe = datapipe
        self.dataloader = torch_data.DataLoader(
            datapipe,
            batch_size=None,
            num_workers=num_workers,
            persistent_workers=(num_workers > 0) and persistent_workers,
            worker_init_fn=_set_worker_id if num_workers > 0 else None,
        )

    def __iter__(self):
        yield from self.dataloader


[docs] class DataLoader(MiniBatchTransformer): """Multiprocessing DataLoader. Iterates over the data pipeline with everything before feature fetching (i.e. :class:`dgl.graphbolt.FeatureFetcher`) in subprocesses, and everything after feature fetching in the main process. The datapipe is modified in-place as a result. When the copy_to operation is placed earlier in the data pipeline, the num_workers argument is required to be 0 as utilizing CUDA in multiple worker processes is not supported. Parameters ---------- datapipe : DataPipe The data pipeline. num_workers : int, optional Number of worker processes. Default is 0. persistent_workers : bool, optional If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers instances alive. max_uva_threads : int, optional Limits the number of CUDA threads used for UVA copies so that the rest of the computations can run simultaneously with it. Setting it to a too high value will limit the amount of overlap while setting it too low may cause the PCI-e bandwidth to not get fully utilized. Manually tuned default is 10240, meaning around 5-7 Streaming Multiprocessors. """ def __init__( self, datapipe, num_workers=0, persistent_workers=True, max_uva_threads=10240, ): # Multiprocessing requires two modifications to the datapipe: # # 1. Insert a stage after ItemSampler to distribute the # minibatches evenly across processes. # 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe # of the FeatureFetcher with a multiprocessing PyTorch DataLoader. datapipe = datapipe.mark_end() datapipe_graph = traverse_dps(datapipe) if num_workers > 0: # (1) Insert minibatch distribution. # TODO(BarclayII): Currently I'm using sharding_filter() as a # concept demonstration. Later on minibatch distribution should be # merged into ItemSampler to maximize efficiency. item_samplers = find_dps( datapipe_graph, ItemSampler, ) for item_sampler in item_samplers: datapipe_graph = replace_dp( datapipe_graph, item_sampler, item_sampler.sharding_filter(), ) # (2) Cut datapipe at FeatureFetcher and wrap. datapipe_graph = _find_and_wrap_parent( datapipe_graph, FeatureFetcherStartMarker, MultiprocessingWrapper, num_workers=num_workers, persistent_workers=persistent_workers, ) # (3) Limit the number of UVA threads used if the feature_fetcher # or any of the samplers have overlapping optimization enabled. if num_workers == 0 and torch.cuda.is_available(): feature_fetchers = find_dps( datapipe_graph, FeatureFetcher, ) for feature_fetcher in feature_fetchers: if feature_fetcher.max_num_stages > 0: # Overlap enabled. torch.ops.graphbolt.set_max_uva_threads(max_uva_threads) if num_workers == 0 and torch.cuda.is_available(): samplers = find_dps( datapipe_graph, SamplePerLayer, ) for sampler in samplers: if sampler.overlap_fetch: torch.ops.graphbolt.set_max_uva_threads(max_uva_threads) # (4) Cut datapipe at CopyTo and wrap with pinning and prefetching # before it. This enables enables non_blocking copies to the device. # Prefetching enables the data pipeline up to the CopyTo to run in a # separate thread. copiers = find_dps(datapipe_graph, CopyTo) if len(copiers) > 1: gb_warning( "Multiple CopyTo operations were found in the datapipe graph." " This case is not officially supported." ) for copier in copiers: # We enable the prefetch at all times for good CPU only performance. datapipe_graph = replace_dp( datapipe_graph, copier, # Add prefetch so that CPU and GPU can run concurrently. copier.datapipe.prefetch(2).copy_to( copier.device, non_blocking=True ), ) super().__init__(datapipe)