.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/multi/2_node_classification.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_multi_2_node_classification.py: Single Machine Multi-GPU Minibatch Node Classification ====================================================== In this tutorial, you will learn how to use multiple GPUs in training a graph neural network (GNN) for node classification. (Time estimate: 8 minutes) This tutorial assumes that you have read the :doc:`Training GNN with Neighbor Sampling for Node Classification <../large/L1_large_node_classification>` tutorial. It also assumes that you know the basics of training general models with multi-GPU with ``DistributedDataParallel``. .. note:: See `this tutorial `__ from PyTorch for general multi-GPU training with ``DistributedDataParallel``. Also, see the first section of :doc:`the multi-GPU graph classification tutorial <1_graph_classification>` for an overview of using ``DistributedDataParallel`` with DGL. .. GENERATED FROM PYTHON SOURCE LINES 27-35 Loading Dataset --------------- OGB already prepared the data as a ``DGLGraph`` object. The following code is copy-pasted from the :doc:`Training GNN with Neighbor Sampling for Node Classification <../large/L1_large_node_classification>` tutorial. .. GENERATED FROM PYTHON SOURCE LINES 35-66 .. code-block:: Python import os os.environ["DGLBACKEND"] = "pytorch" import dgl import numpy as np import sklearn.metrics import torch import torch.nn as nn import torch.nn.functional as F import tqdm from dgl.nn import SAGEConv from ogb.nodeproppred import DglNodePropPredDataset dataset = DglNodePropPredDataset("ogbn-arxiv") graph, node_labels = dataset[0] # Add reverse edges since ogbn-arxiv is unidirectional. graph = dgl.add_reverse_edges(graph) graph.ndata["label"] = node_labels[:, 0] node_features = graph.ndata["feat"] num_features = node_features.shape[1] num_classes = (node_labels.max() + 1).item() idx_split = dataset.get_idx_split() train_nids = idx_split["train"] valid_nids = idx_split["valid"] test_nids = idx_split["test"] # Test node IDs, not used in the tutorial though. .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip 0%| | 0/81 [00:00` tutorial. .. GENERATED FROM PYTHON SOURCE LINES 74-92 .. code-block:: Python class Model(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(Model, self).__init__() self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type="mean") self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type="mean") self.h_feats = h_feats def forward(self, mfgs, x): h_dst = x[: mfgs[0].num_dst_nodes()] h = self.conv1(mfgs[0], (x, h_dst)) h = F.relu(h) h_dst = h[: mfgs[1].num_dst_nodes()] h = self.conv2(mfgs[1], (h, h_dst)) return h .. GENERATED FROM PYTHON SOURCE LINES 93-107 Defining Training Procedure --------------------------- The training procedure will be slightly different from what you saw previously, in the sense that you will need to * Initialize a distributed training context with ``torch.distributed``. * Wrap your model with ``torch.nn.parallel.DistributedDataParallel``. * Add a ``use_ddp=True`` argument to the DGL dataloader you wish to run together with DDP. You will also need to wrap the training loop inside a function so that you can spawn subprocesses to run it. .. GENERATED FROM PYTHON SOURCE LINES 107-232 .. code-block:: Python def run(proc_id, devices): # Initialize distributed training context. dev_id = devices[proc_id] dist_init_method = "tcp://{master_ip}:{master_port}".format( master_ip="127.0.0.1", master_port="12345" ) if torch.cuda.device_count() < 1: device = torch.device("cpu") torch.distributed.init_process_group( backend="gloo", init_method=dist_init_method, world_size=len(devices), rank=proc_id, ) else: torch.cuda.set_device(dev_id) device = torch.device("cuda:" + str(dev_id)) torch.distributed.init_process_group( backend="nccl", init_method=dist_init_method, world_size=len(devices), rank=proc_id, ) # Define training and validation dataloader, copied from the previous tutorial # but with one line of difference: use_ddp to enable distributed data parallel # data loading. sampler = dgl.dataloading.NeighborSampler([4, 4]) train_dataloader = dgl.dataloading.DataLoader( # The following arguments are specific to DataLoader. graph, # The graph train_nids, # The node IDs to iterate over in minibatches sampler, # The neighbor sampler device=device, # Put the sampled MFGs on CPU or GPU use_ddp=True, # Make it work with distributed data parallel # The following arguments are inherited from PyTorch DataLoader. batch_size=1024, # Per-device batch size. # The effective batch size is this number times the number of GPUs. shuffle=True, # Whether to shuffle the nodes for every epoch drop_last=False, # Whether to drop the last incomplete batch num_workers=0, # Number of sampler processes ) valid_dataloader = dgl.dataloading.DataLoader( graph, valid_nids, sampler, device=device, use_ddp=False, batch_size=1024, shuffle=False, drop_last=False, num_workers=0, ) model = Model(num_features, 128, num_classes).to(device) # Wrap the model with distributed data parallel module. if device == torch.device("cpu"): model = torch.nn.parallel.DistributedDataParallel( model, device_ids=None, output_device=None ) else: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[device], output_device=device ) # Define optimizer opt = torch.optim.Adam(model.parameters()) best_accuracy = 0 best_model_path = "./model.pt" # Copied from previous tutorial with changes highlighted. for epoch in range(10): model.train() with tqdm.tqdm(train_dataloader) as tq: for step, (input_nodes, output_nodes, mfgs) in enumerate(tq): # feature copy from CPU to GPU takes place here inputs = mfgs[0].srcdata["feat"] labels = mfgs[-1].dstdata["label"] predictions = model(mfgs, inputs) loss = F.cross_entropy(predictions, labels) opt.zero_grad() loss.backward() opt.step() accuracy = sklearn.metrics.accuracy_score( labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy(), ) tq.set_postfix( {"loss": "%.03f" % loss.item(), "acc": "%.03f" % accuracy}, refresh=False, ) model.eval() # Evaluate on only the first GPU. if proc_id == 0: predictions = [] labels = [] with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad(): for input_nodes, output_nodes, mfgs in tq: inputs = mfgs[0].srcdata["feat"] labels.append(mfgs[-1].dstdata["label"].cpu().numpy()) predictions.append( model(mfgs, inputs).argmax(1).cpu().numpy() ) predictions = np.concatenate(predictions) labels = np.concatenate(labels) accuracy = sklearn.metrics.accuracy_score(labels, predictions) print("Epoch {} Validation Accuracy {}".format(epoch, accuracy)) if best_accuracy < accuracy: best_accuracy = accuracy torch.save(model.state_dict(), best_model_path) # Note that this tutorial does not train the whole model to the end. break .. GENERATED FROM PYTHON SOURCE LINES 233-250 Spawning Trainer Processes -------------------------- A typical scenario for multi-GPU training with DDP is to replicate the model once per GPU, and spawn one trainer process per GPU. Normally, DGL maintains only one sparse matrix representation (usually COO) for each graph, and will create new formats when some APIs are called for efficiency. For instance, calling ``in_degrees`` will create a CSC representation for the graph, and calling ``out_degrees`` will create a CSR representation. A consequence is that if a graph is shared to trainer processes via copy-on-write *before* having its CSC/CSR created, each trainer will create its own CSC/CSR replica once ``in_degrees`` or ``out_degrees`` is called. To avoid this, you need to create all sparse matrix representations beforehand using the ``create_formats_`` method: .. GENERATED FROM PYTHON SOURCE LINES 250-254 .. code-block:: Python graph.create_formats_() .. GENERATED FROM PYTHON SOURCE LINES 255-265 Then you can spawn the subprocesses to train with multiple GPUs. .. code:: python # Say you have four GPUs. if __name__ == '__main__': num_gpus = 4 import torch.multiprocessing as mp mp.spawn(run, args=(list(range(num_gpus)),), nprocs=num_gpus) .. GENERATED FROM PYTHON SOURCE LINES 265-268 .. code-block:: Python # Thumbnail credits: Stanford CS224W Notes # sphinx_gallery_thumbnail_path = '_static/blitz_1_introduction.png' .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 6.136 seconds) .. _sphx_glr_download_tutorials_multi_2_node_classification.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 2_node_classification.ipynb <2_node_classification.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 2_node_classification.py <2_node_classification.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 2_node_classification.zip <2_node_classification.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_