.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/blitz/1_introduction.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_blitz_1_introduction.py: Node Classification with DGL ============================ GNNs are powerful tools for many machine learning tasks on graphs. In this introductory tutorial, you will learn the basic workflow of using GNNs for node classification, i.e. predicting the category of a node in a graph. By completing this tutorial, you will be able to - Load a DGL-provided dataset. - Build a GNN model with DGL-provided neural network modules. - Train and evaluate a GNN model for node classification on either CPU or GPU. This tutorial assumes that you have experience in building neural networks with PyTorch. (Time estimate: 13 minutes) .. GENERATED FROM PYTHON SOURCE LINES 23-33 .. code-block:: Python import os os.environ["DGLBACKEND"] = "pytorch" import dgl import dgl.data import torch import torch.nn as nn import torch.nn.functional as F .. GENERATED FROM PYTHON SOURCE LINES 34-64 Overview of Node Classification with GNN ---------------------------------------- One of the most popular and widely adopted tasks on graph data is node classification, where a model needs to predict the ground truth category of each node. Before graph neural networks, many proposed methods are using either connectivity alone (such as DeepWalk or node2vec), or simple combinations of connectivity and the node's own features. GNNs, by contrast, offers an opportunity to obtain node representations by combining the connectivity and features of a *local neighborhood*. `Kipf et al., `__ is an example that formulates the node classification problem as a semi-supervised node classification task. With the help of only a small portion of labeled nodes, a graph neural network (GNN) can accurately predict the node category of the others. This tutorial will show how to build such a GNN for semi-supervised node classification with only a small number of labels on the Cora dataset, a citation network with papers as nodes and citations as edges. The task is to predict the category of a given paper. Each paper node contains a word count vector as its features, normalized so that they sum up to one, as described in Section 5.2 of `the paper `__. Loading Cora Dataset -------------------- .. GENERATED FROM PYTHON SOURCE LINES 64-70 .. code-block:: Python dataset = dgl.data.CoraGraphDataset() print(f"Number of categories: {dataset.num_classes}") .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading /root/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip... /root/.dgl/cora_v2.zip: 0%| | 0.00/132k [00:00`__. Each layer computes new node representations by aggregating neighbor information. To build a multi-layer GCN you can simply stack ``dgl.nn.GraphConv`` modules, which inherit ``torch.nn.Module``. .. GENERATED FROM PYTHON SOURCE LINES 115-136 .. code-block:: Python from dgl.nn import GraphConv class GCN(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(GCN, self).__init__() self.conv1 = GraphConv(in_feats, h_feats) self.conv2 = GraphConv(h_feats, num_classes) def forward(self, g, in_feat): h = self.conv1(g, in_feat) h = F.relu(h) h = self.conv2(g, h) return h # Create the model with given dimensions model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes) .. GENERATED FROM PYTHON SOURCE LINES 137-140 DGL provides implementation of many popular neighbor aggregation modules. You can easily invoke them with one line of code. .. GENERATED FROM PYTHON SOURCE LINES 143-148 Training the GCN ---------------- Training this GCN is similar to training other PyTorch neural networks. .. GENERATED FROM PYTHON SOURCE LINES 148-196 .. code-block:: Python def train(g, model): optimizer = torch.optim.Adam(model.parameters(), lr=0.01) best_val_acc = 0 best_test_acc = 0 features = g.ndata["feat"] labels = g.ndata["label"] train_mask = g.ndata["train_mask"] val_mask = g.ndata["val_mask"] test_mask = g.ndata["test_mask"] for e in range(100): # Forward logits = model(g, features) # Compute prediction pred = logits.argmax(1) # Compute loss # Note that you should only compute the losses of the nodes in the training set. loss = F.cross_entropy(logits[train_mask], labels[train_mask]) # Compute accuracy on training/validation/test train_acc = (pred[train_mask] == labels[train_mask]).float().mean() val_acc = (pred[val_mask] == labels[val_mask]).float().mean() test_acc = (pred[test_mask] == labels[test_mask]).float().mean() # Save the best validation accuracy and the corresponding test accuracy. if best_val_acc < val_acc: best_val_acc = val_acc best_test_acc = test_acc # Backward optimizer.zero_grad() loss.backward() optimizer.step() if e % 5 == 0: print( f"In epoch {e}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best {best_test_acc:.3f})" ) model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes) train(g, model) .. rst-class:: sphx-glr-script-out .. code-block:: none In epoch 0, loss: 1.945, val acc: 0.122 (best 0.122), test acc: 0.160 (best 0.160) In epoch 5, loss: 1.888, val acc: 0.560 (best 0.560), test acc: 0.545 (best 0.545) In epoch 10, loss: 1.803, val acc: 0.618 (best 0.632), test acc: 0.623 (best 0.620) In epoch 15, loss: 1.695, val acc: 0.652 (best 0.652), test acc: 0.639 (best 0.639) In epoch 20, loss: 1.563, val acc: 0.678 (best 0.680), test acc: 0.676 (best 0.668) In epoch 25, loss: 1.411, val acc: 0.708 (best 0.708), test acc: 0.700 (best 0.700) In epoch 30, loss: 1.246, val acc: 0.714 (best 0.714), test acc: 0.713 (best 0.709) In epoch 35, loss: 1.074, val acc: 0.728 (best 0.728), test acc: 0.742 (best 0.742) In epoch 40, loss: 0.905, val acc: 0.740 (best 0.742), test acc: 0.752 (best 0.748) In epoch 45, loss: 0.749, val acc: 0.750 (best 0.750), test acc: 0.759 (best 0.759) In epoch 50, loss: 0.612, val acc: 0.752 (best 0.752), test acc: 0.766 (best 0.760) In epoch 55, loss: 0.495, val acc: 0.748 (best 0.752), test acc: 0.768 (best 0.760) In epoch 60, loss: 0.400, val acc: 0.760 (best 0.760), test acc: 0.767 (best 0.766) In epoch 65, loss: 0.324, val acc: 0.762 (best 0.762), test acc: 0.771 (best 0.769) In epoch 70, loss: 0.263, val acc: 0.768 (best 0.768), test acc: 0.776 (best 0.773) In epoch 75, loss: 0.216, val acc: 0.766 (best 0.768), test acc: 0.776 (best 0.773) In epoch 80, loss: 0.179, val acc: 0.764 (best 0.768), test acc: 0.771 (best 0.773) In epoch 85, loss: 0.149, val acc: 0.764 (best 0.768), test acc: 0.769 (best 0.773) In epoch 90, loss: 0.126, val acc: 0.766 (best 0.768), test acc: 0.769 (best 0.773) In epoch 95, loss: 0.108, val acc: 0.764 (best 0.768), test acc: 0.768 (best 0.773) .. GENERATED FROM PYTHON SOURCE LINES 197-209 Training on GPU --------------- Training on GPU requires to put both the model and the graph onto GPU with the ``to`` method, similar to what you will do in PyTorch. .. code:: python g = g.to('cuda') model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes).to('cuda') train(g, model) .. GENERATED FROM PYTHON SOURCE LINES 212-225 What’s next? ------------ - :doc:`How does DGL represent a graph <2_dglgraph>`? - :doc:`Write your own GNN module <3_message_passing>`. - :doc:`Link prediction (predicting existence of edges) on full graph <4_link_predict>`. - :doc:`Graph classification <5_graph_classification>`. - :doc:`Make your own dataset <6_load_data>`. - :ref:`The list of supported graph convolution modules `. - :ref:`The list of datasets provided by DGL `. .. GENERATED FROM PYTHON SOURCE LINES 225-229 .. code-block:: Python # Thumbnail credits: Stanford CS224W Notes # sphinx_gallery_thumbnail_path = '_static/blitz_1_introduction.png' .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 2.076 seconds) .. _sphx_glr_download_tutorials_blitz_1_introduction.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 1_introduction.ipynb <1_introduction.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 1_introduction.py <1_introduction.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 1_introduction.zip <1_introduction.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_