.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorials/blitz/1_introduction.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_tutorials_blitz_1_introduction.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_tutorials_blitz_1_introduction.py:


Node Classification with DGL
============================

GNNs are powerful tools for many machine learning tasks on graphs. In
this introductory tutorial, you will learn the basic workflow of using
GNNs for node classification, i.e. predicting the category of a node in
a graph.

By completing this tutorial, you will be able to

-  Load a DGL-provided dataset.
-  Build a GNN model with DGL-provided neural network modules.
-  Train and evaluate a GNN model for node classification on either CPU
   or GPU.

This tutorial assumes that you have experience in building neural
networks with PyTorch.

(Time estimate: 13 minutes)

.. GENERATED FROM PYTHON SOURCE LINES 23-33

.. code-block:: Python


    import os

    os.environ["DGLBACKEND"] = "pytorch"
    import dgl
    import dgl.data
    import torch
    import torch.nn as nn
    import torch.nn.functional as F








.. GENERATED FROM PYTHON SOURCE LINES 34-64

Overview of Node Classification with GNN
----------------------------------------

One of the most popular and widely adopted tasks on graph data is node
classification, where a model needs to predict the ground truth category
of each node. Before graph neural networks, many proposed methods are
using either connectivity alone (such as DeepWalk or node2vec), or simple
combinations of connectivity and the node's own features.  GNNs, by
contrast, offers an opportunity to obtain node representations by
combining the connectivity and features of a *local neighborhood*.

`Kipf et
al., <https://arxiv.org/abs/1609.02907>`__ is an example that formulates
the node classification problem as a semi-supervised node classification
task. With the help of only a small portion of labeled nodes, a graph
neural network (GNN) can accurately predict the node category of the
others.

This tutorial will show how to build such a GNN for semi-supervised node
classification with only a small number of labels on the Cora
dataset,
a citation network with papers as nodes and citations as edges. The task
is to predict the category of a given paper. Each paper node contains a
word count vector as its features, normalized so that they sum up to one,
as described in Section 5.2 of
`the paper <https://arxiv.org/abs/1609.02907>`__.

Loading Cora Dataset
--------------------


.. GENERATED FROM PYTHON SOURCE LINES 64-70

.. code-block:: Python



    dataset = dgl.data.CoraGraphDataset()
    print(f"Number of categories: {dataset.num_classes}")






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

      NumNodes: 2708
      NumEdges: 10556
      NumFeats: 1433
      NumClasses: 7
      NumTrainingSamples: 140
      NumValidationSamples: 500
      NumTestSamples: 1000
    Done loading data from cached files.
    Number of categories: 7




.. GENERATED FROM PYTHON SOURCE LINES 71-74

A DGL Dataset object may contain one or multiple graphs. The Cora
dataset used in this tutorial only consists of one single graph.


.. GENERATED FROM PYTHON SOURCE LINES 74-78

.. code-block:: Python


    g = dataset[0]









.. GENERATED FROM PYTHON SOURCE LINES 79-96

A DGL graph can store node features and edge features in two
dictionary-like attributes called ``ndata`` and ``edata``.
In the DGL Cora dataset, the graph contains the following node features:

- ``train_mask``: A boolean tensor indicating whether the node is in the
  training set.

- ``val_mask``: A boolean tensor indicating whether the node is in the
  validation set.

- ``test_mask``: A boolean tensor indicating whether the node is in the
  test set.

- ``label``: The ground truth node category.

-  ``feat``: The node features.


.. GENERATED FROM PYTHON SOURCE LINES 96-103

.. code-block:: Python


    print("Node features")
    print(g.ndata)
    print("Edge features")
    print(g.edata)






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Node features
    {'feat': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]), 'label': tensor([3, 4, 4,  ..., 3, 3, 3]), 'test_mask': tensor([False, False, False,  ...,  True,  True,  True]), 'val_mask': tensor([False, False, False,  ..., False, False, False]), 'train_mask': tensor([ True,  True,  True,  ..., False, False, False])}
    Edge features
    {}




.. GENERATED FROM PYTHON SOURCE LINES 104-115

Defining a Graph Convolutional Network (GCN)
--------------------------------------------

This tutorial will build a two-layer `Graph Convolutional Network
(GCN) <http://tkipf.github.io/graph-convolutional-networks/>`__. Each
layer computes new node representations by aggregating neighbor
information.

To build a multi-layer GCN you can simply stack ``dgl.nn.GraphConv``
modules, which inherit ``torch.nn.Module``.


.. GENERATED FROM PYTHON SOURCE LINES 115-136

.. code-block:: Python


    from dgl.nn import GraphConv


    class GCN(nn.Module):
        def __init__(self, in_feats, h_feats, num_classes):
            super(GCN, self).__init__()
            self.conv1 = GraphConv(in_feats, h_feats)
            self.conv2 = GraphConv(h_feats, num_classes)

        def forward(self, g, in_feat):
            h = self.conv1(g, in_feat)
            h = F.relu(h)
            h = self.conv2(g, h)
            return h


    # Create the model with given dimensions
    model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes)









.. GENERATED FROM PYTHON SOURCE LINES 137-140

DGL provides implementation of many popular neighbor aggregation
modules. You can easily invoke them with one line of code.


.. GENERATED FROM PYTHON SOURCE LINES 143-148

Training the GCN
----------------

Training this GCN is similar to training other PyTorch neural networks.


.. GENERATED FROM PYTHON SOURCE LINES 148-196

.. code-block:: Python



    def train(g, model):
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        best_val_acc = 0
        best_test_acc = 0

        features = g.ndata["feat"]
        labels = g.ndata["label"]
        train_mask = g.ndata["train_mask"]
        val_mask = g.ndata["val_mask"]
        test_mask = g.ndata["test_mask"]
        for e in range(100):
            # Forward
            logits = model(g, features)

            # Compute prediction
            pred = logits.argmax(1)

            # Compute loss
            # Note that you should only compute the losses of the nodes in the training set.
            loss = F.cross_entropy(logits[train_mask], labels[train_mask])

            # Compute accuracy on training/validation/test
            train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
            val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
            test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

            # Save the best validation accuracy and the corresponding test accuracy.
            if best_val_acc < val_acc:
                best_val_acc = val_acc
                best_test_acc = test_acc

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if e % 5 == 0:
                print(
                    f"In epoch {e}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best {best_test_acc:.3f})"
                )


    model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes)
    train(g, model)






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    In epoch 0, loss: 1.947, val acc: 0.122 (best 0.122), test acc: 0.121 (best 0.121)
    In epoch 5, loss: 1.909, val acc: 0.512 (best 0.512), test acc: 0.523 (best 0.523)
    In epoch 10, loss: 1.842, val acc: 0.644 (best 0.646), test acc: 0.668 (best 0.663)
    In epoch 15, loss: 1.750, val acc: 0.714 (best 0.714), test acc: 0.733 (best 0.733)
    In epoch 20, loss: 1.632, val acc: 0.710 (best 0.720), test acc: 0.717 (best 0.729)
    In epoch 25, loss: 1.490, val acc: 0.696 (best 0.720), test acc: 0.717 (best 0.729)
    In epoch 30, loss: 1.329, val acc: 0.704 (best 0.720), test acc: 0.724 (best 0.729)
    In epoch 35, loss: 1.157, val acc: 0.714 (best 0.720), test acc: 0.731 (best 0.729)
    In epoch 40, loss: 0.984, val acc: 0.730 (best 0.730), test acc: 0.744 (best 0.744)
    In epoch 45, loss: 0.819, val acc: 0.748 (best 0.748), test acc: 0.748 (best 0.748)
    In epoch 50, loss: 0.672, val acc: 0.754 (best 0.754), test acc: 0.749 (best 0.749)
    In epoch 55, loss: 0.545, val acc: 0.764 (best 0.764), test acc: 0.757 (best 0.757)
    In epoch 60, loss: 0.441, val acc: 0.772 (best 0.772), test acc: 0.762 (best 0.761)
    In epoch 65, loss: 0.358, val acc: 0.778 (best 0.778), test acc: 0.760 (best 0.760)
    In epoch 70, loss: 0.292, val acc: 0.780 (best 0.780), test acc: 0.762 (best 0.762)
    In epoch 75, loss: 0.239, val acc: 0.782 (best 0.782), test acc: 0.765 (best 0.764)
    In epoch 80, loss: 0.198, val acc: 0.786 (best 0.786), test acc: 0.762 (best 0.763)
    In epoch 85, loss: 0.166, val acc: 0.780 (best 0.786), test acc: 0.763 (best 0.763)
    In epoch 90, loss: 0.140, val acc: 0.776 (best 0.786), test acc: 0.763 (best 0.763)
    In epoch 95, loss: 0.120, val acc: 0.776 (best 0.786), test acc: 0.763 (best 0.763)




.. GENERATED FROM PYTHON SOURCE LINES 197-209

Training on GPU
---------------

Training on GPU requires to put both the model and the graph onto GPU
with the ``to`` method, similar to what you will do in PyTorch.

.. code:: python

   g = g.to('cuda')
   model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes).to('cuda')
   train(g, model)


.. GENERATED FROM PYTHON SOURCE LINES 212-225

What’s next?
------------

-  :doc:`How does DGL represent a graph <2_dglgraph>`?
-  :doc:`Write your own GNN module <3_message_passing>`.
-  :doc:`Link prediction (predicting existence of edges) on full
   graph <4_link_predict>`.
-  :doc:`Graph classification <5_graph_classification>`.
-  :doc:`Make your own dataset <6_load_data>`.
-  :ref:`The list of supported graph convolution
   modules <apinn-pytorch>`.
-  :ref:`The list of datasets provided by DGL <apidata>`.


.. GENERATED FROM PYTHON SOURCE LINES 225-229

.. code-block:: Python



    # Thumbnail credits: Stanford CS224W Notes
    # sphinx_gallery_thumbnail_path = '_static/blitz_1_introduction.png'








.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 1.837 seconds)


.. _sphx_glr_download_tutorials_blitz_1_introduction.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: 1_introduction.ipynb <1_introduction.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: 1_introduction.py <1_introduction.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: 1_introduction.zip <1_introduction.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_