DataLoader
- class dgl.dataloading.DataLoader(graph, indices, graph_sampler, device=None, use_ddp=False, ddp_seed=0, batch_size=1, drop_last=False, shuffle=False, use_prefetch_thread=None, use_alternate_streams=None, pin_prefetcher=None, use_uva=False, gpu_cache=None, **kwargs)[source]
Bases:
DataLoaderSampled graph data loader. Wrap a
DGLGraphand aSamplerinto an iterable over mini-batches of samples.DGL’s
DataLoaderextends PyTorch’sDataLoaderby handling creation and transmission of graph samples. It supports iterating over a set of nodes, edges or any kinds of indices to get samples in the form ofDGLGraph, message flow graphs (MFGS), or any other structures necessary to train a graph neural network.- Parameters:
graph (DGLGraph) – The graph.
indices (Tensor or dict[ntype, Tensor]) –
The set of indices. It can either be a tensor of integer indices or a dictionary of types and indices.
The actual meaning of the indices is defined by the
sample()method ofgraph_sampler.graph_sampler (dgl.dataloading.Sampler) – The subgraph sampler.
device (device context, optional) –
The device of the generated MFGs in each iteration, which should be a PyTorch device object (e.g.,
torch.device).By default this value is None. If
use_uvais True, MFGs and graphs will generated in torch.cuda.current_device(), otherwise generated in the same device ofg.use_ddp (boolean, optional) –
If True, tells the DataLoader to split the training set for each participating process appropriately using
torch.utils.data.distributed.DistributedSampler.Overrides the
samplerargument oftorch.utils.data.DataLoader.ddp_seed (int, optional) –
The seed for shuffling the dataset in
torch.utils.data.distributed.DistributedSampler.Only effective when
use_ddpis True.use_uva (bool, optional) –
Whether to use Unified Virtual Addressing (UVA) to directly sample the graph and slice the features from CPU into GPU. Setting it to True will pin the graph and feature tensors into pinned memory.
If True, requires that
indicesmust have the same device as thedeviceargument.Default: False.
use_prefetch_thread (bool, optional) –
(Advanced option) Spawns a new Python thread to perform feature slicing asynchronously. Can make things faster at the cost of GPU memory.
Default: True if the graph is on CPU and
deviceis CUDA. False otherwise.use_alternate_streams (bool, optional) –
(Advanced option) Whether to slice and transfers the features to GPU on a non-default stream.
Default: True if the graph is on CPU,
deviceis CUDA, anduse_uvais False. False otherwise.pin_prefetcher (bool, optional) –
(Advanced option) Whether to pin the feature tensors into pinned memory.
Default: True if the graph is on CPU and
deviceis CUDA. False otherwise.gpu_cache (dict[dict], optional) –
Which node and edge features to cache using HugeCTR gpu_cache. Example: {“node”: {“features”: 500000}, “edge”: {“types”: 4000000}} would indicate that we want to cache 500k of the node “features” and 4M of the edge “types” in GPU caches.
Is supported only on NVIDIA GPUs with compute capability 70 or above. The dictionary holds the keys of features along with the corresponding cache sizes. Please see https://github.com/NVIDIA-Merlin/HugeCTR/blob/main/gpu_cache/ReadMe.md for further reference.
kwargs (dict) –
Key-word arguments to be passed to the parent PyTorch
torch.utils.data.DataLoaderclass. Common arguments are:batch_size(int): The number of indices in each batch.drop_last(bool): Whether to drop the last incomplete batch.shuffle(bool): Whether to randomly shuffle the indices at each epoch.
Examples
To train a 3-layer GNN for node classification on a set of nodes
train_nidon a homogeneous graph where each node takes messages from 15 neighbors on the first layer, 10 neighbors on the second, and 5 neighbors on the third (assume the backend is PyTorch):>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_nid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, output_nodes, blocks in dataloader: ... train_on(input_nodes, output_nodes, blocks)
Using with Distributed Data Parallel
If you are using PyTorch’s distributed training (e.g. when using
torch.nn.parallel.DistributedDataParallel), you can train the model by turning on the use_ddp option:>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_nid, sampler, use_ddp=True, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for epoch in range(start_epoch, n_epochs): ... for input_nodes, output_nodes, blocks in dataloader: ... train_on(input_nodes, output_nodes, blocks)
Notes
Please refer to Minibatch Training Tutorials and User Guide Section 6 for usage.
Tips for selecting the proper device
If the input graph
gis on GPU, the output devicedevicemust be the same GPU andnum_workersmust be zero. In this case, the sampling and subgraph construction will take place on the GPU. This is the recommended setting when using a single-GPU and the whole graph fits in GPU memory.If the input graph
gis on CPU while the output devicedeviceis GPU, then depending on the value ofuse_uva:If
use_uvais set to True, the sampling and subgraph construction will happen on GPU even if the GPU itself cannot hold the entire graph. This is the recommended setting unless there are operations not supporting UVA.num_workersmust be 0 in this case.Otherwise, both the sampling and subgraph construction will take place on the CPU.