.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/large/L1_large_node_classification.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_large_L1_large_node_classification.py: Training GNN with Neighbor Sampling for Node Classification =========================================================== This tutorial shows how to train a multi-layer GraphSAGE for node classification on ``ogbn-arxiv`` provided by `Open Graph Benchmark (OGB) `__. The dataset contains around 170 thousand nodes and 1 million edges. By the end of this tutorial, you will be able to - Train a GNN model for node classification on a single GPU with DGL's neighbor sampling components. This tutorial assumes that you have read the :doc:`Introduction of Neighbor Sampling for GNN Training `. .. GENERATED FROM PYTHON SOURCE LINES 22-27 Loading Dataset --------------- OGB already prepared the data as DGL graph. .. GENERATED FROM PYTHON SOURCE LINES 27-40 .. code-block:: Python import os os.environ["DGLBACKEND"] = "pytorch" import dgl import numpy as np import torch from ogb.nodeproppred import DglNodePropPredDataset dataset = DglNodePropPredDataset("ogbn-arxiv") device = "cpu" # change to 'cuda' for GPU .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip 0%| | 0/81 [00:00`, you have seen that the computation dependency for message passing of a single node can be described as a series of *message flow graphs* (MFG). |image1| .. |image1| image:: https://data.dgl.ai/tutorial/img/bipartite.gif .. GENERATED FROM PYTHON SOURCE LINES 85-112 Defining Neighbor Sampler and Data Loader in DGL ------------------------------------------------ DGL provides tools to iterate over the dataset in minibatches while generating the computation dependencies to compute their outputs with the MFGs above. For node classification, you can use ``dgl.dataloading.DataLoader`` for iterating over the dataset. It accepts a sampler object to control how to generate the computation dependencies in the form of MFGs. DGL provides implementations of common sampling algorithms such as ``dgl.dataloading.NeighborSampler`` which randomly picks a fixed number of neighbors for each node. .. note:: To write your own neighbor sampler, please refer to :ref:`this user guide section `. The syntax of ``dgl.dataloading.DataLoader`` is mostly similar to a PyTorch ``DataLoader``, with the addition that it needs a graph to generate computation dependency from, a set of node IDs to iterate on, and the neighbor sampler you defined. Let’s say that each node will gather messages from 4 neighbors on each layer. The code defining the data loader and neighbor sampler will look like the following. .. GENERATED FROM PYTHON SOURCE LINES 112-128 .. code-block:: Python sampler = dgl.dataloading.NeighborSampler([4, 4]) train_dataloader = dgl.dataloading.DataLoader( # The following arguments are specific to DGL's DataLoader. graph, # The graph train_nids, # The node IDs to iterate over in minibatches sampler, # The neighbor sampler device=device, # Put the sampled MFGs on CPU or GPU # The following arguments are inherited from PyTorch DataLoader. batch_size=1024, # Batch size shuffle=True, # Whether to shuffle the nodes for every epoch drop_last=False, # Whether to drop the last incomplete batch num_workers=0, # Number of sampler processes ) .. GENERATED FROM PYTHON SOURCE LINES 129-135 .. note:: Since DGL 0.7 neighborhood sampling on GPU is supported. Please refer to :ref:`guide-minibatch-gpu-sampling` if you are interested. .. GENERATED FROM PYTHON SOURCE LINES 138-140 You can iterate over the data loader and see what it yields. .. GENERATED FROM PYTHON SOURCE LINES 140-152 .. code-block:: Python input_nodes, output_nodes, mfgs = example_minibatch = next( iter(train_dataloader) ) print(example_minibatch) print( "To compute {} nodes' outputs, we need {} nodes' input features".format( len(output_nodes), len(input_nodes) ) ) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/ubuntu/regression_test/dgl/python/dgl/dataloading/dataloader.py:1149: DGLWarning: Dataloader CPU affinity opt is not enabled, consider switching it on (see enable_cpu_affinity() or CPU best practices for DGL [https://docs.dgl.ai/tutorials/cpu/cpu_best_practises.html]) dgl_warning( [tensor([ 52794, 40221, 46723, ..., 85632, 137642, 127797]), tensor([ 52794, 40221, 46723, ..., 151851, 56563, 104106]), [Block(num_src_nodes=12846, num_dst_nodes=4100, num_edges=14797), Block(num_src_nodes=4100, num_dst_nodes=1024, num_edges=3265)]] To compute 1024 nodes' outputs, we need 12846 nodes' input features .. GENERATED FROM PYTHON SOURCE LINES 153-162 DGL's ``DataLoader`` gives us three items per iteration. - An ID tensor for the input nodes, i.e., nodes whose input features are needed on the first GNN layer for this minibatch. - An ID tensor for the output nodes, i.e. nodes whose representations are to be computed. - A list of MFGs storing the computation dependencies for each GNN layer. .. GENERATED FROM PYTHON SOURCE LINES 165-171 You can get the source and destination node IDs of the MFGs and verify that the first few source nodes are always the same as the destination nodes. As we described in the :doc:`overview `, destination nodes' own features from the previous layer may also be necessary in the computation of the new features. .. GENERATED FROM PYTHON SOURCE LINES 171-179 .. code-block:: Python mfg_0_src = mfgs[0].srcdata[dgl.NID] mfg_0_dst = mfgs[0].dstdata[dgl.NID] print(mfg_0_src) print(mfg_0_dst) print(torch.equal(mfg_0_src[: mfgs[0].num_dst_nodes()], mfg_0_dst)) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([ 52794, 40221, 46723, ..., 85632, 137642, 127797]) tensor([52794, 40221, 46723, ..., 13129, 10962, 77071]) True .. GENERATED FROM PYTHON SOURCE LINES 180-186 Defining Model -------------- Let’s consider training a 2-layer GraphSAGE with neighbor sampling. The model can be written as follows: .. GENERATED FROM PYTHON SOURCE LINES 186-213 .. code-block:: Python import torch.nn as nn import torch.nn.functional as F from dgl.nn import SAGEConv class Model(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(Model, self).__init__() self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type="mean") self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type="mean") self.h_feats = h_feats def forward(self, mfgs, x): # Lines that are changed are marked with an arrow: "<---" h_dst = x[: mfgs[0].num_dst_nodes()] # <--- h = self.conv1(mfgs[0], (x, h_dst)) # <--- h = F.relu(h) h_dst = h[: mfgs[1].num_dst_nodes()] # <--- h = self.conv2(mfgs[1], (h, h_dst)) # <--- return h model = Model(num_features, 128, num_classes).to(device) .. GENERATED FROM PYTHON SOURCE LINES 214-258 If you compare against the code in the :doc:`introduction <../blitz/1_introduction>`, you will notice several differences: - **DGL GNN layers on MFGs**. Instead of computing on the full graph: .. code:: python h = self.conv1(g, x) you only compute on the sampled MFG: .. code:: python h = self.conv1(mfgs[0], (x, h_dst)) All DGL’s GNN modules support message passing on MFGs, where you supply a pair of features, one for source nodes and another for destination nodes. - **Feature slicing for self-dependency**. There are statements that perform slicing to obtain the previous-layer representation of the nodes: .. code:: python h_dst = x[:mfgs[0].num_dst_nodes()] ``num_dst_nodes`` method works with MFGs, where it will return the number of destination nodes. Since the first few source nodes of the yielded MFG are always the same as the destination nodes, these statements obtain the representations of the destination nodes on the previous layer. They are then combined with neighbor aggregation in ``dgl.nn.SAGEConv`` layer. .. note:: See the :doc:`custom message passing tutorial ` for more details on how to manipulate MFGs produced in this way, such as the usage of ``num_dst_nodes``. .. GENERATED FROM PYTHON SOURCE LINES 261-266 Defining Training Loop ---------------------- The following initializes the model and defines the optimizer. .. GENERATED FROM PYTHON SOURCE LINES 266-270 .. code-block:: Python opt = torch.optim.Adam(model.parameters()) .. GENERATED FROM PYTHON SOURCE LINES 271-275 When computing the validation score for model selection, usually you can also do neighbor sampling. To do that, you need to define another data loader. .. GENERATED FROM PYTHON SOURCE LINES 275-290 .. code-block:: Python valid_dataloader = dgl.dataloading.DataLoader( graph, valid_nids, sampler, batch_size=1024, shuffle=False, drop_last=False, num_workers=0, device=device, ) import sklearn.metrics .. GENERATED FROM PYTHON SOURCE LINES 291-294 The following is a training loop that performs validation every epoch. It also saves the model with the best validation accuracy into a file. .. GENERATED FROM PYTHON SOURCE LINES 294-346 .. code-block:: Python import tqdm best_accuracy = 0 best_model_path = "model.pt" for epoch in range(10): model.train() with tqdm.tqdm(train_dataloader) as tq: for step, (input_nodes, output_nodes, mfgs) in enumerate(tq): # feature copy from CPU to GPU takes place here inputs = mfgs[0].srcdata["feat"] labels = mfgs[-1].dstdata["label"] predictions = model(mfgs, inputs) loss = F.cross_entropy(predictions, labels) opt.zero_grad() loss.backward() opt.step() accuracy = sklearn.metrics.accuracy_score( labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy(), ) tq.set_postfix( {"loss": "%.03f" % loss.item(), "acc": "%.03f" % accuracy}, refresh=False, ) model.eval() predictions = [] labels = [] with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad(): for input_nodes, output_nodes, mfgs in tq: inputs = mfgs[0].srcdata["feat"] labels.append(mfgs[-1].dstdata["label"].cpu().numpy()) predictions.append(model(mfgs, inputs).argmax(1).cpu().numpy()) predictions = np.concatenate(predictions) labels = np.concatenate(labels) accuracy = sklearn.metrics.accuracy_score(labels, predictions) print("Epoch {} Validation Accuracy {}".format(epoch, accuracy)) if best_accuracy < accuracy: best_accuracy = accuracy torch.save(model.state_dict(), best_model_path) # Note that this tutorial do not train the whole model to the end. break .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/89 [00:00`. - :doc:`Adapting your custom GNN module for stochastic training `. - During inference you may wish to disable neighbor sampling. If so, please refer to the :ref:`user guide on exact offline inference `. .. GENERATED FROM PYTHON SOURCE LINES 364-368 .. code-block:: Python # Thumbnail credits: Stanford CS224W Notes # sphinx_gallery_thumbnail_path = '_static/blitz_1_introduction.png' .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 8.151 seconds) .. _sphx_glr_download_tutorials_large_L1_large_node_classification.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: L1_large_node_classification.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: L1_large_node_classification.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: L1_large_node_classification.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_